Skip to main content

crypto_bigint/modular/bingcd/
xgcd.rs

1//! The Binary Extended GCD algorithm.
2use super::gcd::bingcd_step;
3use crate::modular::bingcd::matrix::{DividedIntMatrix, DividedPatternMatrix, PatternMatrix, Unit};
4use crate::primitives::u32_max;
5use crate::{Choice, Int, Limb, NonZeroUint, Odd, OddUint, U64, U128, Uint, Word};
6
7/// Binary XGCD update step.
8///
9/// This is a condensed, constant time execution of the following algorithm:
10/// ```text
11/// if a mod 2 == 1
12///    if a < b
13///        (a, b) ← (b, a)
14///        (f0, g0, f1, g1) ← (f1, g1, f0, g0)
15///    a ← a - b
16///    (f0, g0) ← (f0 - f1, g0 - g1)
17/// if a > 0
18///     a ← a/2
19///     (f1, g1) ← (2f1, 2g1)
20/// ```
21/// where `matrix` represents
22/// ```text
23///  (f0 g0)
24///  (f1 g1).
25/// ```
26///
27/// Note: this algorithm assumes `b` to be an odd integer. The algorithm will likely not yield
28/// the correct result when this is not the case.
29///
30/// Ref: Pornin, Algorithm 2, L8-17, <https://eprint.iacr.org/2020/972.pdf>.
31#[inline(always)]
32const fn binxgcd_step<const LIMBS: usize, const MATRIX_LIMBS: usize>(
33    a: &mut Uint<LIMBS>,
34    b: &mut Uint<LIMBS>,
35    matrix: &mut DividedPatternMatrix<MATRIX_LIMBS>,
36    halt_at_zero: Choice,
37) -> Word {
38    let (a_odd, swap, j) = bingcd_step(a, b);
39
40    // swap if a odd and a < b
41    matrix.conditional_swap_rows(swap);
42
43    // subtract b from a when a is odd
44    matrix.conditional_subtract_bottom_row_from_top(a_odd);
45
46    // Double the bottom row of the matrix when a was ≠ 0 and when not halting.
47    matrix.conditional_double_bottom_row(a.is_nonzero().or(halt_at_zero.not()));
48
49    j
50}
51
52/// Container for the raw output of the Binary XGCD algorithm.
53pub(crate) struct RawXgcdOutput<const LIMBS: usize, MATRIX> {
54    gcd: OddUint<LIMBS>,
55    matrix: MATRIX,
56}
57
58pub(crate) type DividedPatternXgcdOutput<const LIMBS: usize> =
59    RawXgcdOutput<LIMBS, DividedPatternMatrix<LIMBS>>;
60
61impl<const LIMBS: usize> DividedPatternXgcdOutput<LIMBS> {
62    /// Divide `self.matrix.inner` by `2^self.matrix.k`, allowing us to simplify `inner` from a
63    /// [`DividedPatternMatrix`] to a [`PatternMatrix`].
64    ///
65    /// The performed divisions are modulo `lhs/gcd` and `rhs/gcd` to maintain the correctness of
66    /// the XGCD state.
67    ///
68    /// This operation is 'fast' since it only applies the division to the top row of the matrix.
69    /// This is allowed since it is assumed that `self.matrix * (lhs, rhs) = (gcd, 0)`; dividing
70    /// the bottom row of the matrix by a constant has no impact since its inner-product with the
71    /// input vector is zero.
72    ///
73    /// Executes in variable time w.r.t. `k_upper_bound`.
74    pub(crate) const fn divide(self) -> PatternXgcdOutput<LIMBS> {
75        let DividedPatternMatrix {
76            inner: mut matrix,
77            k,
78            k_upper_bound,
79            ..
80        } = self.matrix;
81
82        let PatternMatrix {
83            m00: x,
84            m01: y,
85            m10: rhs_div_gcd,
86            m11: lhs_div_gcd,
87            ..
88        } = &mut matrix;
89
90        if k_upper_bound > 0 {
91            *x = x.bounded_div2k_mod_q(
92                k,
93                k_upper_bound,
94                &rhs_div_gcd.to_odd().expect_copied("odd by construction"),
95            );
96            *y = y.bounded_div2k_mod_q(
97                k,
98                k_upper_bound,
99                &lhs_div_gcd.to_odd().expect_copied("odd by construction"),
100            );
101        }
102
103        PatternXgcdOutput {
104            gcd: self.gcd,
105            matrix,
106        }
107    }
108}
109
110pub(crate) type PatternXgcdOutput<const LIMBS: usize> = RawXgcdOutput<LIMBS, PatternMatrix<LIMBS>>;
111
112impl<const LIMBS: usize> PatternXgcdOutput<LIMBS> {
113    /// Obtain the `gcd`.
114    pub(crate) const fn gcd(&self) -> OddUint<LIMBS> {
115        self.gcd
116    }
117
118    /// Obtain the bezout coefficients `(x, y)` such that `lhs * x + rhs * y = gcd`.
119    pub(crate) const fn bezout_coefficients(&self) -> (Int<LIMBS>, Int<LIMBS>) {
120        let PatternMatrix {
121            m00,
122            m01,
123            m10,
124            m11,
125            pattern,
126            ..
127        } = self.matrix;
128
129        // TODO: can we simplify this?
130        let m10_sub_m00 = m10.wrapping_sub(&m00);
131        let m11_sub_m01 = m11.wrapping_sub(&m01);
132        let apply = Uint::lte(&m10_sub_m00, &m00).and(Uint::lte(&m11_sub_m01, &m01));
133
134        let m00 = *Uint::select(&m00, &m10_sub_m00, apply)
135            .wrapping_neg_if(apply.xor(pattern.not()))
136            .as_int();
137        let m01 = *Uint::select(&m01, &m11_sub_m01, apply)
138            .wrapping_neg_if(apply.xor(pattern))
139            .as_int();
140        (m00, m01)
141    }
142
143    /// Obtain the quotients `lhs/gcd` and `rhs/gcd` from `matrix`.
144    pub(crate) const fn quotients(&self) -> (Uint<LIMBS>, Uint<LIMBS>) {
145        let PatternMatrix {
146            m10: rhs_div_gcd,
147            m11: lhs_div_gcd,
148            ..
149        } = self.matrix;
150        (lhs_div_gcd, rhs_div_gcd)
151    }
152}
153
154/// Number of bits used by [`OddUint::optimized_binxgcd`] to represent a "compact" [`Uint`].
155const SUMMARY_BITS: u32 = U64::BITS - 1;
156
157/// Number of limbs used to represent [`SUMMARY_BITS`].
158const SUMMARY_LIMBS: usize = U64::LIMBS;
159
160/// Twice the number of limbs used to represent [`SUMMARY_BITS`], i.e., two times [`SUMMARY_LIMBS`].
161const DOUBLE_SUMMARY_LIMBS: usize = U128::LIMBS;
162
163impl<const LIMBS: usize> OddUint<LIMBS> {
164    /// The minimal number of binary GCD iterations required to guarantee successful completion.
165    const MIN_BINXGCD_ITERATIONS: u32 = 2 * Self::BITS - 1;
166
167    /// Given `(self, rhs)`, computes `(g, x, y)` s.t. `self * x + rhs * y = g = gcd(self, rhs)`,
168    /// leveraging the Binary Extended GCD algorithm.
169    pub(crate) const fn binxgcd_nz(&self, rhs: &NonZeroUint<LIMBS>) -> PatternXgcdOutput<LIMBS> {
170        let (lhs_, rhs_) = (self.as_ref(), rhs.as_ref());
171
172        // The `xgcd` subroutine requires `rhs` to be odd.
173        // We leverage the equality gcd(lhs, rhs) = gcd(lhs, |lhs-rhs|) to deal with the case that
174        // `rhs` is even.
175        let rhs_is_even = rhs_.is_odd().not();
176        let (abs_diff, rhs_gt_lhs) = lhs_.abs_diff(rhs_);
177        let odd_rhs = Odd(Uint::select(rhs_, &abs_diff, rhs_is_even));
178
179        let mut output = self.binxgcd_odd(&odd_rhs);
180        let matrix = &mut output.matrix;
181
182        // Modify the output to negate the transformation applied to the input.
183        let case_one = rhs_is_even.and(rhs_gt_lhs);
184        matrix.conditional_subtract_right_column_from_left(case_one);
185
186        let case_two = rhs_is_even.and(rhs_gt_lhs.not());
187        matrix.conditional_add_right_column_to_left(case_two);
188        matrix.conditional_negate(case_two);
189
190        output
191    }
192
193    /// Execute the classic Extended GCD algorithm.
194    ///
195    /// Given `(self, rhs)`, computes `(g, x, y)` s.t. `self * x + rhs * y = g = gcd(self, rhs)`.
196    #[inline]
197    pub(crate) const fn binxgcd_odd(&self, rhs: &Self) -> PatternXgcdOutput<LIMBS> {
198        if LIMBS < 4 {
199            self.classic_binxgcd(rhs).divide()
200        } else {
201            self.optimized_binxgcd(rhs).divide()
202        }
203    }
204
205    /// Execute the classic Binary Extended GCD algorithm.
206    ///
207    /// Given `(self, rhs)`, computes `(g, x, y)` s.t. `self * x + rhs * y = g = gcd(self, rhs)`.
208    ///
209    /// Ref: Pornin, Optimized Binary GCD for Modular Inversion, Algorithm 1.
210    /// <https://eprint.iacr.org/2020/972.pdf>.
211    pub(crate) const fn classic_binxgcd(&self, rhs: &Self) -> DividedPatternXgcdOutput<LIMBS> {
212        let (gcd, _, matrix, _) =
213            self.partial_binxgcd::<LIMBS>(rhs.as_ref(), Self::MIN_BINXGCD_ITERATIONS, Choice::TRUE);
214        DividedPatternXgcdOutput { gcd, matrix }
215    }
216
217    /// Given `(self, rhs)`, computes `(g, x, y)` s.t. `self * x + rhs * y = g = gcd(self, rhs)`,
218    /// leveraging the Binary Extended GCD algorithm.
219    ///
220    /// **Warning**: `self` and `rhs` must be contained in an [U128] or larger.
221    ///
222    /// Note: this algorithm becomes more efficient than the classical algorithm for [Uint]s with
223    /// relatively many `LIMBS`. A best-effort threshold is presented in [`Self::binxgcd`_].
224    ///
225    /// Note: the full algorithm has an additional parameter; this function selects the best-effort
226    /// value for this parameter. You might be able to further tune your performance by calling the
227    /// [`Self::optimized_bingcd`_] function directly.
228    ///
229    /// Ref: Pornin, Optimized Binary GCD for Modular Inversion, Algorithm 2.
230    /// <https://eprint.iacr.org/2020/972.pdf>.
231    pub(crate) const fn optimized_binxgcd(&self, rhs: &Self) -> DividedPatternXgcdOutput<LIMBS> {
232        assert!(Self::BITS >= U128::BITS);
233        self.optimized_binxgcd_::<SUMMARY_BITS, SUMMARY_LIMBS, DOUBLE_SUMMARY_LIMBS>(rhs)
234    }
235
236    /// Given `(self, rhs)`, computes `(g, x, y)`, s.t. `self * x + rhs * y = g = gcd(self, rhs)`,
237    /// leveraging the optimized Binary Extended GCD algorithm.
238    ///
239    /// Ref: Pornin, Optimized Binary GCD for Modular Inversion, Algorithm 2.
240    /// <https://eprint.iacr.org/2020/972.pdf>
241    ///
242    /// In summary, the optimized algorithm does not operate on `self` and `rhs` directly, but
243    /// instead of condensed summaries that fit in few registers. Based on these summaries, an
244    /// update matrix is constructed by which `self` and `rhs` are updated in larger steps.
245    ///
246    /// This function is generic over the following three values:
247    /// - `K`: the number of bits used when summarizing `self` and `rhs` for the inner loop. The
248    ///   `K+1` top bits and `K-1` least significant bits are selected. It is recommended to keep
249    ///   `K` close to a (multiple of) the number of bits that fit in a single register.
250    /// - `LIMBS_K`: should be chosen as the minimum number s.t. `Uint::<LIMBS>::BITS ≥ K`,
251    /// - `LIMBS_2K`: should be chosen as the minimum number s.t. `Uint::<LIMBS>::BITS ≥ 2K`.
252    pub(crate) const fn optimized_binxgcd_<
253        const K: u32,
254        const LIMBS_K: usize,
255        const LIMBS_2K: usize,
256    >(
257        &self,
258        rhs: &Self,
259    ) -> DividedPatternXgcdOutput<LIMBS> {
260        let (mut a, mut b) = (*self.as_ref(), *rhs.as_ref());
261        let mut state = DividedIntMatrix::UNIT;
262
263        let (mut a_is_negative, mut b_is_negative);
264        let mut i = 0;
265        while i < Self::MIN_BINXGCD_ITERATIONS.div_ceil(K - 1) {
266            // Loop invariants:
267            //  i) each iteration of this loop, `a.bits() + b.bits()` shrinks by at least K-1,
268            //     until `b = 0`.
269            // ii) `a` is odd.
270            i += 1;
271
272            // Construct compact_a and compact_b as the summary of a and b, respectively.
273            let b_bits = b.bits();
274            let n = u32_max(2 * K, u32_max(a.bits(), b_bits));
275            let compact_a = a.compact::<K, LIMBS_2K>(n);
276            let compact_b = b.compact::<K, LIMBS_2K>(n);
277            let b_eq_compact_b =
278                Choice::from_u32_le(b_bits, K - 1).or(Choice::from_u32_eq(n, 2 * K));
279
280            // Compute the K-1 iteration update matrix from a_ and b_
281            let (.., update_matrix, _) = compact_a
282                .to_odd()
283                .expect_copied("a is always odd")
284                .partial_binxgcd::<LIMBS_K>(&compact_b, K - 1, b_eq_compact_b);
285
286            // Update `a` and `b` using the update matrix
287            let (updated_a, updated_b) = update_matrix.extended_apply_to::<LIMBS, K>((a, b));
288            (a, a_is_negative) = updated_a.dropped_abs_sign();
289            (b, b_is_negative) = updated_b.dropped_abs_sign();
290
291            state = update_matrix.mul_int_matrix(&state);
292
293            // Correct for the sign change in a/b after dropping the extension.
294            state.conditional_negate_top_row(a_is_negative);
295            state.conditional_negate_bottom_row(b_is_negative);
296        }
297
298        let gcd = a
299            .to_odd()
300            .expect_copied("gcd of an odd value with something else is always odd");
301
302        let matrix = state.to_divided_pattern_matrix();
303        DividedPatternXgcdOutput { gcd, matrix }
304    }
305
306    /// Executes the optimized Binary GCD inner loop.
307    ///
308    /// Ref: Pornin, Optimized Binary GCD for Modular Inversion, Algorithm 2.
309    /// <https://eprint.iacr.org/2020/972.pdf>.
310    ///
311    /// The function outputs the reduced values `(a, b)` for the input values `(self, rhs)` as well
312    /// as the matrix that yields the former two when multiplied with the latter two.
313    ///
314    /// Note: this implementation deviates slightly from the paper, in that it can be instructed to
315    /// "run in place" (i.e., execute iterations that do nothing) once `a` becomes zero.
316    /// This is done by passing a truthy `halt_at_zero`.
317    ///
318    /// The function executes in time variable in `iterations`.
319    #[inline(always)]
320    pub(super) const fn partial_binxgcd<const UPDATE_LIMBS: usize>(
321        &self,
322        rhs: &Uint<LIMBS>,
323        iterations: u32,
324        halt_at_zero: Choice,
325    ) -> (Self, Uint<LIMBS>, DividedPatternMatrix<UPDATE_LIMBS>, Word) {
326        let (mut a, mut b) = (*self.as_ref(), *rhs);
327        // This matrix corresponds with (f0, g0, f1, g1) in the paper.
328        let mut matrix = DividedPatternMatrix::UNIT;
329
330        // Compute the update matrix.
331        // Note: to be consistent with the paper, the `binxgcd_step` algorithm requires the second
332        // argument to be odd. Here, we have `a` odd, so we have to swap a and b before and after
333        // calling the subroutine. The columns of the matrix have to be swapped accordingly.
334        Uint::swap(&mut a, &mut b);
335        matrix.swap_rows();
336
337        let mut jacobi_neg = 0;
338        let mut i = 0;
339
340        while i < iterations {
341            jacobi_neg ^=
342                binxgcd_step::<LIMBS, UPDATE_LIMBS>(&mut a, &mut b, &mut matrix, halt_at_zero);
343            i += 1;
344        }
345
346        // Undo swap
347        Uint::swap(&mut a, &mut b);
348        matrix.swap_rows();
349
350        let a = a.to_odd().expect_copied("a is always odd");
351        (a, b, matrix, jacobi_neg)
352    }
353}
354
355impl<const LIMBS: usize> Uint<LIMBS> {
356    /// Compute the absolute difference between `self` and `rhs`.
357    /// In addition to the result, also returns whether `rhs > self`.
358    const fn abs_diff(&self, rhs: &Self) -> (Self, Choice) {
359        let (diff, borrow) = self.borrowing_sub(rhs, Limb::ZERO);
360        let rhs_gt_self = borrow.is_nonzero();
361        let abs_diff = diff.wrapping_neg_if(rhs_gt_self);
362        (abs_diff, rhs_gt_self)
363    }
364}
365
366#[cfg(all(test, not(miri)))]
367mod tests {
368    use crate::modular::bingcd::xgcd::PatternXgcdOutput;
369    use crate::{Concat, Uint};
370    use core::ops::Div;
371
372    mod test_extract_quotients {
373        use crate::modular::bingcd::matrix::DividedPatternMatrix;
374        use crate::modular::bingcd::xgcd::{DividedPatternXgcdOutput, RawXgcdOutput};
375        use crate::{Choice, U64, Uint};
376
377        fn raw_binxgcdoutput_setup<const LIMBS: usize>(
378            matrix: DividedPatternMatrix<LIMBS>,
379        ) -> DividedPatternXgcdOutput<LIMBS> {
380            RawXgcdOutput {
381                gcd: Uint::<LIMBS>::ONE.to_odd().unwrap(),
382                matrix,
383            }
384        }
385
386        #[test]
387        fn test_extract_quotients_unit() {
388            let output =
389                raw_binxgcdoutput_setup(DividedPatternMatrix::<{ U64::LIMBS }>::UNIT).divide();
390            let (lhs_on_gcd, rhs_on_gcd) = output.quotients();
391            assert_eq!(lhs_on_gcd, Uint::ONE);
392            assert_eq!(rhs_on_gcd, Uint::ZERO);
393        }
394
395        #[test]
396        fn test_extract_quotients_basic() {
397            let output = raw_binxgcdoutput_setup(DividedPatternMatrix::<{ U64::LIMBS }>::new_u64(
398                (0, 0, 5, 7),
399                Choice::FALSE,
400                0,
401                0,
402            ))
403            .divide();
404            let (lhs_on_gcd, rhs_on_gcd) = output.quotients();
405            assert_eq!(lhs_on_gcd, Uint::from(7u32));
406            assert_eq!(rhs_on_gcd, Uint::from(5u32));
407
408            let output = raw_binxgcdoutput_setup(DividedPatternMatrix::<{ U64::LIMBS }>::new_u64(
409                (0, 0, 7u64, 5u64),
410                Choice::TRUE,
411                0,
412                0,
413            ))
414            .divide();
415            let (lhs_on_gcd, rhs_on_gcd) = output.quotients();
416            assert_eq!(lhs_on_gcd, Uint::from(5u32));
417            assert_eq!(rhs_on_gcd, Uint::from(7u32));
418        }
419    }
420
421    mod test_derive_bezout_coefficients {
422        use crate::modular::bingcd::matrix::DividedPatternMatrix;
423        use crate::modular::bingcd::xgcd::RawXgcdOutput;
424        use crate::{Choice, Int, U64, Uint};
425
426        #[test]
427        fn test_derive_bezout_coefficients_unit() {
428            let output = RawXgcdOutput {
429                gcd: Uint::ONE.to_odd().unwrap(),
430                matrix: DividedPatternMatrix::<{ U64::LIMBS }>::UNIT,
431            }
432            .divide();
433            let (x, y) = output.bezout_coefficients();
434            assert_eq!(x, Int::ONE);
435            assert_eq!(y, Int::ZERO);
436        }
437
438        #[test]
439        fn test_derive_bezout_coefficients_basic() {
440            let output = RawXgcdOutput {
441                gcd: U64::ONE.to_odd().unwrap(),
442                matrix: DividedPatternMatrix::new_u64((2u64, 3u64, 5u64, 5u64), Choice::TRUE, 0, 0),
443            }
444            .divide();
445            let (x, y) = output.bezout_coefficients();
446            assert_eq!(x, Int::from(2i32));
447            assert_eq!(y, Int::from(-3i32));
448
449            let output = RawXgcdOutput {
450                gcd: U64::ONE.to_odd().unwrap(),
451                matrix: DividedPatternMatrix::new_u64(
452                    (2u64, 3u64, 3u64, 5u64),
453                    Choice::FALSE,
454                    0,
455                    1,
456                ),
457            }
458            .divide();
459            let (x, y) = output.bezout_coefficients();
460            assert_eq!(x, Int::from(1i32));
461            assert_eq!(y, Int::from(-2i32));
462        }
463
464        #[test]
465        fn test_derive_bezout_coefficients_removes_doublings_easy() {
466            let output = RawXgcdOutput {
467                gcd: U64::ONE.to_odd().unwrap(),
468                matrix: DividedPatternMatrix::new_u64((2u64, 6u64, 3u64, 5u64), Choice::TRUE, 1, 1),
469            }
470            .divide();
471            let (x, y) = output.bezout_coefficients();
472            assert_eq!(x, Int::ONE);
473            assert_eq!(y, Int::from(-3i32));
474
475            let output = RawXgcdOutput {
476                gcd: U64::ONE.to_odd().unwrap(),
477                matrix: DividedPatternMatrix::new_u64(
478                    (120u64, 64u64, 7u64, 5u64),
479                    Choice::FALSE,
480                    5,
481                    6,
482                ),
483            }
484            .divide();
485            let (x, y) = output.bezout_coefficients();
486            assert_eq!(x, Int::from(-9i32));
487            assert_eq!(y, Int::from(2i32));
488        }
489
490        #[test]
491        fn test_derive_bezout_coefficients_removes_doublings_for_odd_numbers() {
492            let output = RawXgcdOutput {
493                gcd: U64::ONE.to_odd().unwrap(),
494                matrix: DividedPatternMatrix::new_u64(
495                    (2u64, 6u64, 7u64, 5u64),
496                    Choice::FALSE,
497                    3,
498                    7,
499                ),
500            }
501            .divide();
502            let (x, y) = output.bezout_coefficients();
503            assert_eq!(x, Int::from(-2i32));
504            assert_eq!(y, Int::from(2i32));
505        }
506    }
507
508    mod test_partial_binxgcd {
509        use crate::modular::bingcd::matrix::DividedPatternMatrix;
510        use crate::{Choice, Gcd, Odd, U64};
511
512        const A: Odd<U64> = Odd::from_be_hex("CA048AFA63CD6A1F");
513        const B: U64 = U64::from_be_hex("AE693BF7BE8E5566");
514
515        #[test]
516        fn test_partial_binxgcd() {
517            let (.., matrix, _) = A.partial_binxgcd::<{ U64::LIMBS }>(&B, 5, Choice::TRUE);
518            assert_eq!(matrix.k, 5);
519            assert_eq!(
520                matrix,
521                DividedPatternMatrix::new_u64((8u64, 4u64, 2u64, 5u64), Choice::TRUE, 5, 5)
522            );
523        }
524
525        #[test]
526        fn test_partial_binxgcd_constructs_correct_matrix() {
527            let target_a = U64::from_be_hex("1CB3FB3FA1218FDB").to_odd().unwrap();
528            let target_b = U64::from_be_hex("0EA028AF0F8966B6");
529
530            let (new_a, new_b, matrix, _) =
531                A.partial_binxgcd::<{ U64::LIMBS }>(&B, 5, Choice::TRUE);
532
533            assert_eq!(new_a, target_a);
534            assert_eq!(new_b, target_b);
535
536            let (computed_a, computed_b) =
537                matrix.extended_apply_to::<{ U64::LIMBS }, 6>((A.get(), B));
538            let computed_a = computed_a.dropped_abs_sign().0;
539            let computed_b = computed_b.dropped_abs_sign().0;
540
541            assert_eq!(computed_a, target_a);
542            assert_eq!(computed_b, target_b);
543        }
544
545        const SMALL_A: Odd<U64> = Odd::from_be_hex("0000000003CD6A1F");
546        const SMALL_B: U64 = U64::from_be_hex("000000000E8E5566");
547
548        #[test]
549        fn test_partial_binxgcd_halts() {
550            let (gcd, _, matrix, _) =
551                SMALL_A.partial_binxgcd::<{ U64::LIMBS }>(&SMALL_B, 60, Choice::TRUE);
552            assert_eq!(matrix.k, 35);
553            assert_eq!(matrix.k_upper_bound, 60);
554            assert_eq!(gcd.get(), SMALL_A.gcd(&SMALL_B));
555        }
556
557        #[test]
558        fn test_partial_binxgcd_does_not_halt() {
559            let (gcd, .., matrix, _) =
560                SMALL_A.partial_binxgcd::<{ U64::LIMBS }>(&SMALL_B, 60, Choice::FALSE);
561            assert_eq!(matrix.k, 60);
562            assert_eq!(matrix.k_upper_bound, 60);
563            assert_eq!(gcd.get(), SMALL_A.gcd(&SMALL_B));
564        }
565    }
566
567    /// Helper function to effectively test xgcd.
568    fn test_xgcd<const LIMBS: usize, const DOUBLE: usize>(
569        lhs: Uint<LIMBS>,
570        rhs: Uint<LIMBS>,
571        output: PatternXgcdOutput<LIMBS>,
572    ) where
573        Uint<LIMBS>: Concat<LIMBS, Output = Uint<DOUBLE>>,
574    {
575        // Test the gcd
576        assert_eq!(lhs.gcd(&rhs), output.gcd, "{lhs} {rhs}");
577
578        // Test the quotients
579        let (lhs_on_gcd, rhs_on_gcd) = output.quotients();
580        assert_eq!(lhs_on_gcd, lhs.div(output.gcd.as_nz_ref()));
581        assert_eq!(rhs_on_gcd, rhs.div(output.gcd.as_nz_ref()));
582
583        // Test the Bezout coefficients for correctness
584        let (x, y) = output.bezout_coefficients();
585        assert_eq!(
586            x.concatenating_mul_unsigned(&lhs) + y.concatenating_mul_unsigned(&rhs),
587            *output.gcd.resize().as_int(),
588        );
589
590        // Test the Bezout coefficients for minimality
591        assert!(x.abs() <= rhs.div(output.gcd.as_nz_ref()));
592        assert!(y.abs() <= lhs.div(output.gcd.as_nz_ref()));
593        if lhs != rhs {
594            assert!(x.abs() <= rhs_on_gcd.shr(1) || rhs_on_gcd.is_zero().to_bool());
595            assert!(y.abs() <= lhs_on_gcd.shr(1) || lhs_on_gcd.is_zero().to_bool());
596        }
597    }
598
599    mod test_binxgcd_nz {
600        use crate::modular::bingcd::xgcd::tests::test_xgcd;
601        use crate::{
602            Concat, Int, U64, U128, U192, U256, U384, U512, U768, U1024, U2048, U4096, U8192, Uint,
603        };
604
605        fn binxgcd_nz_test<const LIMBS: usize, const DOUBLE: usize>(
606            lhs: Uint<LIMBS>,
607            rhs: Uint<LIMBS>,
608        ) where
609            Uint<LIMBS>: Concat<LIMBS, Output = Uint<DOUBLE>>,
610        {
611            let output = lhs.to_odd().unwrap().binxgcd_nz(&rhs.to_nz().unwrap());
612            test_xgcd(lhs, rhs, output);
613        }
614
615        fn binxgcd_nz_tests<const LIMBS: usize, const DOUBLE: usize>()
616        where
617            Uint<LIMBS>: Concat<LIMBS, Output = Uint<DOUBLE>>,
618        {
619            let max_int = *Int::MAX.as_uint();
620            let int_abs_min = Int::MIN.abs();
621
622            binxgcd_nz_test(Uint::ONE, Uint::ONE);
623            binxgcd_nz_test(Uint::ONE, max_int);
624            binxgcd_nz_test(Uint::ONE, int_abs_min);
625            binxgcd_nz_test(Uint::ONE, Uint::MAX);
626            binxgcd_nz_test(max_int, Uint::ONE);
627            binxgcd_nz_test(max_int, max_int);
628            binxgcd_nz_test(max_int, int_abs_min);
629            binxgcd_nz_test(max_int, Uint::MAX);
630            binxgcd_nz_test(Uint::MAX, Uint::ONE);
631            binxgcd_nz_test(Uint::MAX, max_int);
632            binxgcd_nz_test(Uint::MAX, int_abs_min);
633            binxgcd_nz_test(Uint::MAX, Uint::MAX);
634        }
635
636        #[test]
637        fn test_binxgcd_nz() {
638            binxgcd_nz_tests::<{ U64::LIMBS }, { U128::LIMBS }>();
639            binxgcd_nz_tests::<{ U128::LIMBS }, { U256::LIMBS }>();
640            binxgcd_nz_tests::<{ U192::LIMBS }, { U384::LIMBS }>();
641            binxgcd_nz_tests::<{ U256::LIMBS }, { U512::LIMBS }>();
642            binxgcd_nz_tests::<{ U384::LIMBS }, { U768::LIMBS }>();
643            binxgcd_nz_tests::<{ U512::LIMBS }, { U1024::LIMBS }>();
644            binxgcd_nz_tests::<{ U1024::LIMBS }, { U2048::LIMBS }>();
645            binxgcd_nz_tests::<{ U2048::LIMBS }, { U4096::LIMBS }>();
646            binxgcd_nz_tests::<{ U4096::LIMBS }, { U8192::LIMBS }>();
647        }
648    }
649
650    mod test_classic_binxgcd {
651        use crate::modular::bingcd::xgcd::tests::test_xgcd;
652        use crate::{
653            Concat, Int, U64, U128, U192, U256, U384, U512, U768, U1024, U2048, U4096, U8192, Uint,
654        };
655
656        fn classic_binxgcd_test<const LIMBS: usize, const DOUBLE: usize>(
657            lhs: Uint<LIMBS>,
658            rhs: Uint<LIMBS>,
659        ) where
660            Uint<LIMBS>: Concat<LIMBS, Output = Uint<DOUBLE>>,
661        {
662            let output = lhs
663                .to_odd()
664                .unwrap()
665                .classic_binxgcd(&rhs.to_odd().unwrap())
666                .divide();
667            test_xgcd(lhs, rhs, output);
668        }
669
670        fn classic_binxgcd_tests<const LIMBS: usize, const DOUBLE: usize>()
671        where
672            Uint<LIMBS>: Concat<LIMBS, Output = Uint<DOUBLE>>,
673        {
674            let max_int = *Int::MAX.as_uint();
675
676            classic_binxgcd_test(Uint::ONE, Uint::ONE);
677            classic_binxgcd_test(Uint::ONE, max_int);
678            classic_binxgcd_test(Uint::ONE, Uint::MAX);
679            classic_binxgcd_test(max_int, Uint::ONE);
680            classic_binxgcd_test(max_int, max_int);
681            classic_binxgcd_test(max_int, Uint::MAX);
682            classic_binxgcd_test(Uint::MAX, Uint::ONE);
683            classic_binxgcd_test(Uint::MAX, max_int);
684            classic_binxgcd_test(Uint::MAX, Uint::MAX);
685        }
686
687        #[test]
688        fn test_classic_binxgcd() {
689            classic_binxgcd_tests::<{ U64::LIMBS }, { U128::LIMBS }>();
690            classic_binxgcd_tests::<{ U128::LIMBS }, { U256::LIMBS }>();
691            classic_binxgcd_tests::<{ U192::LIMBS }, { U384::LIMBS }>();
692            classic_binxgcd_tests::<{ U256::LIMBS }, { U512::LIMBS }>();
693            classic_binxgcd_tests::<{ U384::LIMBS }, { U768::LIMBS }>();
694            classic_binxgcd_tests::<{ U512::LIMBS }, { U1024::LIMBS }>();
695            classic_binxgcd_tests::<{ U1024::LIMBS }, { U2048::LIMBS }>();
696            classic_binxgcd_tests::<{ U2048::LIMBS }, { U4096::LIMBS }>();
697            classic_binxgcd_tests::<{ U4096::LIMBS }, { U8192::LIMBS }>();
698        }
699    }
700
701    mod test_optimized_binxgcd {
702        use crate::modular::bingcd::xgcd::tests::test_xgcd;
703        use crate::modular::bingcd::xgcd::{DOUBLE_SUMMARY_LIMBS, SUMMARY_BITS, SUMMARY_LIMBS};
704        use crate::{
705            Concat, Int, U64, U128, U192, U256, U384, U512, U768, U1024, U2048, U4096, U8192, Uint,
706        };
707
708        fn test<const LIMBS: usize, const DOUBLE: usize>(lhs: Uint<LIMBS>, rhs: Uint<LIMBS>)
709        where
710            Uint<LIMBS>: Concat<LIMBS, Output = Uint<DOUBLE>>,
711        {
712            let output = lhs
713                .to_odd()
714                .unwrap()
715                .optimized_binxgcd(&rhs.to_odd().unwrap())
716                .divide();
717            test_xgcd(lhs, rhs, output);
718        }
719
720        fn run_tests<const LIMBS: usize, const DOUBLE: usize>()
721        where
722            Uint<LIMBS>: Concat<LIMBS, Output = Uint<DOUBLE>>,
723        {
724            let upper_bound = *Int::MAX.as_uint();
725            test(Uint::ONE, Uint::ONE);
726            test(Uint::ONE, upper_bound);
727            test(Uint::ONE, Uint::MAX);
728            test(upper_bound, Uint::ONE);
729            test(upper_bound, upper_bound);
730            test(upper_bound, Uint::MAX);
731            test(Uint::MAX, Uint::ONE);
732            test(Uint::MAX, upper_bound);
733            test(Uint::MAX, Uint::MAX);
734        }
735
736        #[test]
737        fn test_optimized_binxgcd_edge_cases() {
738            // If one of these tests fails, you have probably tweaked the SUMMARY_BITS,
739            // SUMMARY_LIMBS or DOUBLE_SUMMARY_LIMBS settings. Please make sure to update these
740            // tests accordingly.
741            assert_eq!(SUMMARY_BITS, 63);
742            assert_eq!(SUMMARY_LIMBS, U64::LIMBS);
743            assert_eq!(DOUBLE_SUMMARY_LIMBS, U128::LIMBS);
744
745            // Case #1: a > b but a.compact() < b.compact()
746            let a = U256::from_be_hex(
747                "1234567890ABCDEF80000000000000000000000000000000BEDCBA0987654321",
748            );
749            let b = U256::from_be_hex(
750                "1234567890ABCDEF800000000000000000000000000000007EDCBA0987654321",
751            );
752            assert!(a > b);
753            assert!(
754                a.compact::<SUMMARY_BITS, DOUBLE_SUMMARY_LIMBS>(U256::BITS)
755                    < b.compact::<SUMMARY_BITS, DOUBLE_SUMMARY_LIMBS>(U256::BITS)
756            );
757            test(a, b);
758
759            // Case #2: a < b but a.compact() > b.compact()
760            test(b, a);
761
762            // Case #3: a > b but a.compact() = b.compact()
763            let a = U256::from_be_hex(
764                "1234567890ABCDEF80000000000000000000000000000000FEDCBA0987654321",
765            );
766            let b = U256::from_be_hex(
767                "1234567890ABCDEF800000000000000000000000000000007EDCBA0987654321",
768            );
769            assert!(a > b);
770            assert_eq!(
771                a.compact::<SUMMARY_BITS, DOUBLE_SUMMARY_LIMBS>(U256::BITS),
772                b.compact::<SUMMARY_BITS, DOUBLE_SUMMARY_LIMBS>(U256::BITS)
773            );
774            test(a, b);
775
776            // Case #4: a < b but a.compact() = b.compact()
777            test(b, a);
778        }
779
780        #[test]
781        fn optimized_binxgcd() {
782            run_tests::<{ U128::LIMBS }, { U256::LIMBS }>();
783            run_tests::<{ U192::LIMBS }, { U384::LIMBS }>();
784            run_tests::<{ U256::LIMBS }, { U512::LIMBS }>();
785            run_tests::<{ U384::LIMBS }, { U768::LIMBS }>();
786            run_tests::<{ U512::LIMBS }, { U1024::LIMBS }>();
787            run_tests::<{ U1024::LIMBS }, { U2048::LIMBS }>();
788            run_tests::<{ U2048::LIMBS }, { U4096::LIMBS }>();
789            run_tests::<{ U4096::LIMBS }, { U8192::LIMBS }>();
790        }
791    }
792}