Skip to main content

crypto_bigint/modular/safegcd/
boxed.rs

1//! Implementation of Bernstein-Yang modular inversion and GCD algorithm (a.k.a. safegcd)
2//! as described in: <https://eprint.iacr.org/2019/266>.
3//!
4//! See parent module for more information.
5
6use super::{GCD_BATCH_SIZE, Matrix, iterations, jump};
7use crate::{
8    BoxedUint, Choice, ConcatenatingMul, CtAssign, CtOption, CtSelect, I64, Int, Limb, NonZero,
9    Odd, Resize, U64, Uint,
10    primitives::{u32_max, u32_min},
11};
12use core::fmt;
13
14/// Modular multiplicative inverter based on the Bernstein-Yang method.
15///
16/// See [`super::SafeGcdInverter`] for more information.
17#[derive(Clone, Debug)]
18pub(crate) struct BoxedSafeGcdInverter {
19    /// Modulus
20    pub(crate) modulus: Odd<BoxedUint>,
21
22    /// Multiplicative inverse of the modulus modulo 2^62
23    inverse: u64,
24
25    /// Adjusting parameter (see toplevel documentation).
26    adjuster: BoxedUint,
27}
28
29impl BoxedSafeGcdInverter {
30    /// Creates the inverter for specified modulus and adjusting parameter.
31    ///
32    /// Modulus must be odd. Returns `None` if it is not.
33    #[cfg(test)]
34    pub fn new(modulus: Odd<BoxedUint>, adjuster: BoxedUint) -> Self {
35        let inverse = U64::from_u64(modulus.as_uint_ref().invert_mod_u64());
36        Self::new_with_inverse(modulus, inverse, adjuster)
37    }
38
39    /// Creates the inverter for specified modulus and adjusting parameter.
40    ///
41    /// Modulus must be odd. Returns `None` if it is not.
42    pub(crate) fn new_with_inverse(
43        modulus: Odd<BoxedUint>,
44        inverse: U64,
45        mut adjuster: BoxedUint,
46    ) -> Self {
47        adjuster = adjuster.resize(modulus.bits_precision());
48        Self {
49            modulus,
50            inverse: inverse.as_uint_ref().lowest_u64(),
51            adjuster,
52        }
53    }
54
55    /// Perform constant-time modular inversion.
56    pub(crate) fn invert(&self, value: &BoxedUint) -> CtOption<BoxedUint> {
57        invert_odd_mod_precomp::<false>(
58            value,
59            &self.modulus,
60            self.inverse,
61            Some(self.adjuster.clone()),
62        )
63    }
64
65    /// Perform variable-time modular inversion.
66    pub(crate) fn invert_vartime(&self, value: &BoxedUint) -> CtOption<BoxedUint> {
67        invert_odd_mod_precomp::<true>(
68            value,
69            &self.modulus,
70            self.inverse,
71            Some(self.adjuster.clone()),
72        )
73    }
74}
75
76#[inline]
77pub fn invert_odd_mod<const VARTIME: bool>(
78    a: &BoxedUint,
79    m: &Odd<BoxedUint>,
80) -> CtOption<BoxedUint> {
81    let mi = m.as_uint_ref().invert_mod_u64();
82    invert_odd_mod_precomp::<VARTIME>(a, m, mi, None)
83}
84
85/// Calculate the multiplicative inverse of `a` modulo `m`.
86///
87fn invert_odd_mod_precomp<const VARTIME: bool>(
88    a: &BoxedUint,
89    m: &Odd<BoxedUint>,
90    mi: u64,
91    e: Option<BoxedUint>,
92) -> CtOption<BoxedUint> {
93    let a_nonzero = a.is_nonzero();
94    let bits_precision = u32_max(a.bits_precision(), m.as_ref().bits_precision());
95    let m = m.as_ref().resize(bits_precision);
96    let (mut f, mut g) = (
97        SignedBoxedInt::from_uint(m.clone()),
98        SignedBoxedInt::from_uint_with_precision(a, bits_precision),
99    );
100    let (mut d, mut e) = (
101        SignedBoxedInt::zero_with_precision(bits_precision),
102        SignedBoxedInt::from_uint(e.map_or_else(
103            || BoxedUint::one_with_precision(bits_precision),
104            |e| e.resize(bits_precision),
105        )),
106    );
107    let mut steps = iterations(bits_precision);
108    let mut delta = 1;
109    let mut t;
110
111    while steps > 0 {
112        if VARTIME && g.is_zero_vartime() {
113            break;
114        }
115        let batch = u32_min(steps, GCD_BATCH_SIZE);
116        (delta, t) = jump::<VARTIME>(f.lowest(), g.lowest(), delta, batch);
117        (f, g) = update_fg(&f, &g, t, batch);
118        (d, e) = update_de(&d, &e, &m, mi, t, batch);
119        steps -= batch;
120    }
121
122    let d = d
123        .norm(f.is_negative(), &m)
124        .resize_unchecked(a.bits_precision());
125
126    CtOption::new(d, f.magnitude().is_one() & a_nonzero)
127}
128
129/// Calculate the greatest common denominator of `f` and `g`.
130pub fn gcd<const VARTIME: bool>(f: &BoxedUint, g: &BoxedUint) -> BoxedUint {
131    let f_is_zero = f.is_zero();
132
133    // Note: is non-zero by construction
134    let f_nz = NonZero(BoxedUint::ct_select(
135        f,
136        &BoxedUint::one_with_precision(f.bits_precision()),
137        f_is_zero,
138    ));
139
140    // gcd of (0, g) is g
141    let mut r = gcd_nz::<VARTIME>(&f_nz, g).0;
142    r.ct_assign(g, f_is_zero);
143    r
144}
145
146/// Calculate the greatest common denominator of nonzero `f`, and `g`.
147pub fn gcd_nz<const VARTIME: bool>(f: &NonZero<BoxedUint>, g: &BoxedUint) -> NonZero<BoxedUint> {
148    // Note the following two GCD identity rules:
149    // 1) gcd(2f, 2g) = 2•gcd(f, g), and
150    // 2) gcd(a, 2g) = gcd(f, g) if f is odd.
151    //
152    // Combined, these rules imply that
153    // 3) gcd(2^i•f, 2^j•g) = 2^k•gcd(f, g), with k = min(i, j).
154    //
155    // However, to save ourselves having to divide out 2^j, we also note that
156    // 4) 2^k•gcd(f, g) = 2^k•gcd(a, 2^j•b)
157
158    let i = f.as_ref().trailing_zeros();
159    let k = u32_min(i, g.trailing_zeros());
160
161    let f_odd = Odd(f.as_ref().shr(i));
162    let mut r = gcd_odd::<VARTIME>(&f_odd, g).0;
163    r.shl_assign(k);
164    NonZero(r)
165}
166
167/// Calculate the greatest common denominator of odd `f`, and `g`.
168pub fn gcd_odd<const VARTIME: bool>(f: &Odd<BoxedUint>, g: &BoxedUint) -> Odd<BoxedUint> {
169    let bits_precision = u32_max(f.as_ref().bits_precision(), g.bits_precision());
170    let (mut f, mut g) = (
171        SignedBoxedInt::from_uint_with_precision(f.as_ref(), bits_precision),
172        SignedBoxedInt::from_uint_with_precision(g, bits_precision),
173    );
174    let mut steps = iterations(bits_precision);
175    let mut delta = 1;
176    let mut t;
177
178    while steps > 0 {
179        if VARTIME && g.is_zero_vartime() {
180            break;
181        }
182        let batch = u32_min(steps, GCD_BATCH_SIZE);
183        (delta, t) = jump::<VARTIME>(f.lowest(), g.lowest(), delta, batch);
184        (f, g) = update_fg(&f, &g, t, batch);
185        steps -= batch;
186    }
187
188    f.magnitude()
189        .resize_unchecked(bits_precision)
190        .to_odd()
191        .expect("odd by construction")
192}
193
194#[inline]
195fn update_fg(
196    a: &SignedBoxedInt,
197    b: &SignedBoxedInt,
198    t: Matrix,
199    shift: u32,
200) -> (SignedBoxedInt, SignedBoxedInt) {
201    (
202        SignedBoxedInt::lincomb_int_reduce_shift(
203            a,
204            b,
205            &I64::from_i64(t[0][0]),
206            &I64::from_i64(t[0][1]),
207            shift,
208        ),
209        SignedBoxedInt::lincomb_int_reduce_shift(
210            a,
211            b,
212            &I64::from_i64(t[1][0]),
213            &I64::from_i64(t[1][1]),
214            shift,
215        ),
216    )
217}
218
219#[inline]
220fn update_de(
221    d: &SignedBoxedInt,
222    e: &SignedBoxedInt,
223    m: &BoxedUint,
224    mi: u64,
225    t: Matrix,
226    shift: u32,
227) -> (SignedBoxedInt, SignedBoxedInt) {
228    (
229        SignedBoxedInt::lincomb_int_reduce_shift_mod(
230            d,
231            e,
232            &Int::from_i64(t[0][0]),
233            &Int::from_i64(t[0][1]),
234            shift,
235            m,
236            U64::from_u64(mi),
237        ),
238        SignedBoxedInt::lincomb_int_reduce_shift_mod(
239            d,
240            e,
241            &Int::from_i64(t[1][0]),
242            &Int::from_i64(t[1][1]),
243            shift,
244            m,
245            U64::from_u64(mi),
246        ),
247    )
248}
249
250/// A `Uint` which carries a separate sign in order to maintain the same range.
251#[derive(Clone)]
252struct SignedBoxedInt {
253    sign: Choice,
254    magnitude: BoxedUint,
255}
256
257impl SignedBoxedInt {
258    pub fn zero_with_precision(bits_precision: u32) -> Self {
259        Self::from_uint(BoxedUint::zero_with_precision(bits_precision))
260    }
261
262    /// Construct a new `SignedInt` from a `Uint`.
263    pub const fn from_uint(uint: BoxedUint) -> Self {
264        Self {
265            sign: Choice::FALSE,
266            magnitude: uint,
267        }
268    }
269
270    /// Construct a new `SignedInt` from a `Uint`.
271    pub fn from_uint_with_precision(uint: &BoxedUint, bits_precision: u32) -> Self {
272        Self {
273            sign: Choice::FALSE,
274            magnitude: uint.resize(bits_precision),
275        }
276    }
277
278    /// Construct a new `SignedInt` from a `Uint` and a sign flag.
279    pub const fn from_uint_sign(magnitude: BoxedUint, sign: Choice) -> Self {
280        Self { sign, magnitude }
281    }
282
283    /// Obtain the magnitude of the `SignedInt`, i.e. its absolute value.
284    pub const fn magnitude(&self) -> &BoxedUint {
285        &self.magnitude
286    }
287
288    /// Determine if the `SignedInt` is non-zero.
289    pub fn is_nonzero(&self) -> Choice {
290        self.magnitude.is_nonzero()
291    }
292
293    /// Determine if the `SignedInt` is zero in variable time.
294    pub fn is_zero_vartime(&self) -> bool {
295        self.magnitude.is_zero_vartime()
296    }
297
298    /// Determine if the `SignedInt` is negative.
299    /// Note: `-0` is representable in this type, so it may be necessary
300    /// to check `self.is_nonzero()` as well.
301    pub const fn is_negative(&self) -> Choice {
302        self.sign
303    }
304
305    /// Extract the lowest 63 bits and convert to its signed representation.
306    #[allow(clippy::cast_possible_wrap)]
307    pub fn lowest(&self) -> i64 {
308        let mag = (self.magnitude.as_uint_ref().lowest_u64() & (u64::MAX >> 1)) as i64;
309        self.sign.select_i64(mag, mag.wrapping_neg())
310    }
311
312    /// Compute the linear combination `a•b + c•d`, returning a widened result.
313    #[inline]
314    pub(crate) fn lincomb_int<const RHS: usize>(
315        a: &Self,
316        b: &Self,
317        c: &Int<RHS>,
318        d: &Int<RHS>,
319    ) -> Self {
320        debug_assert!(a.magnitude.bits_precision() == b.magnitude.bits_precision());
321        let (c, c_sign) = c.abs_sign();
322        let (d, d_sign) = d.abs_sign();
323        // Each SignedBoxedInt • abs(Int) product leaves an empty upper bit.
324        let mut x = a.magnitude.concatenating_mul(c);
325        let x_neg = a.sign.xor(c_sign);
326        let mut y = b.magnitude.concatenating_mul(d);
327        let y_neg = b.sign.xor(d_sign);
328        let odd_neg = x_neg.xor(y_neg);
329
330        // Negate y if none or both of the multiplication results are negative.
331        y.conditional_wrapping_neg_assign(odd_neg.not());
332
333        let borrow;
334        (x, borrow) = x.borrowing_sub(&y, Limb::ZERO);
335        let swap = borrow.is_nonzero().and(odd_neg);
336
337        // Negate the result if we did not negate y and there was a borrow,
338        // indicating that |y| > |x|.
339        x.conditional_wrapping_neg_assign(swap);
340
341        let sign = x_neg.and(swap.not()).or(y_neg.and(swap));
342        Self::from_uint_sign(x, sign)
343    }
344
345    /// Compute the linear combination `a•b + c•d`, and shift the result
346    /// `shift` bits to the right, returning a signed value in the same range
347    /// as the `SignedInt` inputs.
348    pub(crate) fn lincomb_int_reduce_shift<const S: usize>(
349        a: &Self,
350        b: &Self,
351        c: &Int<S>,
352        d: &Int<S>,
353        shift: u32,
354    ) -> Self {
355        debug_assert!(shift < Uint::<S>::BITS);
356        let SignedBoxedInt {
357            sign,
358            mut magnitude,
359        } = Self::lincomb_int(a, b, c, d);
360        magnitude.shr_assign(shift);
361        Self::from_uint_sign(
362            magnitude.resize_unchecked(a.magnitude.bits_precision()),
363            sign,
364        )
365    }
366
367    /// Compute the linear combination `a•b + c•d`, and shift the result
368    /// `shift` bits to the right modulo `m`, returning a signed value in the
369    /// same range as the `SignedInt` inputs.
370    pub(crate) fn lincomb_int_reduce_shift_mod<const S: usize>(
371        a: &Self,
372        b: &Self,
373        c: &Int<S>,
374        d: &Int<S>,
375        shift: u32,
376        m: &BoxedUint,
377        mi: Uint<S>,
378    ) -> Self {
379        debug_assert!(shift < Uint::<S>::BITS);
380        let SignedBoxedInt {
381            sign: mut x_sign,
382            magnitude: mut x,
383        } = Self::lincomb_int(a, b, c, d);
384
385        // Compute the multiple of m that will clear the low N bits of x.
386        let mut xs = Uint::<S>::ZERO;
387        xs.limbs.copy_from_slice(&x.limbs[..S]);
388        let mut mf = xs.wrapping_mul(&mi);
389        mf = mf.bitand(&Uint::MAX.shr_vartime(Uint::<S>::BITS - shift));
390        let xa = m.concatenating_mul(mf);
391
392        // Subtract the adjustment from x potentially producing a borrow.
393        let borrow = x.borrowing_sub_assign(&xa, Limb::ZERO);
394
395        // Negate x if the subtraction borrowed.
396        let swap = borrow.is_nonzero();
397        x.conditional_wrapping_neg_assign(swap);
398        x_sign = x_sign.xor(swap);
399
400        // Shift the result, eliminating the trailing zeros.
401        x.shr_assign(shift);
402
403        // The magnitude x is now in the range [0, 2m). We conditionally subtract
404        // m in order to keep the output in (-m, m).
405        let x_hi = x.limbs[m.nlimbs()];
406        x = x.resize_unchecked(m.bits_precision());
407        x.sub_assign_mod_with_carry(x_hi, m, m);
408
409        Self::from_uint_sign(x, x_sign)
410    }
411
412    /// Normalize the value to a `BoxedUint` in the range `[0, m)`.
413    fn norm(&self, f_sign: Choice, m: &BoxedUint) -> BoxedUint {
414        let swap = f_sign.xor(self.sign) & self.is_nonzero();
415        BoxedUint::ct_select(&self.magnitude, &m.wrapping_sub(&self.magnitude), swap)
416    }
417}
418
419impl fmt::Debug for SignedBoxedInt {
420    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
421        f.write_fmt(format_args!(
422            "{}0x{}",
423            if self.sign.to_bool_vartime() {
424                "-"
425            } else {
426                "+"
427            },
428            &self.magnitude
429        ))
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::BoxedSafeGcdInverter;
436    use crate::BoxedUint;
437
438    #[test]
439    fn invert() {
440        let g = BoxedUint::from_be_hex(
441            "00000000CBF9350842F498CE441FC2DC23C7BF47D3DE91C327B2157C5E4EED77",
442            256,
443        )
444        .unwrap();
445        let modulus = BoxedUint::from_be_hex(
446            "FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551",
447            256,
448        )
449        .unwrap()
450        .to_odd()
451        .unwrap();
452        let inverter = BoxedSafeGcdInverter::new(modulus, BoxedUint::one());
453        let result = inverter.invert(&g).unwrap();
454        assert_eq!(
455            BoxedUint::from_be_hex(
456                "FB668F8F509790BC549B077098918604283D42901C92981062EB48BC723F617B",
457                256
458            )
459            .unwrap(),
460            result
461        );
462    }
463}