1use crate::modular::bingcd::extension::ExtendedInt;
2use crate::{Choice, Uint};
3use ctutils::CtEq;
4
5pub trait Unit: Sized {
6 const UNIT: Self;
8}
9
10type Vector<T> = (T, T);
11
12type ExtraLimbInt<const LIMBS: usize> = ExtendedInt<LIMBS, 1>;
14
15#[derive(Debug, Clone, Copy, PartialEq)]
24pub(crate) struct IntMatrix<const LIMBS: usize> {
25 m00: ExtraLimbInt<LIMBS>,
26 m01: ExtraLimbInt<LIMBS>,
27 m10: ExtraLimbInt<LIMBS>,
28 m11: ExtraLimbInt<LIMBS>,
29}
30
31impl<const LIMBS: usize> Unit for IntMatrix<LIMBS> {
32 const UNIT: Self = Self {
33 m00: ExtraLimbInt::ONE,
34 m01: ExtraLimbInt::ZERO,
35 m10: ExtraLimbInt::ZERO,
36 m11: ExtraLimbInt::ONE,
37 };
38}
39
40impl<const LIMBS: usize> IntMatrix<LIMBS> {
41 pub(super) const fn conditional_negate_top_row(&mut self, negate: Choice) {
43 self.m00 = self.m00.wrapping_neg_if(negate);
44 self.m01 = self.m01.wrapping_neg_if(negate);
45 }
46
47 pub(super) const fn conditional_negate_bottom_row(&mut self, negate: Choice) {
49 self.m10 = self.m10.wrapping_neg_if(negate);
50 self.m11 = self.m11.wrapping_neg_if(negate);
51 }
52
53 pub(super) const fn to_pattern_matrix(self) -> PatternMatrix<LIMBS> {
54 let (abs_m00, m00_is_negative) = self.m00.abs_sign();
55 let (abs_m01, m01_is_negative) = self.m01.abs_sign();
56 let (abs_m10, m10_is_negative) = self.m10.abs_sign();
57 let (abs_m11, m11_is_negative) = self.m11.abs_sign();
58
59 let m00_is_zero = abs_m00.is_zero();
61 let m01_is_zero = abs_m01.is_zero();
62 let pattern_vote_1 = m00_is_zero.not().and(m00_is_negative.not());
63 let pattern_vote_2 = m01_is_zero.not().and(m01_is_negative);
64
65 let m00_and_m01_are_zero = m00_is_zero.and(m01_is_zero);
66 let m10_is_zero = abs_m10.is_zero();
67 let m11_is_zero = abs_m11.is_zero();
68 let pattern_vote_3 = m00_and_m01_are_zero.and(m10_is_zero.not().and(m10_is_negative));
69 let pattern_vote_4 = m00_and_m01_are_zero.and(m11_is_zero.not().and(m11_is_negative.not()));
70 let pattern = pattern_vote_1
71 .or(pattern_vote_2)
72 .or(pattern_vote_3)
73 .or(pattern_vote_4);
74
75 PatternMatrix {
76 m00: abs_m00.checked_drop_extension().expect_copied("m00 fits"),
77 m01: abs_m01.checked_drop_extension().expect_copied("m01 fits"),
78 m10: abs_m10.checked_drop_extension().expect_copied("m10 fits"),
79 m11: abs_m11.checked_drop_extension().expect_copied("m11 fits"),
80 pattern,
81 }
82 }
83}
84
85#[derive(Debug, Clone, Copy)]
95pub(crate) struct PatternMatrix<const LIMBS: usize> {
96 pub m00: Uint<LIMBS>,
97 pub m01: Uint<LIMBS>,
98 pub m10: Uint<LIMBS>,
99 pub m11: Uint<LIMBS>,
100 pub pattern: Choice,
101}
102
103impl<const LIMBS: usize> PatternMatrix<LIMBS> {
104 pub const UNIT: Self = Self {
105 m00: Uint::ONE,
106 m01: Uint::ZERO,
107 m10: Uint::ZERO,
108 m11: Uint::ONE,
109 pattern: Choice::TRUE,
110 };
111
112 #[inline]
115 pub(crate) const fn extended_apply_to<const VEC_LIMBS: usize>(
116 &self,
117 vec: Vector<Uint<VEC_LIMBS>>,
118 ) -> Vector<ExtendedInt<VEC_LIMBS, LIMBS>> {
119 let (a, b) = vec;
120 let m00a = ExtendedInt::from_product(a, self.m00);
121 let m10a = ExtendedInt::from_product(a, self.m10);
122 let m01b = ExtendedInt::from_product(b, self.m01);
123 let m11b = ExtendedInt::from_product(b, self.m11);
124 (
125 m00a.wrapping_sub(&m01b).wrapping_neg_if(self.pattern.not()),
126 m11b.wrapping_sub(&m10a).wrapping_neg_if(self.pattern.not()),
127 )
128 }
129
130 #[inline]
132 pub(super) const fn mul_int_matrix<const RHS_LIMBS: usize>(
133 &self,
134 rhs: &IntMatrix<RHS_LIMBS>,
135 ) -> IntMatrix<RHS_LIMBS> {
136 let a0 = rhs.m00.wrapping_mul((&self.m00, &self.pattern.not()));
137 let a1 = rhs.m10.wrapping_mul((&self.m01, &self.pattern));
138 let m00 = a0.wrapping_add(&a1);
139
140 let b0 = rhs.m01.wrapping_mul((&self.m00, &self.pattern.not()));
141 let b1 = rhs.m11.wrapping_mul((&self.m01, &self.pattern));
142 let m01 = b0.wrapping_add(&b1);
143
144 let c0 = rhs.m00.wrapping_mul((&self.m10, &self.pattern));
145 let c1 = rhs.m10.wrapping_mul((&self.m11, &self.pattern.not()));
146 let m10 = c0.wrapping_add(&c1);
147
148 let d0 = rhs.m01.wrapping_mul((&self.m10, &self.pattern));
149 let d1 = rhs.m11.wrapping_mul((&self.m11, &self.pattern.not()));
150 let m11 = d0.wrapping_add(&d1);
151
152 IntMatrix { m00, m01, m10, m11 }
153 }
154
155 #[inline]
157 pub(crate) const fn conditional_swap_rows(&mut self, swap: Choice) {
158 Uint::conditional_swap(&mut self.m00, &mut self.m10, swap);
159 Uint::conditional_swap(&mut self.m01, &mut self.m11, swap);
160 self.pattern = self.pattern.xor(swap);
161 }
162
163 #[inline]
165 pub(crate) const fn conditional_subtract_bottom_row_from_top(&mut self, subtract: Choice) {
166 self.m00 = Uint::select(&self.m00, &self.m00.wrapping_add(&self.m10), subtract);
169 self.m01 = Uint::select(&self.m01, &self.m01.wrapping_add(&self.m11), subtract);
170 }
171
172 #[inline]
174 pub(crate) const fn conditional_subtract_right_column_from_left(&mut self, subtract: Choice) {
175 self.m00 = Uint::select(&self.m00, &self.m00.wrapping_add(&self.m01), subtract);
178 self.m10 = Uint::select(&self.m10, &self.m10.wrapping_add(&self.m11), subtract);
179 }
180
181 #[inline]
183 pub(crate) const fn conditional_add_right_column_to_left(&mut self, add: Choice) {
184 self.m00 = Uint::select(&self.m00, &self.m01.wrapping_sub(&self.m00), add);
187 self.m10 = Uint::select(&self.m10, &self.m11.wrapping_sub(&self.m10), add);
188 }
189
190 #[inline]
192 pub(crate) const fn conditional_double_bottom_row(&mut self, double: Choice) {
193 self.m10 = Uint::select(&self.m10, &self.m10.shl1(), double);
194 self.m11 = Uint::select(&self.m11, &self.m11.shl1(), double);
195 }
196
197 #[inline]
199 pub(crate) const fn conditional_negate(&mut self, negate: Choice) {
200 self.pattern = self.pattern.xor(negate);
201 }
202}
203
204impl<const LIMBS: usize> PartialEq for PatternMatrix<LIMBS> {
205 fn eq(&self, other: &Self) -> bool {
206 (self.m00.ct_eq(&other.m00)
207 & self.m01.ct_eq(&other.m01)
208 & self.m10.ct_eq(&other.m10)
209 & self.m11.ct_eq(&other.m11)
210 & self.pattern.ct_eq(&other.pattern))
211 .into()
212 }
213}
214
215#[derive(Debug, Clone, Copy, PartialEq)]
220pub(crate) struct DividedMatrix<const LIMBS: usize, MATRIX: Unit> {
221 pub(super) inner: MATRIX,
222 pub k: u32,
223 pub k_upper_bound: u32,
224}
225
226impl<const LIMBS: usize, Matrix: Unit> Unit for DividedMatrix<LIMBS, Matrix> {
227 const UNIT: Self = Self {
228 inner: Matrix::UNIT,
229 k: 0,
230 k_upper_bound: 0,
231 };
232}
233
234#[derive(Debug, Clone, Copy, PartialEq)]
247pub(crate) struct DividedPatternMatrix<const LIMBS: usize> {
248 pub(super) inner: PatternMatrix<LIMBS>,
249 pub k: u32,
250 pub k_upper_bound: u32,
251}
252
253impl<const LIMBS: usize> DividedPatternMatrix<LIMBS> {
254 pub const UNIT: Self = Self {
256 inner: PatternMatrix::UNIT,
257 k: 0,
258 k_upper_bound: 0,
259 };
260
261 #[inline]
264 pub const fn extended_apply_to<const VEC_LIMBS: usize, const UPPER_BOUND: u32>(
265 &self,
266 vec: Vector<Uint<VEC_LIMBS>>,
267 ) -> Vector<ExtendedInt<VEC_LIMBS, LIMBS>> {
268 let (a, b) = self.inner.extended_apply_to(vec);
269 (
270 a.bounded_div_2k::<UPPER_BOUND>(self.k),
271 b.bounded_div_2k::<UPPER_BOUND>(self.k),
272 )
273 }
274
275 #[inline]
278 pub const fn extended_apply_to_vartime<const VEC_LIMBS: usize>(
279 &self,
280 vec: Vector<Uint<VEC_LIMBS>>,
281 ) -> Vector<ExtendedInt<VEC_LIMBS, LIMBS>> {
282 let (a, b) = self.inner.extended_apply_to(vec);
283 (a.div_2k_vartime(self.k), b.div_2k_vartime(self.k))
284 }
285
286 #[inline]
288 pub const fn mul_int_matrix<const RHS_LIMBS: usize>(
289 &self,
290 rhs: &DividedIntMatrix<RHS_LIMBS>,
291 ) -> DividedIntMatrix<RHS_LIMBS> {
292 DividedIntMatrix {
293 inner: self.inner.mul_int_matrix(&rhs.inner),
294 k: self.k + rhs.k,
295 k_upper_bound: self.k_upper_bound + rhs.k_upper_bound,
296 }
297 }
298
299 #[inline]
301 pub const fn conditional_swap_rows(&mut self, swap: Choice) {
302 self.inner.conditional_swap_rows(swap);
303 }
304
305 #[inline]
307 pub const fn swap_rows(&mut self) {
308 self.conditional_swap_rows(Choice::TRUE);
309 }
310
311 #[inline]
313 pub const fn conditional_subtract_bottom_row_from_top(&mut self, subtract: Choice) {
314 self.inner
315 .conditional_subtract_bottom_row_from_top(subtract);
316 }
317
318 #[inline]
320 pub const fn conditional_double_bottom_row(&mut self, double: Choice) {
321 self.inner.conditional_double_bottom_row(double);
322 self.k = double.select_u32(self.k, self.k + 1);
323 self.k_upper_bound += 1;
324 }
325}
326
327pub(crate) type DividedIntMatrix<const LIMBS: usize> = DividedMatrix<LIMBS, IntMatrix<LIMBS>>;
328
329impl<const LIMBS: usize> DividedIntMatrix<LIMBS> {
330 pub(super) const fn conditional_negate_top_row(&mut self, negate: Choice) {
332 self.inner.conditional_negate_top_row(negate);
333 }
334
335 pub(super) const fn conditional_negate_bottom_row(&mut self, negate: Choice) {
337 self.inner.conditional_negate_bottom_row(negate);
338 }
339
340 pub(super) const fn to_divided_pattern_matrix(self) -> DividedPatternMatrix<LIMBS> {
341 DividedPatternMatrix {
342 inner: self.inner.to_pattern_matrix(),
343 k: self.k,
344 k_upper_bound: self.k_upper_bound,
345 }
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use crate::modular::bingcd::matrix::{DividedPatternMatrix, PatternMatrix};
352 use crate::{Choice, U64, U256, Uint};
353
354 impl<const LIMBS: usize> PatternMatrix<LIMBS> {
355 pub(crate) const fn new_u64(matrix: (u64, u64, u64, u64), pattern: Choice) -> Self {
356 Self {
357 m00: Uint::from_u64(matrix.0),
358 m01: Uint::from_u64(matrix.1),
359 m10: Uint::from_u64(matrix.2),
360 m11: Uint::from_u64(matrix.3),
361 pattern,
362 }
363 }
364 }
365
366 impl<const LIMBS: usize> DividedPatternMatrix<LIMBS> {
367 pub(crate) const fn new_u64(
368 matrix: (u64, u64, u64, u64),
369 pattern: Choice,
370 k: u32,
371 k_upper_bound: u32,
372 ) -> Self {
373 Self {
374 inner: PatternMatrix::new_u64(matrix, pattern),
375 k,
376 k_upper_bound,
377 }
378 }
379 }
380
381 const X: DividedPatternMatrix<{ U256::LIMBS }> =
382 DividedPatternMatrix::new_u64((1u64, 7u64, 23u64, 53u64), Choice::TRUE, 6, 8);
383
384 #[test]
385 fn test_wrapping_apply_to() {
386 let a = U64::from_be_hex("CA048AFA63CD6A1F");
387 let b = U64::from_be_hex("AE693BF7BE8E5566");
388 let matrix = DividedPatternMatrix::<{ U64::LIMBS }>::new_u64(
389 (288, 208, 310, 679),
390 Choice::TRUE,
391 17,
392 17,
393 );
394
395 let (a_, b_) = matrix.extended_apply_to::<{ U64::LIMBS }, 18>((a, b));
396 assert_eq!(
397 a_.dropped_abs_sign().0,
398 Uint::from_be_hex("002AC7CDD032B9B9")
399 );
400 assert_eq!(
401 b_.dropped_abs_sign().0,
402 Uint::from_be_hex("006CFBCEE172C863")
403 );
404 }
405
406 #[test]
407 fn test_swap() {
408 let mut y = X;
409 y.swap_rows();
410 let target = DividedPatternMatrix::new_u64((23, 53, 1, 7), Choice::FALSE, 6, 8);
411 assert_eq!(y, target);
412 }
413
414 #[test]
415 fn test_conditional_swap() {
416 let mut y = X;
417 y.conditional_swap_rows(Choice::FALSE);
418 assert_eq!(y, X);
419 y.conditional_swap_rows(Choice::TRUE);
420 let target = DividedPatternMatrix::new_u64((23, 53, 1, 7), Choice::FALSE, 6, 8);
421 assert_eq!(y, target);
422 }
423
424 #[test]
425 fn test_conditional_subtract_bottom_row_from_top() {
426 let mut y = X;
427 y.conditional_subtract_bottom_row_from_top(Choice::FALSE);
428 assert_eq!(y, X);
429 y.conditional_subtract_bottom_row_from_top(Choice::TRUE);
430 let target =
431 DividedPatternMatrix::new_u64((24u64, 60u64, 23u64, 53u64), Choice::TRUE, 6, 8);
432 assert_eq!(y, target);
433 }
434
435 #[test]
436 fn test_conditional_double() {
437 let mut y = X;
438 y.conditional_double_bottom_row(Choice::FALSE);
439 let target = DividedPatternMatrix::new_u64((1u64, 7u64, 23u64, 53u64), Choice::TRUE, 6, 9);
440 assert_eq!(y, target);
441 y.conditional_double_bottom_row(Choice::TRUE);
442 let target =
443 DividedPatternMatrix::new_u64((1u64, 7u64, 46u64, 106u64), Choice::TRUE, 7, 10);
444 assert_eq!(y, target);
445 }
446
447 #[test]
448 fn test_conditional_add_right_column_to_left() {
449 let mut y = X.inner;
450 y.conditional_add_right_column_to_left(Choice::FALSE);
451 assert_eq!(y, X.inner);
452 y.conditional_add_right_column_to_left(Choice::TRUE);
453
454 let target = PatternMatrix::new_u64((6u64, 7u64, 30u64, 53u64), Choice::TRUE);
455 assert_eq!(y, target);
456 }
457
458 #[test]
459 fn test_conditional_subtract_right_column_from_left() {
460 let mut y = X.inner;
461 y.conditional_subtract_right_column_from_left(Choice::FALSE);
462 assert_eq!(y, X.inner);
463 y.conditional_subtract_right_column_from_left(Choice::TRUE);
464 let target = PatternMatrix::new_u64((8u64, 7u64, 76u64, 53u64), Choice::TRUE);
465 assert_eq!(y, target);
466 }
467}