Skip to main content

crypto_bigint/modular/
safegcd.rs

1//! Implementation of Bernstein-Yang modular inversion and GCD algorithm (a.k.a. safegcd)
2//! as described in: <https://eprint.iacr.org/2019/266>.
3//!
4//! Adapted from the Apache 2.0+MIT licensed implementation originally from:
5//! <https://github.com/taikoxyz/halo2curves/pull/2>
6//! <https://github.com/privacy-scaling-explorations/halo2curves/pull/83>
7//!
8//! Copyright (c) 2023 Privacy Scaling Explorations Team
9
10// TODO(tarcieri): optimized implementation for 32-bit platforms (#380)
11
12#[cfg(feature = "alloc")]
13pub(crate) mod boxed;
14
15use crate::{Choice, CtOption, I64, Int, Limb, Odd, U64, Uint, bitlen, primitives::u32_min};
16use core::fmt;
17
18const GCD_BATCH_SIZE: u32 = 62;
19
20/// Modular multiplicative inverter based on the Bernstein-Yang method.
21///
22/// The inverter can be created for a specified modulus M and adjusting parameter A to compute
23/// the adjusted multiplicative inverses of positive integers, i.e. for computing
24/// (1 / x) * A (mod M) for a positive integer x.
25///
26/// The adjusting parameter allows computing the multiplicative inverses in the case of using the
27/// Montgomery representation for the input or the expected output. If R is the Montgomery
28/// factor, the multiplicative inverses in the appropriate representation can be computed
29/// provided that the value of A is chosen as follows:
30/// - A = 1, if both the input and the expected output are in the standard form
31/// - A = R^2 mod M, if both the input and the expected output are in the Montgomery form
32/// - A = R mod M, if either the input or the expected output is in the Montgomery form,
33///   but not both of them
34///
35/// The public methods of this type receive and return unsigned big integers as arrays of 64-bit
36/// chunks, the ordering of which is little-endian. Both the modulus and the integer to be
37/// inverted should not exceed 2 ^ (62 * L - 64).
38///
39/// For better understanding the implementation, the following resources are recommended:
40/// - D. Bernstein, B.-Y. Yang, "Fast constant-time gcd computation and modular inversion",
41///   <https://gcd.cr.yp.to/safegcd-20190413.pdf>
42/// - P. Wuille, "The safegcd implementation in libsecp256k1 explained",
43///   <https://github.com/bitcoin-core/secp256k1/blob/master/doc/safegcd_implementation.md>
44#[derive(Clone, Debug)]
45pub(crate) struct SafeGcdInverter<const LIMBS: usize> {
46    /// Modulus
47    pub(super) modulus: Odd<Uint<LIMBS>>,
48
49    /// Multiplicative inverse of the modulus modulo 2^62
50    inverse: u64,
51
52    /// Adjusting parameter (see toplevel documentation).
53    adjuster: Uint<LIMBS>,
54}
55
56/// Type of the Bernstein-Yang transition matrix multiplied by 2^62
57type Matrix = [[i64; 2]; 2];
58
59impl<const LIMBS: usize> SafeGcdInverter<LIMBS> {
60    /// Creates the inverter for specified modulus and adjusting parameter.
61    #[cfg(test)]
62    pub(crate) const fn new(modulus: &Odd<Uint<LIMBS>>, adjuster: &Uint<LIMBS>) -> Self {
63        Self::new_with_inverse(
64            modulus,
65            U64::from_u64(modulus.as_uint_ref().invert_mod_u64()),
66            adjuster,
67        )
68    }
69
70    #[inline]
71    pub(crate) const fn new_with_inverse(
72        modulus: &Odd<Uint<LIMBS>>,
73        inverse: U64,
74        adjuster: &Uint<LIMBS>,
75    ) -> Self {
76        Self {
77            modulus: *modulus,
78            inverse: inverse.as_uint_ref().lowest_u64(),
79            adjuster: *adjuster,
80        }
81    }
82
83    /// Returns either the adjusted modular multiplicative inverse for the argument or `None`
84    /// depending on invertibility of the argument, i.e. its coprimality with the modulus.
85    pub const fn invert(&self, value: &Uint<LIMBS>) -> CtOption<Uint<LIMBS>> {
86        invert_odd_mod_precomp::<LIMBS, false>(value, &self.modulus, self.inverse, &self.adjuster)
87    }
88
89    /// Returns either the adjusted modular multiplicative inverse for the argument or `None`
90    /// depending on invertibility of the argument, i.e. its coprimality with the modulus.
91    ///
92    /// This version is variable-time with respect to `value`.
93    pub const fn invert_vartime(&self, value: &Uint<LIMBS>) -> CtOption<Uint<LIMBS>> {
94        invert_odd_mod_precomp::<LIMBS, true>(value, &self.modulus, self.inverse, &self.adjuster)
95    }
96}
97
98#[inline]
99pub const fn invert_odd_mod<const LIMBS: usize, const VARTIME: bool>(
100    a: &Uint<LIMBS>,
101    m: &Odd<Uint<LIMBS>>,
102) -> CtOption<Uint<LIMBS>> {
103    let mi = m.as_uint_ref().invert_mod_u64();
104    invert_odd_mod_precomp::<LIMBS, VARTIME>(a, m, mi, &Uint::ONE)
105}
106
107/// Calculate the multiplicative inverse of `a` modulo `m`.
108const fn invert_odd_mod_precomp<const LIMBS: usize, const VARTIME: bool>(
109    a: &Uint<LIMBS>,
110    m: &Odd<Uint<LIMBS>>,
111    mi: u64,
112    e: &Uint<LIMBS>,
113) -> CtOption<Uint<LIMBS>> {
114    let a_nonzero = a.is_nonzero();
115    let (mut f, mut g) = (SignedInt::from_uint(*m.as_ref()), SignedInt::from_uint(*a));
116    let (mut d, mut e) = (SignedInt::<LIMBS>::ZERO, SignedInt::from_uint(*e));
117    let mut steps = iterations(Uint::<LIMBS>::BITS);
118    let mut delta = 1;
119    let mut t;
120
121    while steps > 0 {
122        if VARTIME && g.is_zero_vartime() {
123            break;
124        }
125        let batch = u32_min(steps, GCD_BATCH_SIZE);
126        (delta, t) = jump::<VARTIME>(f.lowest(), g.lowest(), delta, batch);
127        (f, g) = update_fg(&f, &g, t, batch);
128        (d, e) = update_de(&d, &e, m.as_ref(), mi, t, batch);
129        steps -= batch;
130    }
131
132    let d = d.norm(f.is_negative(), m.as_ref());
133    CtOption::new(d, Uint::eq(&f.magnitude, &Uint::ONE).and(a_nonzero))
134}
135
136/// Calculate the greatest common denominator of odd `f`, and `g`.
137pub const fn gcd_odd<const LIMBS: usize, const VARTIME: bool>(
138    f: &Odd<Uint<LIMBS>>,
139    g: &Uint<LIMBS>,
140) -> Odd<Uint<LIMBS>> {
141    let (mut f, mut g) = (SignedInt::from_uint(*f.as_ref()), SignedInt::from_uint(*g));
142    let mut steps = iterations(Uint::<LIMBS>::BITS);
143    let mut delta = 1;
144    let mut t;
145
146    while steps > 0 {
147        if VARTIME && g.is_zero_vartime() {
148            break;
149        }
150        let batch = u32_min(steps, GCD_BATCH_SIZE);
151        (delta, t) = jump::<VARTIME>(f.lowest(), g.lowest(), delta, batch);
152        (f, g) = update_fg(&f, &g, t, batch);
153        steps -= batch;
154    }
155
156    f.magnitude().to_odd().expect_copied("odd by construction")
157}
158
159/// Perform `batch` steps of the gcd reduction process on signed tail values `f` and `g`.
160#[inline]
161const fn jump<const VARTIME: bool>(
162    mut f: i64,
163    mut g: i64,
164    mut delta: i64,
165    mut batch: u32,
166) -> (i64, Matrix) {
167    debug_assert!(f & 1 == 1, "f must be odd");
168    let mut t = [[1i64, 0], [0, 1]];
169    while batch > 0 {
170        (f, g, delta, t) = if VARTIME {
171            jump_step_vartime(f, g, delta, t)
172        } else {
173            jump_step(f, g, delta, t)
174        };
175        batch -= 1;
176    }
177    (delta, t)
178}
179
180/// Perform one step of the gcd reduction in constant time.
181/// This follows the half-delta variant of safegcd-bounds which reduces the round count.
182/// <https://github.com/sipa/safegcd-bounds>
183#[inline(always)]
184#[allow(clippy::cast_sign_loss)]
185const fn jump_step(
186    mut f: i64,
187    mut g: i64,
188    mut delta: i64,
189    mut t: Matrix,
190) -> (i64, i64, i64, Matrix) {
191    let d_gtz = Choice::from_u64_nz((delta & !(delta >> 63)) as u64);
192    let g_odd = Choice::from_u64_lsb((g & 1) as u64);
193    let g_adj = g_odd.select_i64(0, f);
194    let swap = d_gtz.and(g_odd);
195    delta = swap.select_i64(2i64.wrapping_add(delta), 2i64.wrapping_sub(delta));
196    f = swap.select_i64(f, g);
197    g = swap.select_i64(g.wrapping_add(g_adj), g.wrapping_sub(g_adj)) >> 1;
198    t = [
199        [
200            swap.select_i64(t[0][0], t[1][0]) << 1,
201            swap.select_i64(t[0][1], t[1][1]) << 1,
202        ],
203        [
204            t[1][0].wrapping_add(g_odd.select_i64(0, d_gtz.select_i64(t[0][0], -t[0][0]))),
205            t[1][1].wrapping_add(g_odd.select_i64(0, d_gtz.select_i64(t[0][1], -t[0][1]))),
206        ],
207    ];
208    (f, g, delta, t)
209}
210
211/// Perform one step of the gcd reduction in variable time.
212#[inline(always)]
213const fn jump_step_vartime(
214    mut f: i64,
215    mut g: i64,
216    mut delta: i64,
217    mut t: Matrix,
218) -> (i64, i64, i64, Matrix) {
219    if (g & 1) != 0 {
220        (f, g, delta, t) = if delta > 0 {
221            (
222                g,
223                g.wrapping_sub(f),
224                2i64.wrapping_sub(delta),
225                [
226                    t[1],
227                    [t[1][0].wrapping_sub(t[0][0]), t[1][1].wrapping_sub(t[0][1])],
228                ],
229            )
230        } else {
231            (
232                f,
233                g.wrapping_add(f),
234                2i64.wrapping_add(delta),
235                [
236                    t[0],
237                    [t[1][0].wrapping_add(t[0][0]), t[1][1].wrapping_add(t[0][1])],
238                ],
239            )
240        };
241    } else {
242        delta = 2i64.wrapping_add(delta);
243    }
244    g >>= 1;
245    t[0][0] <<= 1;
246    t[0][1] <<= 1;
247    (f, g, delta, t)
248}
249
250#[inline]
251const fn update_fg<const LIMBS: usize>(
252    a: &SignedInt<LIMBS>,
253    b: &SignedInt<LIMBS>,
254    t: Matrix,
255    shift: u32,
256) -> (SignedInt<LIMBS>, SignedInt<LIMBS>) {
257    (
258        SignedInt::lincomb_int_reduce_shift(
259            a,
260            b,
261            &I64::from_i64(t[0][0]),
262            &I64::from_i64(t[0][1]),
263            shift,
264        ),
265        SignedInt::lincomb_int_reduce_shift(
266            a,
267            b,
268            &I64::from_i64(t[1][0]),
269            &I64::from_i64(t[1][1]),
270            shift,
271        ),
272    )
273}
274
275#[inline]
276const fn update_de<const LIMBS: usize>(
277    d: &SignedInt<LIMBS>,
278    e: &SignedInt<LIMBS>,
279    m: &Uint<LIMBS>,
280    mi: u64,
281    t: Matrix,
282    shift: u32,
283) -> (SignedInt<LIMBS>, SignedInt<LIMBS>) {
284    (
285        SignedInt::lincomb_int_reduce_shift_mod(
286            d,
287            e,
288            &Int::from_i64(t[0][0]),
289            &Int::from_i64(t[0][1]),
290            shift,
291            m,
292            U64::from_u64(mi),
293        ),
294        SignedInt::lincomb_int_reduce_shift_mod(
295            d,
296            e,
297            &Int::from_i64(t[1][0]),
298            &Int::from_i64(t[1][1]),
299            shift,
300            m,
301            U64::from_u64(mi),
302        ),
303    )
304}
305
306/// Conditionally negate a wide Uint represented by `(lo, hi)`.
307#[inline]
308const fn conditional_negate_in_place_wide<const L: usize, const H: usize>(
309    lo: &mut Uint<L>,
310    hi: &mut Uint<H>,
311    flag: Choice,
312) {
313    let (neg, carry) = lo.carrying_neg();
314    let hi_neg = hi
315        .not()
316        .wrapping_add(&Uint::select(&Uint::ZERO, &Uint::ONE, carry));
317    *lo = Uint::select(lo, &neg, flag);
318    *hi = Uint::select(hi, &hi_neg, flag);
319}
320
321/// Right shift a wide Uint represented by `(lo, hi)` returning any remaining high bits.
322#[inline]
323const fn shr_in_place_wide<const L: usize, const H: usize>(
324    lo: &mut Uint<L>,
325    hi: &mut Uint<H>,
326    shift: u32,
327) {
328    debug_assert!(H <= L);
329    debug_assert!(shift < Uint::<H>::BITS);
330    let copy = hi.shl_vartime(Uint::<H>::BITS - shift);
331    *hi = hi.shr_vartime(shift);
332    *lo = lo.shr_vartime(shift);
333    let mut offs = bitlen::to_limbs(shift);
334    lo.limbs[L - offs] = lo.limbs[L - offs].bitor(copy.limbs[H - offs]);
335    loop {
336        offs -= 1;
337        if offs == 0 {
338            break;
339        }
340        lo.limbs[L - offs] = copy.limbs[H - offs];
341    }
342}
343
344/// Calculate the maximum number of iterations required according to
345/// safegcd-bounds: <https://github.com/sipa/safegcd-bounds>
346// NOTE: the division is non-constant-time, but this is used to compute the number of iterations we
347// perform which is leaked in timing information
348#[inline]
349#[allow(clippy::integer_division_remainder_used, reason = "public parameter")]
350const fn iterations(bits: u32) -> u32 {
351    (45907 * bits + 30179) / 19929
352}
353
354/// A `Uint` which carries a separate sign in order to maintain the same range.
355#[derive(Clone, Copy)]
356struct SignedInt<const LIMBS: usize> {
357    sign: Choice,
358    magnitude: Uint<LIMBS>,
359}
360
361impl<const LIMBS: usize> SignedInt<LIMBS> {
362    pub const ZERO: Self = Self::from_uint(Uint::ZERO);
363
364    /// Construct a new `SignedInt` from a `Uint`.
365    pub const fn from_uint(uint: Uint<LIMBS>) -> Self {
366        Self {
367            sign: Choice::FALSE,
368            magnitude: uint,
369        }
370    }
371
372    /// Construct a new `SignedInt` from a `Uint` and a sign flag.
373    pub const fn from_uint_sign(magnitude: Uint<LIMBS>, sign: Choice) -> Self {
374        Self { sign, magnitude }
375    }
376
377    /// Obtain the magnitude of the `SignedInt`, ie. its absolute value.
378    pub const fn magnitude(&self) -> Uint<LIMBS> {
379        self.magnitude
380    }
381
382    /// Determine if the `SignedInt` is non-zero.
383    pub const fn is_nonzero(&self) -> Choice {
384        self.magnitude.is_nonzero()
385    }
386
387    /// Determine if the `SignedInt` is zero in variable time.
388    pub const fn is_zero_vartime(&self) -> bool {
389        self.magnitude.is_zero_vartime()
390    }
391
392    /// Determine if the `SignedInt` is negative.
393    /// Note: `-0` is representable in this type, so it may be necessary
394    /// to check `self.is_nonzero()` as well.
395    pub const fn is_negative(&self) -> Choice {
396        self.sign
397    }
398
399    /// Extract the lowest 63 bits and convert to its signed representation.
400    #[allow(clippy::cast_possible_wrap)]
401    pub const fn lowest(&self) -> i64 {
402        let mag = (self.magnitude.as_uint_ref().lowest_u64() & (u64::MAX >> 1)) as i64;
403        self.sign.select_i64(mag, mag.wrapping_neg())
404    }
405
406    /// Compute the linear combination `a•b + c•d`, returning `(lo, hi, sign)`.
407    #[inline]
408    pub(crate) const fn lincomb_int<const RHS: usize>(
409        a: &SignedInt<LIMBS>,
410        b: &SignedInt<LIMBS>,
411        c: &Int<RHS>,
412        d: &Int<RHS>,
413    ) -> (Uint<LIMBS>, Uint<RHS>, Choice) {
414        let (c, c_sign) = c.abs_sign();
415        let (d, d_sign) = d.abs_sign();
416        // Each SignedInt • abs(Int) product leaves an empty upper bit.
417        let (mut x, mut x_hi) = a.magnitude.widening_mul(&c);
418        let x_neg = a.sign.xor(c_sign);
419        let (mut y, mut y_hi) = b.magnitude.widening_mul(&d);
420        let y_neg = b.sign.xor(d_sign);
421        let odd_neg = x_neg.xor(y_neg);
422
423        // Negate y if none or both of the multiplication results are negative.
424        conditional_negate_in_place_wide(&mut y, &mut y_hi, odd_neg.not());
425
426        let mut borrow;
427        (x, borrow) = x.borrowing_sub(&y, Limb::ZERO);
428        (x_hi, borrow) = x_hi.borrowing_sub(&y_hi, borrow);
429        let swap = borrow.is_nonzero().and(odd_neg);
430
431        // Negate the result if we did not negate y and there was a borrow,
432        // indicating that |y| > |x|.
433        conditional_negate_in_place_wide(&mut x, &mut x_hi, swap);
434
435        let sign = x_neg.and(swap.not()).or(y_neg.and(swap));
436        (x, x_hi, sign)
437    }
438
439    /// Compute the linear combination `a•b + c•d`, and shift the result
440    /// `shift` bits to the right, returning a signed value in the same range
441    /// as the `SignedInt` inputs.
442    pub(crate) const fn lincomb_int_reduce_shift<const S: usize>(
443        a: &Self,
444        b: &Self,
445        c: &Int<S>,
446        d: &Int<S>,
447        shift: u32,
448    ) -> Self {
449        debug_assert!(shift < Uint::<S>::BITS);
450        let (mut a, mut a_hi, a_sign) = Self::lincomb_int(a, b, c, d);
451        shr_in_place_wide(&mut a, &mut a_hi, shift);
452        SignedInt::from_uint_sign(a, a_sign)
453    }
454
455    /// Compute the linear combination `a•b + c•d`, and shift the result
456    /// `shift` bits to the right modulo `m`, returning a signed value in the
457    /// same range as the `SignedInt` inputs.
458    pub(crate) const fn lincomb_int_reduce_shift_mod<const S: usize>(
459        a: &Self,
460        b: &Self,
461        c: &Int<S>,
462        d: &Int<S>,
463        shift: u32,
464        m: &Uint<LIMBS>,
465        mi: Uint<S>,
466    ) -> SignedInt<LIMBS> {
467        debug_assert!(shift < Uint::<S>::BITS);
468        let (mut x, mut x_hi, mut x_sign) = SignedInt::lincomb_int(a, b, c, d);
469
470        // Compute the multiple of m that will clear the low N bits of (x, x_hi).
471        let mut mf = x.resize::<S>().wrapping_mul(&mi);
472        mf = mf.bitand(&Uint::MAX.shr_vartime(Uint::<S>::BITS - shift));
473        let (xa, xa_hi) = m.widening_mul(&mf);
474
475        // Subtract the adjustment from (x, x_hi) potentially producing a borrow.
476        let mut borrow;
477        (x, borrow) = x.borrowing_sub(&xa, Limb::ZERO);
478        (x_hi, borrow) = x_hi.borrowing_sub(&xa_hi, borrow);
479
480        // Negate (x, x_hi) if the subtraction borrowed.
481        let swap = borrow.is_nonzero();
482        conditional_negate_in_place_wide(&mut x, &mut x_hi, swap);
483        x_sign = x_sign.xor(swap);
484
485        // Shift the result, eliminating the trailing zeros.
486        shr_in_place_wide(&mut x, &mut x_hi, shift);
487        debug_assert!(
488            x_hi.shr1().is_nonzero().not().to_bool_vartime(),
489            "overflow was larger than one bit"
490        );
491
492        // The magnitude x is now in the range [0, 2m). We conditionally subtract
493        // m in order to keep the output in (-m, m).
494        x = x.try_sub_with_carry(x_hi.limbs[0], m).0;
495
496        SignedInt::from_uint_sign(x, x_sign)
497    }
498
499    /// Normalize the value to a `Uint` in the range `[0, m)`.
500    const fn norm(&self, f_sign: Choice, m: &Uint<LIMBS>) -> Uint<LIMBS> {
501        let swap = f_sign.xor(self.sign).and(self.is_nonzero());
502        Uint::select(&self.magnitude, &m.wrapping_sub(&self.magnitude), swap)
503    }
504
505    /// Compare two `SignedInt` in constant time.
506    pub const fn eq(a: &Self, b: &Self) -> Choice {
507        Uint::eq(&a.magnitude, &b.magnitude).and(a.sign.eq(b.sign).or(a.is_nonzero().not()))
508    }
509}
510
511impl<const LIMBS: usize> fmt::Debug for SignedInt<LIMBS> {
512    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
513        f.write_fmt(format_args!(
514            "{}0x{}",
515            if self.sign.to_bool_vartime() {
516                "-"
517            } else {
518                "+"
519            },
520            &self.magnitude
521        ))
522    }
523}
524
525impl<const LIMBS: usize> PartialEq for SignedInt<LIMBS> {
526    fn eq(&self, other: &Self) -> bool {
527        Self::eq(self, other).to_bool_vartime()
528    }
529}
530
531#[cfg(test)]
532mod tests {
533    use super::SafeGcdInverter;
534    use crate::{U128, U256, modular::safegcd::shr_in_place_wide};
535
536    #[test]
537    fn invert() {
538        let g =
539            U256::from_be_hex("00000000CBF9350842F498CE441FC2DC23C7BF47D3DE91C327B2157C5E4EED77");
540        let modulus =
541            U256::from_be_hex("FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551")
542                .to_odd()
543                .unwrap();
544        let inverter = SafeGcdInverter::new(&modulus, &U256::ONE);
545        let result = inverter.invert(&g).unwrap();
546        assert_eq!(
547            U256::from_be_hex("FB668F8F509790BC549B077098918604283D42901C92981062EB48BC723F617B"),
548            result
549        );
550    }
551
552    #[test]
553    fn shr_wide() {
554        let hi = U128::from_u128(0x11111111222222223333333344444444);
555        let lo = U256::MAX;
556        let (mut a, mut a_hi) = (lo, hi);
557        shr_in_place_wide(&mut a, &mut a_hi, 16);
558        assert_eq!(
559            a,
560            U256::from_be_hex("4444FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF")
561        );
562        assert_eq!(a_hi, U128::from_u128(0x1111111122222222333333334444));
563        let (mut b, mut b_hi) = (lo, hi);
564        shr_in_place_wide(&mut b, &mut b_hi, 68);
565        assert_eq!(
566            b,
567            U256::from_be_hex("23333333344444444FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF")
568        );
569        assert_eq!(b_hi, U128::from_u128(0x111111112222222));
570    }
571}