Skip to main content

crypto_bigint/uint/
invert_mod.rs

1use super::Uint;
2use crate::{
3    Choice, CtOption, InvertMod, Limb, NonZero, Odd, U64, UintRef, bitlen, modular::safegcd,
4    mul::karatsuba,
5};
6
7/// Perform a modified recursive Hensel quadratic modular inversion to calculate
8/// `a^-1 mod w^p` given `a^-1 mod w^k` where `w` is the size of `Limb`.
9/// For reference see Algorithm 2: <https://arxiv.org/pdf/1209.6626>
10///
11/// `p` is determined by the length of the in-out buffer `buf`, which must be
12/// pre-populated with `a^-1 mod w^k` (constituting `k` limbs).
13///
14/// This method uses recursion, but the maximum depth is limited by
15/// the bit-width of the number of limbs being inverted (`p`).
16///
17/// This method is variable time in `k` and `p` only.
18///
19/// `scratch` must be a pair of mutable buffers, each with capacity at least `p`.
20#[inline]
21pub(crate) const fn expand_invert_mod2k(
22    a: &Odd<UintRef>,
23    buf: &mut UintRef,
24    mut k: usize,
25    scratch: (&mut UintRef, &mut UintRef),
26) {
27    assert!(k > 0);
28    let p = buf.nlimbs();
29    let zs = p.trailing_zeros();
30
31    // Calculate a target width at which we may need to trim the output of
32    // the doubling loop. We reduce the size of `p` by eliminating multiple factors
33    // of two or a single odd factor, recursing until the target width is small enough
34    // to calculate by doubling without significant overhead.
35    let mut target = if zs > 0 { p >> zs } else { p.div_ceil(2) };
36    if target > 8 {
37        expand_invert_mod2k(a, buf.leading_mut(target), k, (scratch.0, scratch.1));
38        k = target;
39        target = p;
40    } else if target <= k {
41        target = p;
42    }
43
44    // Perform the required number of doublings.
45    while k < p {
46        let mut k2 = k * 2;
47        // `target` represents the point at which we may need to trim the output before
48        // continuing with the doubling until we reach `p`.
49        if k2 >= target {
50            (k2, target) = (target, p);
51        }
52        expand_invert_mod2k_step(a, buf.leading_mut(k2), k, (scratch.0, scratch.1));
53        k = k2;
54    }
55}
56
57/// One step of the Hensel quadratic modular inverse calculation, doubling the width
58/// of the inverted output, and wrapping at capacity of `buf`.
59#[inline(always)]
60const fn expand_invert_mod2k_step(
61    a: &Odd<UintRef>,
62    buf: &mut UintRef,
63    buf_init_len: usize,
64    scratch: (&mut UintRef, &mut UintRef),
65) {
66    let new_len = buf.nlimbs();
67
68    assert!(
69        scratch.0.nlimbs() >= new_len
70            && scratch.1.nlimbs() >= new_len
71            && buf_init_len < new_len
72            && buf_init_len >= (new_len >> 1)
73    );
74
75    // Calculate u0^2, wrapping at `new_len` words
76    let u0_p2 = scratch.0.leading_mut(new_len);
77    u0_p2.fill(Limb::ZERO);
78    karatsuba::wrapping_square(buf.leading(buf_init_len), u0_p2);
79
80    // tmp = u0^2•a
81    let tmp = scratch.1.leading_mut(new_len);
82    tmp.fill(Limb::ZERO);
83    karatsuba::wrapping_mul(u0_p2, a.as_ref(), tmp, false);
84
85    // u1 = u0 << 1
86    buf.shl1_assign();
87    // u1 -= u0^2•a
88    buf.borrowing_sub_assign(tmp, Limb::ZERO);
89}
90
91impl<const LIMBS: usize> Uint<LIMBS> {
92    /// Computes 1/`self` mod `2^k`.
93    /// This method is constant-time w.r.t. `self` but not `k`.
94    ///
95    /// If the inverse does not exist (`k > 0` and `self` is even),
96    /// returns `Choice::FALSE` as the second element of the tuple,
97    /// otherwise returns `Choice::TRUE`.
98    #[deprecated(since = "0.7.0", note = "please use `invert_mod2k_vartime` instead")]
99    #[must_use]
100    pub const fn inv_mod2k_vartime(&self, k: u32) -> CtOption<Self> {
101        self.invert_mod2k_vartime(k)
102    }
103
104    /// Computes 1/`self` mod `2^k`.
105    /// This method is constant-time w.r.t. `self` but not `k`.
106    ///
107    /// If the inverse does not exist (`k > 0` and `self` is even, or `k > Self::BITS`),
108    /// returns `CtOption::none`, otherwise returns `CtOption::some`.
109    #[inline]
110    #[must_use]
111    pub const fn invert_mod2k_vartime(&self, k: u32) -> CtOption<Self> {
112        if k == 0 {
113            CtOption::some(Self::ZERO)
114        } else if k > Self::BITS {
115            CtOption::new(Self::ZERO, Choice::FALSE)
116        } else {
117            let (self_odd, is_some) = self.to_odd_or_one();
118            let inv = self_odd.invert_mod2k_vartime(k);
119            CtOption::new(inv, is_some)
120        }
121    }
122
123    /// Computes 1/`self` mod `2^k`.
124    ///
125    /// If the inverse does not exist (`k > 0` and `self` is even, `k > Self::BITS`),
126    /// returns `CtOption::none`, otherwise returns `CtOption::some`.
127    #[deprecated(since = "0.7.0", note = "please use `invert_mod2k` instead")]
128    #[must_use]
129    pub const fn inv_mod2k(&self, k: u32) -> CtOption<Self> {
130        self.invert_mod2k(k)
131    }
132
133    /// Computes 1/`self` mod `2^k`.
134    ///
135    /// If the inverse does not exist (`k > 0` and `self` is even, or `k > Self::BITS`),
136    /// returns `CtOption::none`, otherwise returns `CtOption::some`.
137    #[inline]
138    #[must_use]
139    pub const fn invert_mod2k(&self, k: u32) -> CtOption<Self> {
140        let (odd, is_odd) = self.to_odd_or_one();
141        let is_some =
142            Choice::from_u32_le(k, Self::BITS).and(Choice::from_u32_nz(k).not().or(is_odd));
143        let inv = odd.invert_mod_precision();
144        CtOption::new(inv.restrict_bits(k), is_some)
145    }
146
147    /// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd.
148    #[deprecated(since = "0.7.0", note = "please use `invert_odd_mod` instead")]
149    #[must_use]
150    pub const fn inv_odd_mod(&self, modulus: &Odd<Self>) -> CtOption<Self> {
151        self.invert_odd_mod(modulus)
152    }
153
154    /// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd.
155    #[must_use]
156    pub const fn invert_odd_mod(&self, modulus: &Odd<Self>) -> CtOption<Self> {
157        safegcd::invert_odd_mod::<LIMBS, false>(self, modulus)
158    }
159
160    /// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd.
161    ///
162    /// This method is variable-time with respect to `self`.
163    #[must_use]
164    pub const fn invert_odd_mod_vartime(&self, modulus: &Odd<Self>) -> CtOption<Self> {
165        safegcd::invert_odd_mod::<LIMBS, true>(self, modulus)
166    }
167
168    /// Computes the multiplicative inverse of `self` mod `modulus`.
169    ///
170    /// Returns some if an inverse exists, otherwise none.
171    #[deprecated(since = "0.7.0", note = "please use `invert_mod` instead")]
172    #[must_use]
173    pub const fn inv_mod(&self, modulus: &Self) -> CtOption<Self> {
174        let (m, is_nz) = modulus.to_nz_or_one();
175        self.invert_mod(&m).filter_by(is_nz)
176    }
177
178    /// Computes the multiplicative inverse of `self` mod `modulus`.
179    ///
180    /// Returns some if an inverse exists, otherwise none.
181    #[must_use]
182    pub const fn invert_mod(&self, modulus: &NonZero<Self>) -> CtOption<Self> {
183        // Decompose `modulus = s * 2^k` where `s` is odd
184        let k = modulus.as_ref().trailing_zeros();
185        let s = Odd::new_unchecked(modulus.as_ref().shr(k));
186
187        // Decompose `self` into RNS with moduli `2^k` and `s` and calculate the inverses.
188        // Using the fact that `(z^{-1} mod (m1 * m2)) mod m1 == z^{-1} mod m1`
189        let maybe_a = self.invert_odd_mod(&s);
190
191        let maybe_b = self.invert_mod2k(k);
192        let is_some = maybe_a.is_some().and(maybe_b.is_some());
193
194        // Extract inner values to avoid mapping through CtOptions.
195        // if `a` or `b` don't exist, the returned CtOption will be None anyway.
196        let a = maybe_a.to_inner_unchecked();
197        let b = maybe_b.to_inner_unchecked();
198
199        // Restore from RNS:
200        // self^{-1} = a mod s = b mod 2^k
201        // => self^{-1} = a + s * ((b - a) * s^(-1) mod 2^k)
202        // (essentially one step of the Garner's algorithm for recovery from RNS).
203
204        // `s` is odd, so this always exists
205        let m_odd_inv = s.invert_mod_precision();
206
207        // This part is mod 2^k
208        let t = b.wrapping_sub(&a).wrapping_mul(&m_odd_inv).restrict_bits(k);
209
210        // Will not overflow since `a <= s - 1`, `t <= 2^k - 1`,
211        // so `a + s * t <= s * 2^k - 1 == modulus - 1`.
212        let result = a.wrapping_add(&s.as_ref().wrapping_mul(&t));
213        CtOption::new(result, is_some)
214    }
215}
216
217impl<const LIMBS: usize> Odd<Uint<LIMBS>> {
218    /// Compute a full-width quadratic inversion, `self^-1 mod 2^Self::BITS`.
219    #[inline]
220    pub(crate) const fn invert_mod_precision(&self) -> Uint<LIMBS> {
221        self.invert_mod2k_vartime(Self::BITS)
222    }
223
224    /// Compute a quadratic inversion, `self^-1 mod 2^k` where `k <= Self::BITS`.
225    ///
226    /// This method is variable-time in `k` only.
227    pub(crate) const fn invert_mod2k_vartime(&self, k: u32) -> Uint<LIMBS> {
228        assert!(k <= Self::BITS);
229
230        let k_limbs = bitlen::to_limbs(k);
231        let mut inv = U64::from_u64(self.as_uint_ref().invert_mod_u64()).resize::<LIMBS>();
232
233        if k_limbs <= U64::LIMBS {
234            // trim to k_limbs
235            inv.as_mut_uint_ref().trailing_mut(k_limbs).fill(Limb::ZERO);
236        } else {
237            // expand to k_limbs
238            let mut scratch = (Uint::<LIMBS>::ZERO, Uint::<LIMBS>::ZERO);
239            expand_invert_mod2k(
240                self.as_uint_ref(),
241                inv.as_mut_uint_ref().leading_mut(k_limbs),
242                U64::LIMBS,
243                (scratch.0.as_mut_uint_ref(), scratch.1.as_mut_uint_ref()),
244            );
245        }
246
247        // clear bits in the high limb if necessary
248        #[allow(clippy::integer_division_remainder_used, reason = "TODO")]
249        let k_bits = k % Limb::BITS;
250        if k_bits > 0 {
251            inv.limbs[k_limbs - 1] = inv.limbs[k_limbs - 1].restrict_bits(k_bits);
252        }
253
254        inv
255    }
256}
257
258impl<const LIMBS: usize> InvertMod for Uint<LIMBS> {
259    type Output = Self;
260
261    fn invert_mod(&self, modulus: &NonZero<Self>) -> CtOption<Self> {
262        self.invert_mod(modulus)
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use crate::{Odd, U64, U256, U1024, Uint};
269
270    #[test]
271    fn invert_mod2k() {
272        let v =
273            U256::from_be_hex("fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f");
274        let e =
275            U256::from_be_hex("3642e6faeaac7c6663b93d3d6a0d489e434ddc0123db5fa627c7f6e22ddacacf");
276        let a = v.invert_mod2k(256).unwrap();
277        assert_eq!(e, a);
278
279        let a = v.invert_mod2k_vartime(256).unwrap();
280        assert_eq!(e, a);
281
282        let v =
283            U256::from_be_hex("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141");
284        let e =
285            U256::from_be_hex("261776f29b6b106c7680cf3ed83054a1af5ae537cb4613dbb4f20099aa774ec1");
286        let a = v.invert_mod2k(256).unwrap();
287        assert_eq!(e, a);
288
289        let a = v.invert_mod2k_vartime(256).unwrap();
290        assert_eq!(e, a);
291
292        // Check that even if the number is >= 2^k, the inverse is still correct.
293
294        let v =
295            U256::from_be_hex("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141");
296        let e =
297            U256::from_be_hex("0000000000000000000000000000000000000000034613dbb4f20099aa774ec1");
298        let a = v.invert_mod2k(90).unwrap();
299        assert_eq!(e, a);
300
301        let a = v.invert_mod2k_vartime(90).unwrap();
302        assert_eq!(e, a);
303
304        // An inverse of an even number does not exist.
305
306        let a = U256::from(10u64).invert_mod2k(4);
307        assert!(a.is_none().to_bool_vartime());
308
309        let a = U256::from(10u64).invert_mod2k_vartime(4);
310        assert!(a.is_none().to_bool_vartime());
311
312        // A degenerate case. An inverse mod 2^0 == 1 always exists even for even numbers.
313
314        let a = U256::from(10u64).invert_mod2k_vartime(0).unwrap();
315        assert_eq!(a, U256::ZERO);
316    }
317
318    #[test]
319    fn test_invert_odd() {
320        let a = U1024::from_be_hex(concat![
321            "000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E",
322            "347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8",
323            "BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8",
324            "382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985"
325        ]);
326        let m = U1024::from_be_hex(concat![
327            "D509E7854ABDC81921F669F1DC6F61359523F3949803E58ED4EA8BC16483DC6F",
328            "37BFE27A9AC9EEA2969B357ABC5C0EE214BE16A7D4C58FC620D5B5A20AFF001A",
329            "D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C",
330            "558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156767"
331        ])
332        .to_odd()
333        .unwrap();
334        let expected = U1024::from_be_hex(concat![
335            "B03623284B0EBABCABD5C5881893320281460C0A8E7BF4BFDCFFCBCCBF436A55",
336            "D364235C8171E46C7D21AAD0680676E57274A8FDA6D12768EF961CACDD2DAE57",
337            "88D93DA5EB8EDC391EE3726CDCF4613C539F7D23E8702200CB31B5ED5B06E5CA",
338            "3E520968399B4017BF98A864FABA2B647EFC4998B56774D4F2CB026BC024A336"
339        ]);
340
341        let res = a.invert_odd_mod(&m).unwrap();
342        assert_eq!(res, expected);
343
344        // Even though it is less efficient, it still works
345        let res = a.invert_mod(m.as_nz_ref()).unwrap();
346        assert_eq!(res, expected);
347    }
348
349    #[test]
350    fn test_invert_odd_no_inverse() {
351        // 2^128 - 159, a prime
352        let p1 =
353            U256::from_be_hex("00000000000000000000000000000000ffffffffffffffffffffffffffffff61");
354        // 2^128 - 173, a prime
355        let p2 =
356            U256::from_be_hex("00000000000000000000000000000000ffffffffffffffffffffffffffffff53");
357
358        let m = p1.wrapping_mul(&p2).to_odd().unwrap();
359
360        // `m` is a multiple of `p1`, so no inverse exists
361        let res = p1.invert_odd_mod(&m);
362        assert!(res.is_none().to_bool_vartime());
363    }
364
365    #[test]
366    fn test_invert_even() {
367        let a = U1024::from_be_hex(concat![
368            "000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E",
369            "347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8",
370            "BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8",
371            "382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985"
372        ]);
373        let m = U1024::from_be_hex(concat![
374            "D509E7854ABDC81921F669F1DC6F61359523F3949803E58ED4EA8BC16483DC6F",
375            "37BFE27A9AC9EEA2969B357ABC5C0EE214BE16A7D4C58FC620D5B5A20AFF001A",
376            "D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C",
377            "558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156000"
378        ])
379        .to_nz()
380        .unwrap();
381        let expected = U1024::from_be_hex(concat![
382            "1EBF391306817E1BC610E213F4453AD70911CCBD59A901B2A468A4FC1D64F357",
383            "DBFC6381EC5635CAA664DF280028AF4651482C77A143DF38D6BFD4D64B6C0225",
384            "FC0E199B15A64966FB26D88A86AD144271F6BDCD3D63193AB2B3CC53B99F21A3",
385            "5B9BFAE5D43C6BC6E7A9856C71C7318C76530E9E5AE35882D5ABB02F1696874D",
386        ]);
387
388        let res = a.invert_mod(&m).unwrap();
389        assert_eq!(res, expected);
390    }
391
392    #[test]
393    fn test_invert_small() {
394        let a = U64::from(3u64);
395        let m = U64::from(13u64).to_odd().unwrap();
396
397        let res = a.invert_odd_mod(&m).unwrap();
398        assert_eq!(U64::from(9u64), res);
399    }
400
401    #[test]
402    fn test_no_inverse_small() {
403        let a = U64::from(14u64);
404        let m = U64::from(49u64).to_odd().unwrap();
405
406        let res = a.invert_odd_mod(&m);
407        assert!(res.is_none().to_bool_vartime());
408    }
409
410    #[test]
411    fn test_invert_edge() {
412        assert!(
413            U256::ZERO
414                .invert_odd_mod(&U256::ONE.to_odd().unwrap())
415                .is_none()
416                .to_bool_vartime()
417        );
418        assert_eq!(
419            U256::ONE
420                .invert_odd_mod(&U256::ONE.to_odd().unwrap())
421                .unwrap(),
422            U256::ZERO
423        );
424        assert_eq!(
425            U256::ONE
426                .invert_odd_mod(&U256::MAX.to_odd().unwrap())
427                .unwrap(),
428            U256::ONE
429        );
430        assert!(
431            U256::MAX
432                .invert_odd_mod(&U256::MAX.to_odd().unwrap())
433                .is_none()
434                .to_bool_vartime()
435        );
436        assert_eq!(
437            U256::MAX
438                .invert_odd_mod(&U256::ONE.to_odd().unwrap())
439                .unwrap(),
440            U256::ZERO
441        );
442    }
443
444    #[test]
445    fn invert_mod_precision() {
446        const BIG: Odd<Uint<8>> = Odd::new_unchecked(Uint::MAX);
447
448        fn test_invert_size<const LIMBS: usize>() {
449            let a = BIG.resize::<LIMBS>();
450            let a_inv = a.invert_mod_precision();
451            assert_eq!(a.as_ref().wrapping_mul(&a_inv), Uint::ONE);
452        }
453
454        test_invert_size::<1>();
455        test_invert_size::<2>();
456        test_invert_size::<3>();
457        test_invert_size::<4>();
458        test_invert_size::<5>();
459        test_invert_size::<6>();
460        test_invert_size::<7>();
461        test_invert_size::<8>();
462        test_invert_size::<9>();
463        test_invert_size::<10>();
464    }
465}