Skip to main content

crypto_bigint/modular/bingcd/
matrix.rs

1use crate::modular::bingcd::extension::ExtendedInt;
2use crate::{Choice, Uint};
3use ctutils::CtEq;
4
5pub trait Unit: Sized {
6    /// The unit matrix.
7    const UNIT: Self;
8}
9
10type Vector<T> = (T, T);
11
12/// [`Int`] with an extra limb.
13type ExtraLimbInt<const LIMBS: usize> = ExtendedInt<LIMBS, 1>;
14
15/// 2x2-Matrix with integers for elements.
16///
17/// ### Representation
18/// The internal state represents the matrix
19/// ```text
20/// [ m00  m01 ]
21/// [ m10  m11 ]
22/// ```
23#[derive(Debug, Clone, Copy, PartialEq)]
24pub(crate) struct IntMatrix<const LIMBS: usize> {
25    m00: ExtraLimbInt<LIMBS>,
26    m01: ExtraLimbInt<LIMBS>,
27    m10: ExtraLimbInt<LIMBS>,
28    m11: ExtraLimbInt<LIMBS>,
29}
30
31impl<const LIMBS: usize> Unit for IntMatrix<LIMBS> {
32    const UNIT: Self = Self {
33        m00: ExtraLimbInt::ONE,
34        m01: ExtraLimbInt::ZERO,
35        m10: ExtraLimbInt::ZERO,
36        m11: ExtraLimbInt::ONE,
37    };
38}
39
40impl<const LIMBS: usize> IntMatrix<LIMBS> {
41    /// Negate the top row of this matrix if `negate`; otherwise do nothing.
42    pub(super) const fn conditional_negate_top_row(&mut self, negate: Choice) {
43        self.m00 = self.m00.wrapping_neg_if(negate);
44        self.m01 = self.m01.wrapping_neg_if(negate);
45    }
46
47    /// Negate the bottom row of this matrix if `negate`; otherwise do nothing.
48    pub(super) const fn conditional_negate_bottom_row(&mut self, negate: Choice) {
49        self.m10 = self.m10.wrapping_neg_if(negate);
50        self.m11 = self.m11.wrapping_neg_if(negate);
51    }
52
53    pub(super) const fn to_pattern_matrix(self) -> PatternMatrix<LIMBS> {
54        let (abs_m00, m00_is_negative) = self.m00.abs_sign();
55        let (abs_m01, m01_is_negative) = self.m01.abs_sign();
56        let (abs_m10, m10_is_negative) = self.m10.abs_sign();
57        let (abs_m11, m11_is_negative) = self.m11.abs_sign();
58
59        // Construct the pattern.
60        let m00_is_zero = abs_m00.is_zero();
61        let m01_is_zero = abs_m01.is_zero();
62        let pattern_vote_1 = m00_is_zero.not().and(m00_is_negative.not());
63        let pattern_vote_2 = m01_is_zero.not().and(m01_is_negative);
64
65        let m00_and_m01_are_zero = m00_is_zero.and(m01_is_zero);
66        let m10_is_zero = abs_m10.is_zero();
67        let m11_is_zero = abs_m11.is_zero();
68        let pattern_vote_3 = m00_and_m01_are_zero.and(m10_is_zero.not().and(m10_is_negative));
69        let pattern_vote_4 = m00_and_m01_are_zero.and(m11_is_zero.not().and(m11_is_negative.not()));
70        let pattern = pattern_vote_1
71            .or(pattern_vote_2)
72            .or(pattern_vote_3)
73            .or(pattern_vote_4);
74
75        PatternMatrix {
76            m00: abs_m00.checked_drop_extension().expect_copied("m00 fits"),
77            m01: abs_m01.checked_drop_extension().expect_copied("m01 fits"),
78            m10: abs_m10.checked_drop_extension().expect_copied("m10 fits"),
79            m11: abs_m11.checked_drop_extension().expect_copied("m11 fits"),
80            pattern,
81        }
82    }
83}
84
85/// 2x2-Matrix where either the diagonal or off-diagonal elements are negative.
86///
87/// The internal state represents the matrix
88/// ```text
89///      true                 false
90/// [  m00 -m01 ]         [ -m00  m01 ]
91/// [ -m10  m11 ]   or    [  m10 -m11 ]
92/// ```
93/// depending on whether `pattern` is respectively truthy or not.
94#[derive(Debug, Clone, Copy)]
95pub(crate) struct PatternMatrix<const LIMBS: usize> {
96    pub m00: Uint<LIMBS>,
97    pub m01: Uint<LIMBS>,
98    pub m10: Uint<LIMBS>,
99    pub m11: Uint<LIMBS>,
100    pub pattern: Choice,
101}
102
103impl<const LIMBS: usize> PatternMatrix<LIMBS> {
104    pub const UNIT: Self = Self {
105        m00: Uint::ONE,
106        m01: Uint::ZERO,
107        m10: Uint::ZERO,
108        m11: Uint::ONE,
109        pattern: Choice::TRUE,
110    };
111
112    /// Apply this matrix to a vector of [Uint]s, returning the result as a vector of
113    /// [`ExtendedInt`]s.
114    #[inline]
115    pub(crate) const fn extended_apply_to<const VEC_LIMBS: usize>(
116        &self,
117        vec: Vector<Uint<VEC_LIMBS>>,
118    ) -> Vector<ExtendedInt<VEC_LIMBS, LIMBS>> {
119        let (a, b) = vec;
120        let m00a = ExtendedInt::from_product(a, self.m00);
121        let m10a = ExtendedInt::from_product(a, self.m10);
122        let m01b = ExtendedInt::from_product(b, self.m01);
123        let m11b = ExtendedInt::from_product(b, self.m11);
124        (
125            m00a.wrapping_sub(&m01b).wrapping_neg_if(self.pattern.not()),
126            m11b.wrapping_sub(&m10a).wrapping_neg_if(self.pattern.not()),
127        )
128    }
129
130    /// Wrapping apply this matrix to `rhs`. Return the result in a [`IntMatrix<RHS_LIMBS>`].
131    #[inline]
132    pub(super) const fn mul_int_matrix<const RHS_LIMBS: usize>(
133        &self,
134        rhs: &IntMatrix<RHS_LIMBS>,
135    ) -> IntMatrix<RHS_LIMBS> {
136        let a0 = rhs.m00.wrapping_mul((&self.m00, &self.pattern.not()));
137        let a1 = rhs.m10.wrapping_mul((&self.m01, &self.pattern));
138        let m00 = a0.wrapping_add(&a1);
139
140        let b0 = rhs.m01.wrapping_mul((&self.m00, &self.pattern.not()));
141        let b1 = rhs.m11.wrapping_mul((&self.m01, &self.pattern));
142        let m01 = b0.wrapping_add(&b1);
143
144        let c0 = rhs.m00.wrapping_mul((&self.m10, &self.pattern));
145        let c1 = rhs.m10.wrapping_mul((&self.m11, &self.pattern.not()));
146        let m10 = c0.wrapping_add(&c1);
147
148        let d0 = rhs.m01.wrapping_mul((&self.m10, &self.pattern));
149        let d1 = rhs.m11.wrapping_mul((&self.m11, &self.pattern.not()));
150        let m11 = d0.wrapping_add(&d1);
151
152        IntMatrix { m00, m01, m10, m11 }
153    }
154
155    /// Swap the rows of this matrix if `swap` is truthy. Otherwise, do nothing.
156    #[inline]
157    pub(crate) const fn conditional_swap_rows(&mut self, swap: Choice) {
158        Uint::conditional_swap(&mut self.m00, &mut self.m10, swap);
159        Uint::conditional_swap(&mut self.m01, &mut self.m11, swap);
160        self.pattern = self.pattern.xor(swap);
161    }
162
163    /// Subtract the bottom row from the top if `subtract` is truthy. Otherwise, do nothing.
164    #[inline]
165    pub(crate) const fn conditional_subtract_bottom_row_from_top(&mut self, subtract: Choice) {
166        // Note: because the signs of the internal representation are stored in `pattern`,
167        // subtracting one row from another involves _adding_ these rows instead.
168        self.m00 = Uint::select(&self.m00, &self.m00.wrapping_add(&self.m10), subtract);
169        self.m01 = Uint::select(&self.m01, &self.m01.wrapping_add(&self.m11), subtract);
170    }
171
172    /// Subtract the right column from the left if `subtract` is truthy. Otherwise, do nothing.
173    #[inline]
174    pub(crate) const fn conditional_subtract_right_column_from_left(&mut self, subtract: Choice) {
175        // Note: because the signs of the internal representation are stored in `pattern`,
176        // subtracting one column from another involves _adding_ these columns instead.
177        self.m00 = Uint::select(&self.m00, &self.m00.wrapping_add(&self.m01), subtract);
178        self.m10 = Uint::select(&self.m10, &self.m10.wrapping_add(&self.m11), subtract);
179    }
180
181    /// If `add` is truthy, add the right column to the left. Otherwise, do nothing.
182    #[inline]
183    pub(crate) const fn conditional_add_right_column_to_left(&mut self, add: Choice) {
184        // Note: because the signs of the internal representation are stored in `pattern`,
185        // subtracting one column from another involves _adding_ these columns instead.
186        self.m00 = Uint::select(&self.m00, &self.m01.wrapping_sub(&self.m00), add);
187        self.m10 = Uint::select(&self.m10, &self.m11.wrapping_sub(&self.m10), add);
188    }
189
190    /// Double the bottom row of this matrix if `double` is truthy. Otherwise, do nothing.
191    #[inline]
192    pub(crate) const fn conditional_double_bottom_row(&mut self, double: Choice) {
193        self.m10 = Uint::select(&self.m10, &self.m10.shl1(), double);
194        self.m11 = Uint::select(&self.m11, &self.m11.shl1(), double);
195    }
196
197    /// Negate the elements in this matrix if `negate` is truthy. Otherwise, do nothing.
198    #[inline]
199    pub(crate) const fn conditional_negate(&mut self, negate: Choice) {
200        self.pattern = self.pattern.xor(negate);
201    }
202}
203
204impl<const LIMBS: usize> PartialEq for PatternMatrix<LIMBS> {
205    fn eq(&self, other: &Self) -> bool {
206        (self.m00.ct_eq(&other.m00)
207            & self.m01.ct_eq(&other.m01)
208            & self.m10.ct_eq(&other.m10)
209            & self.m11.ct_eq(&other.m11)
210            & self.pattern.ct_eq(&other.pattern))
211        .into()
212    }
213}
214
215/// A matrix whose elements still need to be divided by `2^k`.
216///
217/// Since some of the operations conditionally increase `k`, this struct furthermore keeps track of
218/// `k_upper_bound`; an upper bound on the value of `k`.
219#[derive(Debug, Clone, Copy, PartialEq)]
220pub(crate) struct DividedMatrix<const LIMBS: usize, MATRIX: Unit> {
221    pub(super) inner: MATRIX,
222    pub k: u32,
223    pub k_upper_bound: u32,
224}
225
226impl<const LIMBS: usize, Matrix: Unit> Unit for DividedMatrix<LIMBS, Matrix> {
227    const UNIT: Self = Self {
228        inner: Matrix::UNIT,
229        k: 0,
230        k_upper_bound: 0,
231    };
232}
233
234/// Variation on [`PatternMatrix`], where the contents of the matrix need to be divided by
235/// `2^k`.
236/// The internal state represents the matrix
237/// ```text
238///      true                       false
239/// [  m00 -m01 ]               [ -m00  m01 ]
240/// [ -m10  m11 ] / 2^k   or    [  m10 -m11 ] / 2^k
241/// ```
242/// depending on whether `pattern` is respectively truthy or not.
243///
244/// Since some of the operations conditionally increase `k`, this struct furthermore keeps track of
245/// `k_upper_bound`; an upper bound on the value of `k`.
246#[derive(Debug, Clone, Copy, PartialEq)]
247pub(crate) struct DividedPatternMatrix<const LIMBS: usize> {
248    pub(super) inner: PatternMatrix<LIMBS>,
249    pub k: u32,
250    pub k_upper_bound: u32,
251}
252
253impl<const LIMBS: usize> DividedPatternMatrix<LIMBS> {
254    /// The unit matrix.
255    pub const UNIT: Self = Self {
256        inner: PatternMatrix::UNIT,
257        k: 0,
258        k_upper_bound: 0,
259    };
260
261    /// Apply this matrix to a vector of [Uint]s, returning the result as a vector of
262    /// [`ExtendedInt`]s.
263    #[inline]
264    pub const fn extended_apply_to<const VEC_LIMBS: usize, const UPPER_BOUND: u32>(
265        &self,
266        vec: Vector<Uint<VEC_LIMBS>>,
267    ) -> Vector<ExtendedInt<VEC_LIMBS, LIMBS>> {
268        let (a, b) = self.inner.extended_apply_to(vec);
269        (
270            a.bounded_div_2k::<UPPER_BOUND>(self.k),
271            b.bounded_div_2k::<UPPER_BOUND>(self.k),
272        )
273    }
274
275    /// Apply this matrix to a vector of [Uint]s, returning the result as a vector of
276    /// [`ExtendedInt`]s.
277    #[inline]
278    pub const fn extended_apply_to_vartime<const VEC_LIMBS: usize>(
279        &self,
280        vec: Vector<Uint<VEC_LIMBS>>,
281    ) -> Vector<ExtendedInt<VEC_LIMBS, LIMBS>> {
282        let (a, b) = self.inner.extended_apply_to(vec);
283        (a.div_2k_vartime(self.k), b.div_2k_vartime(self.k))
284    }
285
286    /// Multiply `self` with `rhs`. Return the result as a [`DividedIntMatrix<LIMBS>`].
287    #[inline]
288    pub const fn mul_int_matrix<const RHS_LIMBS: usize>(
289        &self,
290        rhs: &DividedIntMatrix<RHS_LIMBS>,
291    ) -> DividedIntMatrix<RHS_LIMBS> {
292        DividedIntMatrix {
293            inner: self.inner.mul_int_matrix(&rhs.inner),
294            k: self.k + rhs.k,
295            k_upper_bound: self.k_upper_bound + rhs.k_upper_bound,
296        }
297    }
298
299    /// Swap the rows of this matrix if `swap` is truthy. Otherwise, do nothing.
300    #[inline]
301    pub const fn conditional_swap_rows(&mut self, swap: Choice) {
302        self.inner.conditional_swap_rows(swap);
303    }
304
305    /// Swap the rows of this matrix.
306    #[inline]
307    pub const fn swap_rows(&mut self) {
308        self.conditional_swap_rows(Choice::TRUE);
309    }
310
311    /// Subtract the bottom row from the top if `subtract` is truthy. Otherwise, do nothing.
312    #[inline]
313    pub const fn conditional_subtract_bottom_row_from_top(&mut self, subtract: Choice) {
314        self.inner
315            .conditional_subtract_bottom_row_from_top(subtract);
316    }
317
318    /// Double the bottom row of this matrix if `double` is truthy. Otherwise, do nothing.
319    #[inline]
320    pub const fn conditional_double_bottom_row(&mut self, double: Choice) {
321        self.inner.conditional_double_bottom_row(double);
322        self.k = double.select_u32(self.k, self.k + 1);
323        self.k_upper_bound += 1;
324    }
325}
326
327pub(crate) type DividedIntMatrix<const LIMBS: usize> = DividedMatrix<LIMBS, IntMatrix<LIMBS>>;
328
329impl<const LIMBS: usize> DividedIntMatrix<LIMBS> {
330    /// Negate the top row of this matrix if `negate`; otherwise do nothing.
331    pub(super) const fn conditional_negate_top_row(&mut self, negate: Choice) {
332        self.inner.conditional_negate_top_row(negate);
333    }
334
335    /// Negate the bottom row of this matrix if `negate`; otherwise do nothing.
336    pub(super) const fn conditional_negate_bottom_row(&mut self, negate: Choice) {
337        self.inner.conditional_negate_bottom_row(negate);
338    }
339
340    pub(super) const fn to_divided_pattern_matrix(self) -> DividedPatternMatrix<LIMBS> {
341        DividedPatternMatrix {
342            inner: self.inner.to_pattern_matrix(),
343            k: self.k,
344            k_upper_bound: self.k_upper_bound,
345        }
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use crate::modular::bingcd::matrix::{DividedPatternMatrix, PatternMatrix};
352    use crate::{Choice, U64, U256, Uint};
353
354    impl<const LIMBS: usize> PatternMatrix<LIMBS> {
355        pub(crate) const fn new_u64(matrix: (u64, u64, u64, u64), pattern: Choice) -> Self {
356            Self {
357                m00: Uint::from_u64(matrix.0),
358                m01: Uint::from_u64(matrix.1),
359                m10: Uint::from_u64(matrix.2),
360                m11: Uint::from_u64(matrix.3),
361                pattern,
362            }
363        }
364    }
365
366    impl<const LIMBS: usize> DividedPatternMatrix<LIMBS> {
367        pub(crate) const fn new_u64(
368            matrix: (u64, u64, u64, u64),
369            pattern: Choice,
370            k: u32,
371            k_upper_bound: u32,
372        ) -> Self {
373            Self {
374                inner: PatternMatrix::new_u64(matrix, pattern),
375                k,
376                k_upper_bound,
377            }
378        }
379    }
380
381    const X: DividedPatternMatrix<{ U256::LIMBS }> =
382        DividedPatternMatrix::new_u64((1u64, 7u64, 23u64, 53u64), Choice::TRUE, 6, 8);
383
384    #[test]
385    fn test_wrapping_apply_to() {
386        let a = U64::from_be_hex("CA048AFA63CD6A1F");
387        let b = U64::from_be_hex("AE693BF7BE8E5566");
388        let matrix = DividedPatternMatrix::<{ U64::LIMBS }>::new_u64(
389            (288, 208, 310, 679),
390            Choice::TRUE,
391            17,
392            17,
393        );
394
395        let (a_, b_) = matrix.extended_apply_to::<{ U64::LIMBS }, 18>((a, b));
396        assert_eq!(
397            a_.dropped_abs_sign().0,
398            Uint::from_be_hex("002AC7CDD032B9B9")
399        );
400        assert_eq!(
401            b_.dropped_abs_sign().0,
402            Uint::from_be_hex("006CFBCEE172C863")
403        );
404    }
405
406    #[test]
407    fn test_swap() {
408        let mut y = X;
409        y.swap_rows();
410        let target = DividedPatternMatrix::new_u64((23, 53, 1, 7), Choice::FALSE, 6, 8);
411        assert_eq!(y, target);
412    }
413
414    #[test]
415    fn test_conditional_swap() {
416        let mut y = X;
417        y.conditional_swap_rows(Choice::FALSE);
418        assert_eq!(y, X);
419        y.conditional_swap_rows(Choice::TRUE);
420        let target = DividedPatternMatrix::new_u64((23, 53, 1, 7), Choice::FALSE, 6, 8);
421        assert_eq!(y, target);
422    }
423
424    #[test]
425    fn test_conditional_subtract_bottom_row_from_top() {
426        let mut y = X;
427        y.conditional_subtract_bottom_row_from_top(Choice::FALSE);
428        assert_eq!(y, X);
429        y.conditional_subtract_bottom_row_from_top(Choice::TRUE);
430        let target =
431            DividedPatternMatrix::new_u64((24u64, 60u64, 23u64, 53u64), Choice::TRUE, 6, 8);
432        assert_eq!(y, target);
433    }
434
435    #[test]
436    fn test_conditional_double() {
437        let mut y = X;
438        y.conditional_double_bottom_row(Choice::FALSE);
439        let target = DividedPatternMatrix::new_u64((1u64, 7u64, 23u64, 53u64), Choice::TRUE, 6, 9);
440        assert_eq!(y, target);
441        y.conditional_double_bottom_row(Choice::TRUE);
442        let target =
443            DividedPatternMatrix::new_u64((1u64, 7u64, 46u64, 106u64), Choice::TRUE, 7, 10);
444        assert_eq!(y, target);
445    }
446
447    #[test]
448    fn test_conditional_add_right_column_to_left() {
449        let mut y = X.inner;
450        y.conditional_add_right_column_to_left(Choice::FALSE);
451        assert_eq!(y, X.inner);
452        y.conditional_add_right_column_to_left(Choice::TRUE);
453
454        let target = PatternMatrix::new_u64((6u64, 7u64, 30u64, 53u64), Choice::TRUE);
455        assert_eq!(y, target);
456    }
457
458    #[test]
459    fn test_conditional_subtract_right_column_from_left() {
460        let mut y = X.inner;
461        y.conditional_subtract_right_column_from_left(Choice::FALSE);
462        assert_eq!(y, X.inner);
463        y.conditional_subtract_right_column_from_left(Choice::TRUE);
464        let target = PatternMatrix::new_u64((8u64, 7u64, 76u64, 53u64), Choice::TRUE);
465        assert_eq!(y, target);
466    }
467}