Skip to main content

crypto_bigint/modular/safegcd/
boxed.rs

1//! Implementation of Bernstein-Yang modular inversion and GCD algorithm (a.k.a. safegcd)
2//! as described in: <https://eprint.iacr.org/2019/266>.
3//!
4//! See parent module for more information.
5
6use super::{GCD_BATCH_SIZE, Matrix, iterations, jump};
7use crate::{
8    BoxedUint, Choice, ConcatenatingMul, CtAssign, CtOption, CtSelect, I64, Int, Limb, NonZero,
9    Odd, Resize, U64, Uint,
10    primitives::{u32_max, u32_min},
11};
12use core::fmt;
13
14/// Modular multiplicative inverter based on the Bernstein-Yang method.
15///
16/// See [`super::SafeGcdInverter`] for more information.
17#[derive(Clone, Debug)]
18pub(crate) struct BoxedSafeGcdInverter {
19    /// Modulus
20    pub(crate) modulus: Odd<BoxedUint>,
21
22    /// Multiplicative inverse of the modulus modulo 2^62
23    inverse: u64,
24
25    /// Adjusting parameter (see toplevel documentation).
26    adjuster: BoxedUint,
27}
28
29impl BoxedSafeGcdInverter {
30    /// Creates the inverter for specified modulus and adjusting parameter.
31    ///
32    /// Modulus must be odd. Returns `None` if it is not.
33    #[cfg(test)]
34    pub fn new(modulus: Odd<BoxedUint>, adjuster: BoxedUint) -> Self {
35        let inverse = U64::from_u64(modulus.as_uint_ref().invert_mod_u64());
36        Self::new_with_inverse(modulus, inverse, adjuster)
37    }
38
39    /// Creates the inverter for specified modulus and adjusting parameter.
40    ///
41    /// Modulus must be odd. Returns `None` if it is not.
42    pub(crate) fn new_with_inverse(
43        modulus: Odd<BoxedUint>,
44        inverse: U64,
45        mut adjuster: BoxedUint,
46    ) -> Self {
47        adjuster = adjuster.resize(modulus.bits_precision());
48        Self {
49            modulus,
50            inverse: inverse.as_uint_ref().lowest_u64(),
51            adjuster,
52        }
53    }
54
55    /// Perform constant-time modular inversion.
56    pub(crate) fn invert(&self, value: &BoxedUint) -> CtOption<BoxedUint> {
57        invert_odd_mod_precomp::<false>(
58            value,
59            &self.modulus,
60            self.inverse,
61            Some(self.adjuster.clone()),
62        )
63    }
64
65    /// Perform variable-time modular inversion.
66    pub(crate) fn invert_vartime(&self, value: &BoxedUint) -> CtOption<BoxedUint> {
67        invert_odd_mod_precomp::<true>(
68            value,
69            &self.modulus,
70            self.inverse,
71            Some(self.adjuster.clone()),
72        )
73    }
74}
75
76#[inline]
77pub fn invert_odd_mod<const VARTIME: bool>(
78    a: &BoxedUint,
79    m: &Odd<BoxedUint>,
80) -> CtOption<BoxedUint> {
81    let mi = m.as_uint_ref().invert_mod_u64();
82    invert_odd_mod_precomp::<VARTIME>(a, m, mi, None)
83}
84
85/// Calculate the multiplicative inverse of `a` modulo `m`.
86///
87fn invert_odd_mod_precomp<const VARTIME: bool>(
88    a: &BoxedUint,
89    m: &Odd<BoxedUint>,
90    mi: u64,
91    e: Option<BoxedUint>,
92) -> CtOption<BoxedUint> {
93    let a_nonzero = a.is_nonzero();
94    let bits_precision = u32_max(a.bits_precision(), m.as_ref().bits_precision());
95    let m = m.as_ref().resize(bits_precision);
96    let (mut f, mut g) = (
97        SignedBoxedInt::from_uint(m.clone()),
98        SignedBoxedInt::from_uint_with_precision(a, bits_precision),
99    );
100    let (mut d, mut e) = (
101        SignedBoxedInt::zero_with_precision(bits_precision),
102        SignedBoxedInt::from_uint(e.map_or_else(
103            || BoxedUint::one_with_precision(bits_precision),
104            |e| e.resize(bits_precision),
105        )),
106    );
107    let mut steps = iterations(bits_precision);
108    let mut delta = 1;
109    let mut t;
110
111    while steps > 0 {
112        if VARTIME && g.is_zero_vartime() {
113            break;
114        }
115        let batch = u32_min(steps, GCD_BATCH_SIZE);
116        (delta, t) = jump::<VARTIME>(f.lowest(), g.lowest(), delta, batch);
117        (f, g) = update_fg(&f, &g, t, batch);
118        (d, e) = update_de(&d, &e, &m, mi, t, batch);
119        steps -= batch;
120    }
121
122    let d = d
123        .norm(f.is_negative(), &m)
124        .resize_unchecked(a.bits_precision());
125
126    CtOption::new(d, f.magnitude().is_one() & a_nonzero)
127}
128
129/// Calculate the greatest common denominator of `f` and `g`.
130pub fn gcd<const VARTIME: bool>(f: &BoxedUint, g: &BoxedUint) -> BoxedUint {
131    if VARTIME {
132        if let Some(f_nz) = f.as_nz_vartime() {
133            gcd_nz::<VARTIME>(f_nz, g).get()
134        } else {
135            // gcd of (0, g) is g
136            g.clone()
137        }
138    } else {
139        let (f_nz, f_is_nonzero) = f.to_nz_or_one();
140        // gcd of (0, g) is g
141        let mut r = gcd_nz::<VARTIME>(&f_nz, g).get();
142        r.ct_assign(g, !f_is_nonzero);
143        r
144    }
145}
146
147/// Calculate the greatest common denominator of nonzero `f`, and `g`.
148pub fn gcd_nz<const VARTIME: bool>(f: &NonZero<BoxedUint>, g: &BoxedUint) -> NonZero<BoxedUint> {
149    if VARTIME {
150        if let Some(f_odd) = f.as_odd_vartime() {
151            return gcd_odd::<VARTIME>(f_odd, g).into_nz();
152        }
153    }
154
155    // Note the following two GCD identity rules:
156    // 1) gcd(2f, 2g) = 2•gcd(f, g), and
157    // 2) gcd(a, 2g) = gcd(f, g) if f is odd.
158    //
159    // Combined, these rules imply that
160    // 3) gcd(2^i•f, 2^j•g) = 2^k•gcd(f, g), with k = min(i, j).
161    //
162    // However, to save ourselves having to divide out 2^j, we also note that
163    // 4) 2^k•gcd(f, g) = 2^k•gcd(a, 2^j•b)
164
165    let i = f.as_ref().trailing_zeros();
166    let k = u32_min(i, g.trailing_zeros());
167
168    let f_odd = Odd::new_unchecked(f.as_ref().shr(i));
169    let mut r = gcd_odd::<VARTIME>(&f_odd, g).get();
170    r.shl_assign(k);
171    NonZero::new_unchecked(r)
172}
173
174/// Calculate the greatest common denominator of odd `f`, and `g`.
175pub fn gcd_odd<const VARTIME: bool>(f: &Odd<BoxedUint>, g: &BoxedUint) -> Odd<BoxedUint> {
176    let bits_precision = u32_max(f.as_ref().bits_precision(), g.bits_precision());
177    let (mut f, mut g) = (
178        SignedBoxedInt::from_uint_with_precision(f.as_ref(), bits_precision),
179        SignedBoxedInt::from_uint_with_precision(g, bits_precision),
180    );
181    let mut steps = iterations(bits_precision);
182    let mut delta = 1;
183    let mut t;
184
185    while steps > 0 {
186        if VARTIME && g.is_zero_vartime() {
187            break;
188        }
189        let batch = u32_min(steps, GCD_BATCH_SIZE);
190        (delta, t) = jump::<VARTIME>(f.lowest(), g.lowest(), delta, batch);
191        (f, g) = update_fg(&f, &g, t, batch);
192        steps -= batch;
193    }
194
195    f.magnitude()
196        .resize_unchecked(bits_precision)
197        .to_odd()
198        .expect("odd by construction")
199}
200
201#[inline]
202fn update_fg(
203    a: &SignedBoxedInt,
204    b: &SignedBoxedInt,
205    t: Matrix,
206    shift: u32,
207) -> (SignedBoxedInt, SignedBoxedInt) {
208    (
209        SignedBoxedInt::lincomb_int_reduce_shift(
210            a,
211            b,
212            &I64::from_i64(t[0][0]),
213            &I64::from_i64(t[0][1]),
214            shift,
215        ),
216        SignedBoxedInt::lincomb_int_reduce_shift(
217            a,
218            b,
219            &I64::from_i64(t[1][0]),
220            &I64::from_i64(t[1][1]),
221            shift,
222        ),
223    )
224}
225
226#[inline]
227fn update_de(
228    d: &SignedBoxedInt,
229    e: &SignedBoxedInt,
230    m: &BoxedUint,
231    mi: u64,
232    t: Matrix,
233    shift: u32,
234) -> (SignedBoxedInt, SignedBoxedInt) {
235    (
236        SignedBoxedInt::lincomb_int_reduce_shift_mod(
237            d,
238            e,
239            &Int::from_i64(t[0][0]),
240            &Int::from_i64(t[0][1]),
241            shift,
242            m,
243            U64::from_u64(mi),
244        ),
245        SignedBoxedInt::lincomb_int_reduce_shift_mod(
246            d,
247            e,
248            &Int::from_i64(t[1][0]),
249            &Int::from_i64(t[1][1]),
250            shift,
251            m,
252            U64::from_u64(mi),
253        ),
254    )
255}
256
257/// A `Uint` which carries a separate sign in order to maintain the same range.
258#[derive(Clone)]
259struct SignedBoxedInt {
260    sign: Choice,
261    magnitude: BoxedUint,
262}
263
264impl SignedBoxedInt {
265    pub fn zero_with_precision(bits_precision: u32) -> Self {
266        Self::from_uint(BoxedUint::zero_with_precision(bits_precision))
267    }
268
269    /// Construct a new `SignedInt` from a `Uint`.
270    pub const fn from_uint(uint: BoxedUint) -> Self {
271        Self {
272            sign: Choice::FALSE,
273            magnitude: uint,
274        }
275    }
276
277    /// Construct a new `SignedInt` from a `Uint`.
278    pub fn from_uint_with_precision(uint: &BoxedUint, bits_precision: u32) -> Self {
279        Self {
280            sign: Choice::FALSE,
281            magnitude: uint.resize(bits_precision),
282        }
283    }
284
285    /// Construct a new `SignedInt` from a `Uint` and a sign flag.
286    pub const fn from_uint_sign(magnitude: BoxedUint, sign: Choice) -> Self {
287        Self { sign, magnitude }
288    }
289
290    /// Obtain the magnitude of the `SignedInt`, i.e. its absolute value.
291    pub const fn magnitude(&self) -> &BoxedUint {
292        &self.magnitude
293    }
294
295    /// Determine if the `SignedInt` is non-zero.
296    pub fn is_nonzero(&self) -> Choice {
297        self.magnitude.is_nonzero()
298    }
299
300    /// Determine if the `SignedInt` is zero in variable time.
301    pub fn is_zero_vartime(&self) -> bool {
302        self.magnitude.is_zero_vartime()
303    }
304
305    /// Determine if the `SignedInt` is negative.
306    /// Note: `-0` is representable in this type, so it may be necessary
307    /// to check `self.is_nonzero()` as well.
308    pub const fn is_negative(&self) -> Choice {
309        self.sign
310    }
311
312    /// Extract the lowest 63 bits and convert to its signed representation.
313    #[allow(clippy::cast_possible_wrap)]
314    pub fn lowest(&self) -> i64 {
315        let mag = (self.magnitude.as_uint_ref().lowest_u64() & (u64::MAX >> 1)) as i64;
316        self.sign.select_i64(mag, mag.wrapping_neg())
317    }
318
319    /// Compute the linear combination `a•b + c•d`, returning a widened result.
320    #[inline]
321    pub(crate) fn lincomb_int<const RHS: usize>(
322        a: &Self,
323        b: &Self,
324        c: &Int<RHS>,
325        d: &Int<RHS>,
326    ) -> Self {
327        debug_assert!(a.magnitude.bits_precision() == b.magnitude.bits_precision());
328        let (c, c_sign) = c.abs_sign();
329        let (d, d_sign) = d.abs_sign();
330        // Each SignedBoxedInt • abs(Int) product leaves an empty upper bit.
331        let mut x = a.magnitude.concatenating_mul(c);
332        let x_neg = a.sign.xor(c_sign);
333        let mut y = b.magnitude.concatenating_mul(d);
334        let y_neg = b.sign.xor(d_sign);
335        let odd_neg = x_neg.xor(y_neg);
336
337        // Negate y if none or both of the multiplication results are negative.
338        y.conditional_wrapping_neg_assign(odd_neg.not());
339
340        let borrow;
341        (x, borrow) = x.borrowing_sub(&y, Limb::ZERO);
342        let swap = borrow.is_nonzero().and(odd_neg);
343
344        // Negate the result if we did not negate y and there was a borrow,
345        // indicating that |y| > |x|.
346        x.conditional_wrapping_neg_assign(swap);
347
348        let sign = x_neg.and(swap.not()).or(y_neg.and(swap));
349        Self::from_uint_sign(x, sign)
350    }
351
352    /// Compute the linear combination `a•b + c•d`, and shift the result
353    /// `shift` bits to the right, returning a signed value in the same range
354    /// as the `SignedInt` inputs.
355    pub(crate) fn lincomb_int_reduce_shift<const S: usize>(
356        a: &Self,
357        b: &Self,
358        c: &Int<S>,
359        d: &Int<S>,
360        shift: u32,
361    ) -> Self {
362        debug_assert!(shift < Uint::<S>::BITS);
363        let SignedBoxedInt {
364            sign,
365            mut magnitude,
366        } = Self::lincomb_int(a, b, c, d);
367        magnitude.shr_assign(shift);
368        Self::from_uint_sign(
369            magnitude.resize_unchecked(a.magnitude.bits_precision()),
370            sign,
371        )
372    }
373
374    /// Compute the linear combination `a•b + c•d`, and shift the result
375    /// `shift` bits to the right modulo `m`, returning a signed value in the
376    /// same range as the `SignedInt` inputs.
377    pub(crate) fn lincomb_int_reduce_shift_mod<const S: usize>(
378        a: &Self,
379        b: &Self,
380        c: &Int<S>,
381        d: &Int<S>,
382        shift: u32,
383        m: &BoxedUint,
384        mi: Uint<S>,
385    ) -> Self {
386        debug_assert!(shift < Uint::<S>::BITS);
387        let SignedBoxedInt {
388            sign: mut x_sign,
389            magnitude: mut x,
390        } = Self::lincomb_int(a, b, c, d);
391
392        // Compute the multiple of m that will clear the low N bits of x.
393        let mut xs = Uint::<S>::ZERO;
394        xs.limbs.copy_from_slice(&x.limbs[..S]);
395        let mut mf = xs.wrapping_mul(&mi);
396        mf = mf.bitand(&Uint::MAX.shr_vartime(Uint::<S>::BITS - shift));
397        let xa = m.concatenating_mul(mf);
398
399        // Subtract the adjustment from x potentially producing a borrow.
400        let borrow = x.borrowing_sub_assign(&xa, Limb::ZERO);
401
402        // Negate x if the subtraction borrowed.
403        let swap = borrow.is_nonzero();
404        x.conditional_wrapping_neg_assign(swap);
405        x_sign = x_sign.xor(swap);
406
407        // Shift the result, eliminating the trailing zeros.
408        x.shr_assign(shift);
409
410        // The magnitude x is now in the range [0, 2m). We conditionally subtract
411        // m in order to keep the output in (-m, m).
412        let x_hi = x.limbs[m.nlimbs()];
413        x = x.resize_unchecked(m.bits_precision());
414        x.sub_assign_mod_with_carry(x_hi, m, m);
415
416        Self::from_uint_sign(x, x_sign)
417    }
418
419    /// Normalize the value to a `BoxedUint` in the range `[0, m)`.
420    fn norm(&self, f_sign: Choice, m: &BoxedUint) -> BoxedUint {
421        let swap = f_sign.xor(self.sign) & self.is_nonzero();
422        BoxedUint::ct_select(&self.magnitude, &m.wrapping_sub(&self.magnitude), swap)
423    }
424}
425
426impl fmt::Debug for SignedBoxedInt {
427    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
428        f.write_fmt(format_args!(
429            "{}0x{}",
430            if self.sign.to_bool_vartime() {
431                "-"
432            } else {
433                "+"
434            },
435            &self.magnitude
436        ))
437    }
438}
439
440#[cfg(test)]
441mod tests {
442    use super::BoxedSafeGcdInverter;
443    use crate::BoxedUint;
444
445    #[test]
446    fn invert() {
447        let g = BoxedUint::from_be_hex(
448            "00000000CBF9350842F498CE441FC2DC23C7BF47D3DE91C327B2157C5E4EED77",
449            256,
450        )
451        .unwrap();
452        let modulus = BoxedUint::from_be_hex(
453            "FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551",
454            256,
455        )
456        .unwrap()
457        .to_odd()
458        .unwrap();
459        let inverter = BoxedSafeGcdInverter::new(modulus, BoxedUint::one());
460        let result = inverter.invert(&g).unwrap();
461        assert_eq!(
462            BoxedUint::from_be_hex(
463                "FB668F8F509790BC549B077098918604283D42901C92981062EB48BC723F617B",
464                256
465            )
466            .unwrap(),
467            result
468        );
469    }
470}