Skip to main content

crypto_bigint/modular/
safegcd.rs

1//! Implementation of Bernstein-Yang modular inversion and GCD algorithm (a.k.a. safegcd)
2//! as described in: <https://eprint.iacr.org/2019/266>.
3//!
4//! Adapted from the Apache 2.0+MIT licensed implementation originally from:
5//! <https://github.com/taikoxyz/halo2curves/pull/2>
6//! <https://github.com/privacy-scaling-explorations/halo2curves/pull/83>
7//!
8//! Copyright (c) 2023 Privacy Scaling Explorations Team
9
10// TODO(tarcieri): optimized implementation for 32-bit platforms (#380)
11
12#[cfg(feature = "alloc")]
13pub(crate) mod boxed;
14
15use core::fmt;
16
17use crate::{Choice, CtOption, I64, Int, Limb, Odd, U64, Uint, primitives::u32_min};
18
19const GCD_BATCH_SIZE: u32 = 62;
20
21/// Modular multiplicative inverter based on the Bernstein-Yang method.
22///
23/// The inverter can be created for a specified modulus M and adjusting parameter A to compute
24/// the adjusted multiplicative inverses of positive integers, i.e. for computing
25/// (1 / x) * A (mod M) for a positive integer x.
26///
27/// The adjusting parameter allows computing the multiplicative inverses in the case of using the
28/// Montgomery representation for the input or the expected output. If R is the Montgomery
29/// factor, the multiplicative inverses in the appropriate representation can be computed
30/// provided that the value of A is chosen as follows:
31/// - A = 1, if both the input and the expected output are in the standard form
32/// - A = R^2 mod M, if both the input and the expected output are in the Montgomery form
33/// - A = R mod M, if either the input or the expected output is in the Montgomery form,
34///   but not both of them
35///
36/// The public methods of this type receive and return unsigned big integers as arrays of 64-bit
37/// chunks, the ordering of which is little-endian. Both the modulus and the integer to be
38/// inverted should not exceed 2 ^ (62 * L - 64).
39///
40/// For better understanding the implementation, the following resources are recommended:
41/// - D. Bernstein, B.-Y. Yang, "Fast constant-time gcd computation and modular inversion",
42///   <https://gcd.cr.yp.to/safegcd-20190413.pdf>
43/// - P. Wuille, "The safegcd implementation in libsecp256k1 explained",
44///   <https://github.com/bitcoin-core/secp256k1/blob/master/doc/safegcd_implementation.md>
45#[derive(Clone, Debug)]
46pub(crate) struct SafeGcdInverter<const LIMBS: usize> {
47    /// Modulus
48    pub(super) modulus: Odd<Uint<LIMBS>>,
49
50    /// Multiplicative inverse of the modulus modulo 2^62
51    inverse: u64,
52
53    /// Adjusting parameter (see toplevel documentation).
54    adjuster: Uint<LIMBS>,
55}
56
57/// Type of the Bernstein-Yang transition matrix multiplied by 2^62
58type Matrix = [[i64; 2]; 2];
59
60impl<const LIMBS: usize> SafeGcdInverter<LIMBS> {
61    /// Creates the inverter for specified modulus and adjusting parameter.
62    #[cfg(test)]
63    pub(crate) const fn new(modulus: &Odd<Uint<LIMBS>>, adjuster: &Uint<LIMBS>) -> Self {
64        Self::new_with_inverse(
65            modulus,
66            U64::from_u64(modulus.as_uint_ref().invert_mod_u64()),
67            adjuster,
68        )
69    }
70
71    #[inline]
72    pub(crate) const fn new_with_inverse(
73        modulus: &Odd<Uint<LIMBS>>,
74        inverse: U64,
75        adjuster: &Uint<LIMBS>,
76    ) -> Self {
77        Self {
78            modulus: *modulus,
79            inverse: inverse.as_uint_ref().lowest_u64(),
80            adjuster: *adjuster,
81        }
82    }
83
84    /// Returns either the adjusted modular multiplicative inverse for the argument or `None`
85    /// depending on invertibility of the argument, i.e. its coprimality with the modulus.
86    pub const fn invert(&self, value: &Uint<LIMBS>) -> CtOption<Uint<LIMBS>> {
87        invert_odd_mod_precomp::<LIMBS, false>(value, &self.modulus, self.inverse, &self.adjuster)
88    }
89
90    /// Returns either the adjusted modular multiplicative inverse for the argument or `None`
91    /// depending on invertibility of the argument, i.e. its coprimality with the modulus.
92    ///
93    /// This version is variable-time with respect to `value`.
94    pub const fn invert_vartime(&self, value: &Uint<LIMBS>) -> CtOption<Uint<LIMBS>> {
95        invert_odd_mod_precomp::<LIMBS, true>(value, &self.modulus, self.inverse, &self.adjuster)
96    }
97}
98
99#[inline]
100pub const fn invert_odd_mod<const LIMBS: usize, const VARTIME: bool>(
101    a: &Uint<LIMBS>,
102    m: &Odd<Uint<LIMBS>>,
103) -> CtOption<Uint<LIMBS>> {
104    let mi = m.as_uint_ref().invert_mod_u64();
105    invert_odd_mod_precomp::<LIMBS, VARTIME>(a, m, mi, &Uint::ONE)
106}
107
108/// Calculate the multiplicative inverse of `a` modulo `m`.
109const fn invert_odd_mod_precomp<const LIMBS: usize, const VARTIME: bool>(
110    a: &Uint<LIMBS>,
111    m: &Odd<Uint<LIMBS>>,
112    mi: u64,
113    e: &Uint<LIMBS>,
114) -> CtOption<Uint<LIMBS>> {
115    let a_nonzero = a.is_nonzero();
116    let (mut f, mut g) = (SignedInt::from_uint(*m.as_ref()), SignedInt::from_uint(*a));
117    let (mut d, mut e) = (SignedInt::<LIMBS>::ZERO, SignedInt::from_uint(*e));
118    let mut steps = iterations(Uint::<LIMBS>::BITS);
119    let mut delta = 1;
120    let mut t;
121
122    while steps > 0 {
123        if VARTIME && g.is_zero_vartime() {
124            break;
125        }
126        let batch = u32_min(steps, GCD_BATCH_SIZE);
127        (delta, t) = jump::<VARTIME>(f.lowest(), g.lowest(), delta, batch);
128        (f, g) = update_fg(&f, &g, t, batch);
129        (d, e) = update_de(&d, &e, m.as_ref(), mi, t, batch);
130        steps -= batch;
131    }
132
133    let d = d.norm(f.is_negative(), m.as_ref());
134    CtOption::new(d, Uint::eq(&f.magnitude, &Uint::ONE).and(a_nonzero))
135}
136
137/// Calculate the greatest common denominator of odd `f`, and `g`.
138pub const fn gcd_odd<const LIMBS: usize, const VARTIME: bool>(
139    f: &Odd<Uint<LIMBS>>,
140    g: &Uint<LIMBS>,
141) -> Odd<Uint<LIMBS>> {
142    let (mut f, mut g) = (SignedInt::from_uint(*f.as_ref()), SignedInt::from_uint(*g));
143    let mut steps = iterations(Uint::<LIMBS>::BITS);
144    let mut delta = 1;
145    let mut t;
146
147    while steps > 0 {
148        if VARTIME && g.is_zero_vartime() {
149            break;
150        }
151        let batch = u32_min(steps, GCD_BATCH_SIZE);
152        (delta, t) = jump::<VARTIME>(f.lowest(), g.lowest(), delta, batch);
153        (f, g) = update_fg(&f, &g, t, batch);
154        steps -= batch;
155    }
156
157    f.magnitude().to_odd().expect_copied("odd by construction")
158}
159
160/// Perform `batch` steps of the gcd reduction process on signed tail values `f` and `g`.
161#[inline]
162const fn jump<const VARTIME: bool>(
163    mut f: i64,
164    mut g: i64,
165    mut delta: i64,
166    mut batch: u32,
167) -> (i64, Matrix) {
168    debug_assert!(f & 1 == 1, "f must be odd");
169    let mut t = [[1i64, 0], [0, 1]];
170    while batch > 0 {
171        (f, g, delta, t) = if VARTIME {
172            jump_step_vartime(f, g, delta, t)
173        } else {
174            jump_step(f, g, delta, t)
175        };
176        batch -= 1;
177    }
178    (delta, t)
179}
180
181/// Perform one step of the gcd reduction in constant time.
182/// This follows the half-delta variant of safegcd-bounds which reduces the round count.
183/// <https://github.com/sipa/safegcd-bounds>
184#[inline(always)]
185#[allow(clippy::cast_sign_loss)]
186const fn jump_step(
187    mut f: i64,
188    mut g: i64,
189    mut delta: i64,
190    mut t: Matrix,
191) -> (i64, i64, i64, Matrix) {
192    let d_gtz = Choice::from_u64_nz((delta & !(delta >> 63)) as u64);
193    let g_odd = Choice::from_u64_lsb((g & 1) as u64);
194    let g_adj = g_odd.select_i64(0, f);
195    let swap = d_gtz.and(g_odd);
196    delta = swap.select_i64(2i64.wrapping_add(delta), 2i64.wrapping_sub(delta));
197    f = swap.select_i64(f, g);
198    g = swap.select_i64(g.wrapping_add(g_adj), g.wrapping_sub(g_adj)) >> 1;
199    t = [
200        [
201            swap.select_i64(t[0][0], t[1][0]) << 1,
202            swap.select_i64(t[0][1], t[1][1]) << 1,
203        ],
204        [
205            t[1][0].wrapping_add(g_odd.select_i64(0, d_gtz.select_i64(t[0][0], -t[0][0]))),
206            t[1][1].wrapping_add(g_odd.select_i64(0, d_gtz.select_i64(t[0][1], -t[0][1]))),
207        ],
208    ];
209    (f, g, delta, t)
210}
211
212/// Perform one step of the gcd reduction in variable time.
213#[inline(always)]
214const fn jump_step_vartime(
215    mut f: i64,
216    mut g: i64,
217    mut delta: i64,
218    mut t: Matrix,
219) -> (i64, i64, i64, Matrix) {
220    if (g & 1) != 0 {
221        (f, g, delta, t) = if delta > 0 {
222            (
223                g,
224                g.wrapping_sub(f),
225                2i64.wrapping_sub(delta),
226                [
227                    t[1],
228                    [t[1][0].wrapping_sub(t[0][0]), t[1][1].wrapping_sub(t[0][1])],
229                ],
230            )
231        } else {
232            (
233                f,
234                g.wrapping_add(f),
235                2i64.wrapping_add(delta),
236                [
237                    t[0],
238                    [t[1][0].wrapping_add(t[0][0]), t[1][1].wrapping_add(t[0][1])],
239                ],
240            )
241        };
242    } else {
243        delta = 2i64.wrapping_add(delta);
244    }
245    g >>= 1;
246    t[0][0] <<= 1;
247    t[0][1] <<= 1;
248    (f, g, delta, t)
249}
250
251#[inline]
252const fn update_fg<const LIMBS: usize>(
253    a: &SignedInt<LIMBS>,
254    b: &SignedInt<LIMBS>,
255    t: Matrix,
256    shift: u32,
257) -> (SignedInt<LIMBS>, SignedInt<LIMBS>) {
258    (
259        SignedInt::lincomb_int_reduce_shift(
260            a,
261            b,
262            &I64::from_i64(t[0][0]),
263            &I64::from_i64(t[0][1]),
264            shift,
265        ),
266        SignedInt::lincomb_int_reduce_shift(
267            a,
268            b,
269            &I64::from_i64(t[1][0]),
270            &I64::from_i64(t[1][1]),
271            shift,
272        ),
273    )
274}
275
276#[inline]
277const fn update_de<const LIMBS: usize>(
278    d: &SignedInt<LIMBS>,
279    e: &SignedInt<LIMBS>,
280    m: &Uint<LIMBS>,
281    mi: u64,
282    t: Matrix,
283    shift: u32,
284) -> (SignedInt<LIMBS>, SignedInt<LIMBS>) {
285    (
286        SignedInt::lincomb_int_reduce_shift_mod(
287            d,
288            e,
289            &Int::from_i64(t[0][0]),
290            &Int::from_i64(t[0][1]),
291            shift,
292            m,
293            U64::from_u64(mi),
294        ),
295        SignedInt::lincomb_int_reduce_shift_mod(
296            d,
297            e,
298            &Int::from_i64(t[1][0]),
299            &Int::from_i64(t[1][1]),
300            shift,
301            m,
302            U64::from_u64(mi),
303        ),
304    )
305}
306
307/// Conditionally negate a wide Uint represented by `(lo, hi)`.
308#[inline]
309const fn conditional_negate_in_place_wide<const L: usize, const H: usize>(
310    lo: &mut Uint<L>,
311    hi: &mut Uint<H>,
312    flag: Choice,
313) {
314    let (neg, carry) = lo.carrying_neg();
315    let hi_neg = hi
316        .not()
317        .wrapping_add(&Uint::select(&Uint::ZERO, &Uint::ONE, carry));
318    *lo = Uint::select(lo, &neg, flag);
319    *hi = Uint::select(hi, &hi_neg, flag);
320}
321
322/// Right shift a wide Uint represented by `(lo, hi)` returning any remaining high bits.
323#[inline]
324const fn shr_in_place_wide<const L: usize, const H: usize>(
325    lo: &mut Uint<L>,
326    hi: &mut Uint<H>,
327    shift: u32,
328) {
329    debug_assert!(H <= L);
330    debug_assert!(shift < Uint::<H>::BITS);
331    let copy = hi.shl_vartime(Uint::<H>::BITS - shift);
332    *hi = hi.shr_vartime(shift);
333    *lo = lo.shr_vartime(shift);
334    let mut offs = shift.div_ceil(Limb::BITS) as usize;
335    lo.limbs[L - offs] = lo.limbs[L - offs].bitor(copy.limbs[H - offs]);
336    loop {
337        offs -= 1;
338        if offs == 0 {
339            break;
340        }
341        lo.limbs[L - offs] = copy.limbs[H - offs];
342    }
343}
344
345/// Calculate the maximum number of iterations required according to
346/// safegcd-bounds: <https://github.com/sipa/safegcd-bounds>
347// NOTE: the division is non-constant-time, but this is used to compute the number of iterations we
348// perform which is leaked in timing information
349#[inline]
350#[allow(clippy::integer_division_remainder_used, reason = "public parameter")]
351const fn iterations(bits: u32) -> u32 {
352    (45907 * bits + 30179) / 19929
353}
354
355/// A `Uint` which carries a separate sign in order to maintain the same range.
356#[derive(Clone, Copy)]
357struct SignedInt<const LIMBS: usize> {
358    sign: Choice,
359    magnitude: Uint<LIMBS>,
360}
361
362impl<const LIMBS: usize> SignedInt<LIMBS> {
363    pub const ZERO: Self = Self::from_uint(Uint::ZERO);
364
365    /// Construct a new `SignedInt` from a `Uint`.
366    pub const fn from_uint(uint: Uint<LIMBS>) -> Self {
367        Self {
368            sign: Choice::FALSE,
369            magnitude: uint,
370        }
371    }
372
373    /// Construct a new `SignedInt` from a `Uint` and a sign flag.
374    pub const fn from_uint_sign(magnitude: Uint<LIMBS>, sign: Choice) -> Self {
375        Self { sign, magnitude }
376    }
377
378    /// Obtain the magnitude of the `SignedInt`, ie. its absolute value.
379    pub const fn magnitude(&self) -> Uint<LIMBS> {
380        self.magnitude
381    }
382
383    /// Determine if the `SignedInt` is non-zero.
384    pub const fn is_nonzero(&self) -> Choice {
385        self.magnitude.is_nonzero()
386    }
387
388    /// Determine if the `SignedInt` is zero in variable time.
389    pub const fn is_zero_vartime(&self) -> bool {
390        self.magnitude.is_zero_vartime()
391    }
392
393    /// Determine if the `SignedInt` is negative.
394    /// Note: `-0` is representable in this type, so it may be necessary
395    /// to check `self.is_nonzero()` as well.
396    pub const fn is_negative(&self) -> Choice {
397        self.sign
398    }
399
400    /// Extract the lowest 63 bits and convert to its signed representation.
401    #[allow(clippy::cast_possible_wrap)]
402    pub const fn lowest(&self) -> i64 {
403        let mag = (self.magnitude.as_uint_ref().lowest_u64() & (u64::MAX >> 1)) as i64;
404        self.sign.select_i64(mag, mag.wrapping_neg())
405    }
406
407    /// Compute the linear combination `a•b + c•d`, returning `(lo, hi, sign)`.
408    #[inline]
409    pub(crate) const fn lincomb_int<const RHS: usize>(
410        a: &SignedInt<LIMBS>,
411        b: &SignedInt<LIMBS>,
412        c: &Int<RHS>,
413        d: &Int<RHS>,
414    ) -> (Uint<LIMBS>, Uint<RHS>, Choice) {
415        let (c, c_sign) = c.abs_sign();
416        let (d, d_sign) = d.abs_sign();
417        // Each SignedInt • abs(Int) product leaves an empty upper bit.
418        let (mut x, mut x_hi) = a.magnitude.widening_mul(&c);
419        let x_neg = a.sign.xor(c_sign);
420        let (mut y, mut y_hi) = b.magnitude.widening_mul(&d);
421        let y_neg = b.sign.xor(d_sign);
422        let odd_neg = x_neg.xor(y_neg);
423
424        // Negate y if none or both of the multiplication results are negative.
425        conditional_negate_in_place_wide(&mut y, &mut y_hi, odd_neg.not());
426
427        let mut borrow;
428        (x, borrow) = x.borrowing_sub(&y, Limb::ZERO);
429        (x_hi, borrow) = x_hi.borrowing_sub(&y_hi, borrow);
430        let swap = borrow.is_nonzero().and(odd_neg);
431
432        // Negate the result if we did not negate y and there was a borrow,
433        // indicating that |y| > |x|.
434        conditional_negate_in_place_wide(&mut x, &mut x_hi, swap);
435
436        let sign = x_neg.and(swap.not()).or(y_neg.and(swap));
437        (x, x_hi, sign)
438    }
439
440    /// Compute the linear combination `a•b + c•d`, and shift the result
441    /// `shift` bits to the right, returning a signed value in the same range
442    /// as the `SignedInt` inputs.
443    pub(crate) const fn lincomb_int_reduce_shift<const S: usize>(
444        a: &Self,
445        b: &Self,
446        c: &Int<S>,
447        d: &Int<S>,
448        shift: u32,
449    ) -> Self {
450        debug_assert!(shift < Uint::<S>::BITS);
451        let (mut a, mut a_hi, a_sign) = Self::lincomb_int(a, b, c, d);
452        shr_in_place_wide(&mut a, &mut a_hi, shift);
453        SignedInt::from_uint_sign(a, a_sign)
454    }
455
456    /// Compute the linear combination `a•b + c•d`, and shift the result
457    /// `shift` bits to the right modulo `m`, returning a signed value in the
458    /// same range as the `SignedInt` inputs.
459    pub(crate) const fn lincomb_int_reduce_shift_mod<const S: usize>(
460        a: &Self,
461        b: &Self,
462        c: &Int<S>,
463        d: &Int<S>,
464        shift: u32,
465        m: &Uint<LIMBS>,
466        mi: Uint<S>,
467    ) -> SignedInt<LIMBS> {
468        debug_assert!(shift < Uint::<S>::BITS);
469        let (mut x, mut x_hi, mut x_sign) = SignedInt::lincomb_int(a, b, c, d);
470
471        // Compute the multiple of m that will clear the low N bits of (x, x_hi).
472        let mut mf = x.resize::<S>().wrapping_mul(&mi);
473        mf = mf.bitand(&Uint::MAX.shr_vartime(Uint::<S>::BITS - shift));
474        let (xa, xa_hi) = m.widening_mul(&mf);
475
476        // Subtract the adjustment from (x, x_hi) potentially producing a borrow.
477        let mut borrow;
478        (x, borrow) = x.borrowing_sub(&xa, Limb::ZERO);
479        (x_hi, borrow) = x_hi.borrowing_sub(&xa_hi, borrow);
480
481        // Negate (x, x_hi) if the subtraction borrowed.
482        let swap = borrow.is_nonzero();
483        conditional_negate_in_place_wide(&mut x, &mut x_hi, swap);
484        x_sign = x_sign.xor(swap);
485
486        // Shift the result, eliminating the trailing zeros.
487        shr_in_place_wide(&mut x, &mut x_hi, shift);
488        debug_assert!(
489            x_hi.shr1().is_nonzero().not().to_bool_vartime(),
490            "overflow was larger than one bit"
491        );
492
493        // The magnitude x is now in the range [0, 2m). We conditionally subtract
494        // m in order to keep the output in (-m, m).
495        x = x.try_sub_with_carry(x_hi.limbs[0], m).0;
496
497        SignedInt::from_uint_sign(x, x_sign)
498    }
499
500    /// Normalize the value to a `Uint` in the range `[0, m)`.
501    const fn norm(&self, f_sign: Choice, m: &Uint<LIMBS>) -> Uint<LIMBS> {
502        let swap = f_sign.xor(self.sign).and(self.is_nonzero());
503        Uint::select(&self.magnitude, &m.wrapping_sub(&self.magnitude), swap)
504    }
505
506    /// Compare two `SignedInt` in constant time.
507    pub const fn eq(a: &Self, b: &Self) -> Choice {
508        Uint::eq(&a.magnitude, &b.magnitude).and(a.sign.eq(b.sign).or(a.is_nonzero().not()))
509    }
510}
511
512impl<const LIMBS: usize> fmt::Debug for SignedInt<LIMBS> {
513    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
514        f.write_fmt(format_args!(
515            "{}0x{}",
516            if self.sign.to_bool_vartime() {
517                "-"
518            } else {
519                "+"
520            },
521            &self.magnitude
522        ))
523    }
524}
525
526impl<const LIMBS: usize> PartialEq for SignedInt<LIMBS> {
527    fn eq(&self, other: &Self) -> bool {
528        Self::eq(self, other).to_bool_vartime()
529    }
530}
531
532#[cfg(test)]
533mod tests {
534    use super::SafeGcdInverter;
535    use crate::{U128, U256, modular::safegcd::shr_in_place_wide};
536
537    #[test]
538    fn invert() {
539        let g =
540            U256::from_be_hex("00000000CBF9350842F498CE441FC2DC23C7BF47D3DE91C327B2157C5E4EED77");
541        let modulus =
542            U256::from_be_hex("FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551")
543                .to_odd()
544                .unwrap();
545        let inverter = SafeGcdInverter::new(&modulus, &U256::ONE);
546        let result = inverter.invert(&g).unwrap();
547        assert_eq!(
548            U256::from_be_hex("FB668F8F509790BC549B077098918604283D42901C92981062EB48BC723F617B"),
549            result
550        );
551    }
552
553    #[test]
554    fn shr_wide() {
555        let hi = U128::from_u128(0x11111111222222223333333344444444);
556        let lo = U256::MAX;
557        let (mut a, mut a_hi) = (lo, hi);
558        shr_in_place_wide(&mut a, &mut a_hi, 16);
559        assert_eq!(
560            a,
561            U256::from_be_hex("4444FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF")
562        );
563        assert_eq!(a_hi, U128::from_u128(0x1111111122222222333333334444));
564        let (mut b, mut b_hi) = (lo, hi);
565        shr_in_place_wide(&mut b, &mut b_hi, 68);
566        assert_eq!(
567            b,
568            U256::from_be_hex("23333333344444444FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF")
569        );
570        assert_eq!(b_hi, U128::from_u128(0x111111112222222));
571    }
572}