1use super::Uint;
2use crate::{
3 Choice, CtOption, InvertMod, Limb, NonZero, Odd, U64, UintRef, bitlen, modular::safegcd,
4 mul::karatsuba,
5};
6
7#[inline]
21pub(crate) const fn expand_invert_mod2k(
22 a: &Odd<UintRef>,
23 buf: &mut UintRef,
24 mut k: usize,
25 scratch: (&mut UintRef, &mut UintRef),
26) {
27 assert!(k > 0);
28 let p = buf.nlimbs();
29 let zs = p.trailing_zeros();
30
31 let mut target = if zs > 0 { p >> zs } else { p.div_ceil(2) };
36 if target > 8 {
37 expand_invert_mod2k(a, buf.leading_mut(target), k, (scratch.0, scratch.1));
38 k = target;
39 target = p;
40 } else if target <= k {
41 target = p;
42 }
43
44 while k < p {
46 let mut k2 = k * 2;
47 if k2 >= target {
50 (k2, target) = (target, p);
51 }
52 expand_invert_mod2k_step(a, buf.leading_mut(k2), k, (scratch.0, scratch.1));
53 k = k2;
54 }
55}
56
57#[inline(always)]
60const fn expand_invert_mod2k_step(
61 a: &Odd<UintRef>,
62 buf: &mut UintRef,
63 buf_init_len: usize,
64 scratch: (&mut UintRef, &mut UintRef),
65) {
66 let new_len = buf.nlimbs();
67
68 assert!(
69 scratch.0.nlimbs() >= new_len
70 && scratch.1.nlimbs() >= new_len
71 && buf_init_len < new_len
72 && buf_init_len >= (new_len >> 1)
73 );
74
75 let u0_p2 = scratch.0.leading_mut(new_len);
77 u0_p2.fill(Limb::ZERO);
78 karatsuba::wrapping_square(buf.leading(buf_init_len), u0_p2);
79
80 let tmp = scratch.1.leading_mut(new_len);
82 tmp.fill(Limb::ZERO);
83 karatsuba::wrapping_mul(u0_p2, a.as_ref(), tmp, false);
84
85 buf.shl1_assign();
87 buf.borrowing_sub_assign(tmp, Limb::ZERO);
89}
90
91impl<const LIMBS: usize> Uint<LIMBS> {
92 #[deprecated(since = "0.7.0", note = "please use `invert_mod2k_vartime` instead")]
99 #[must_use]
100 pub const fn inv_mod2k_vartime(&self, k: u32) -> CtOption<Self> {
101 self.invert_mod2k_vartime(k)
102 }
103
104 #[inline]
110 #[must_use]
111 pub const fn invert_mod2k_vartime(&self, k: u32) -> CtOption<Self> {
112 if k == 0 {
113 CtOption::some(Self::ZERO)
114 } else if k > Self::BITS {
115 CtOption::new(Self::ZERO, Choice::FALSE)
116 } else {
117 let (self_odd, is_some) = self.to_odd_or_one();
118 let inv = self_odd.invert_mod2k_vartime(k);
119 CtOption::new(inv, is_some)
120 }
121 }
122
123 #[deprecated(since = "0.7.0", note = "please use `invert_mod2k` instead")]
128 #[must_use]
129 pub const fn inv_mod2k(&self, k: u32) -> CtOption<Self> {
130 self.invert_mod2k(k)
131 }
132
133 #[inline]
138 #[must_use]
139 pub const fn invert_mod2k(&self, k: u32) -> CtOption<Self> {
140 let (odd, is_odd) = self.to_odd_or_one();
141 let is_some =
142 Choice::from_u32_le(k, Self::BITS).and(Choice::from_u32_nz(k).not().or(is_odd));
143 let inv = odd.invert_mod_precision();
144 CtOption::new(inv.restrict_bits(k), is_some)
145 }
146
147 #[deprecated(since = "0.7.0", note = "please use `invert_odd_mod` instead")]
149 #[must_use]
150 pub const fn inv_odd_mod(&self, modulus: &Odd<Self>) -> CtOption<Self> {
151 self.invert_odd_mod(modulus)
152 }
153
154 #[must_use]
156 pub const fn invert_odd_mod(&self, modulus: &Odd<Self>) -> CtOption<Self> {
157 safegcd::invert_odd_mod::<LIMBS, false>(self, modulus)
158 }
159
160 #[must_use]
164 pub const fn invert_odd_mod_vartime(&self, modulus: &Odd<Self>) -> CtOption<Self> {
165 safegcd::invert_odd_mod::<LIMBS, true>(self, modulus)
166 }
167
168 #[deprecated(since = "0.7.0", note = "please use `invert_mod` instead")]
172 #[must_use]
173 pub const fn inv_mod(&self, modulus: &Self) -> CtOption<Self> {
174 let (m, is_nz) = modulus.to_nz_or_one();
175 self.invert_mod(&m).filter_by(is_nz)
176 }
177
178 #[must_use]
182 pub const fn invert_mod(&self, modulus: &NonZero<Self>) -> CtOption<Self> {
183 let k = modulus.as_ref().trailing_zeros();
185 let s = Odd::new_unchecked(modulus.as_ref().shr(k));
186
187 let maybe_a = self.invert_odd_mod(&s);
190
191 let maybe_b = self.invert_mod2k(k);
192 let is_some = maybe_a.is_some().and(maybe_b.is_some());
193
194 let a = maybe_a.to_inner_unchecked();
197 let b = maybe_b.to_inner_unchecked();
198
199 let m_odd_inv = s.invert_mod_precision();
206
207 let t = b.wrapping_sub(&a).wrapping_mul(&m_odd_inv).restrict_bits(k);
209
210 let result = a.wrapping_add(&s.as_ref().wrapping_mul(&t));
213 CtOption::new(result, is_some)
214 }
215}
216
217impl<const LIMBS: usize> Odd<Uint<LIMBS>> {
218 #[inline]
220 pub(crate) const fn invert_mod_precision(&self) -> Uint<LIMBS> {
221 self.invert_mod2k_vartime(Self::BITS)
222 }
223
224 pub(crate) const fn invert_mod2k_vartime(&self, k: u32) -> Uint<LIMBS> {
228 assert!(k <= Self::BITS);
229
230 let k_limbs = bitlen::to_limbs(k);
231 let mut inv = U64::from_u64(self.as_uint_ref().invert_mod_u64()).resize::<LIMBS>();
232
233 if k_limbs <= U64::LIMBS {
234 inv.as_mut_uint_ref().trailing_mut(k_limbs).fill(Limb::ZERO);
236 } else {
237 let mut scratch = (Uint::<LIMBS>::ZERO, Uint::<LIMBS>::ZERO);
239 expand_invert_mod2k(
240 self.as_uint_ref(),
241 inv.as_mut_uint_ref().leading_mut(k_limbs),
242 U64::LIMBS,
243 (scratch.0.as_mut_uint_ref(), scratch.1.as_mut_uint_ref()),
244 );
245 }
246
247 #[allow(clippy::integer_division_remainder_used, reason = "TODO")]
249 let k_bits = k % Limb::BITS;
250 if k_bits > 0 {
251 inv.limbs[k_limbs - 1] = inv.limbs[k_limbs - 1].restrict_bits(k_bits);
252 }
253
254 inv
255 }
256}
257
258impl<const LIMBS: usize> InvertMod for Uint<LIMBS> {
259 type Output = Self;
260
261 fn invert_mod(&self, modulus: &NonZero<Self>) -> CtOption<Self> {
262 self.invert_mod(modulus)
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use crate::{Odd, U64, U256, U1024, Uint};
269
270 #[test]
271 fn invert_mod2k() {
272 let v =
273 U256::from_be_hex("fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f");
274 let e =
275 U256::from_be_hex("3642e6faeaac7c6663b93d3d6a0d489e434ddc0123db5fa627c7f6e22ddacacf");
276 let a = v.invert_mod2k(256).unwrap();
277 assert_eq!(e, a);
278
279 let a = v.invert_mod2k_vartime(256).unwrap();
280 assert_eq!(e, a);
281
282 let v =
283 U256::from_be_hex("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141");
284 let e =
285 U256::from_be_hex("261776f29b6b106c7680cf3ed83054a1af5ae537cb4613dbb4f20099aa774ec1");
286 let a = v.invert_mod2k(256).unwrap();
287 assert_eq!(e, a);
288
289 let a = v.invert_mod2k_vartime(256).unwrap();
290 assert_eq!(e, a);
291
292 let v =
295 U256::from_be_hex("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141");
296 let e =
297 U256::from_be_hex("0000000000000000000000000000000000000000034613dbb4f20099aa774ec1");
298 let a = v.invert_mod2k(90).unwrap();
299 assert_eq!(e, a);
300
301 let a = v.invert_mod2k_vartime(90).unwrap();
302 assert_eq!(e, a);
303
304 let a = U256::from(10u64).invert_mod2k(4);
307 assert!(a.is_none().to_bool_vartime());
308
309 let a = U256::from(10u64).invert_mod2k_vartime(4);
310 assert!(a.is_none().to_bool_vartime());
311
312 let a = U256::from(10u64).invert_mod2k_vartime(0).unwrap();
315 assert_eq!(a, U256::ZERO);
316 }
317
318 #[test]
319 fn test_invert_odd() {
320 let a = U1024::from_be_hex(concat![
321 "000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E",
322 "347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8",
323 "BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8",
324 "382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985"
325 ]);
326 let m = U1024::from_be_hex(concat![
327 "D509E7854ABDC81921F669F1DC6F61359523F3949803E58ED4EA8BC16483DC6F",
328 "37BFE27A9AC9EEA2969B357ABC5C0EE214BE16A7D4C58FC620D5B5A20AFF001A",
329 "D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C",
330 "558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156767"
331 ])
332 .to_odd()
333 .unwrap();
334 let expected = U1024::from_be_hex(concat![
335 "B03623284B0EBABCABD5C5881893320281460C0A8E7BF4BFDCFFCBCCBF436A55",
336 "D364235C8171E46C7D21AAD0680676E57274A8FDA6D12768EF961CACDD2DAE57",
337 "88D93DA5EB8EDC391EE3726CDCF4613C539F7D23E8702200CB31B5ED5B06E5CA",
338 "3E520968399B4017BF98A864FABA2B647EFC4998B56774D4F2CB026BC024A336"
339 ]);
340
341 let res = a.invert_odd_mod(&m).unwrap();
342 assert_eq!(res, expected);
343
344 let res = a.invert_mod(m.as_nz_ref()).unwrap();
346 assert_eq!(res, expected);
347 }
348
349 #[test]
350 fn test_invert_odd_no_inverse() {
351 let p1 =
353 U256::from_be_hex("00000000000000000000000000000000ffffffffffffffffffffffffffffff61");
354 let p2 =
356 U256::from_be_hex("00000000000000000000000000000000ffffffffffffffffffffffffffffff53");
357
358 let m = p1.wrapping_mul(&p2).to_odd().unwrap();
359
360 let res = p1.invert_odd_mod(&m);
362 assert!(res.is_none().to_bool_vartime());
363 }
364
365 #[test]
366 fn test_invert_even() {
367 let a = U1024::from_be_hex(concat![
368 "000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E",
369 "347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8",
370 "BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8",
371 "382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985"
372 ]);
373 let m = U1024::from_be_hex(concat![
374 "D509E7854ABDC81921F669F1DC6F61359523F3949803E58ED4EA8BC16483DC6F",
375 "37BFE27A9AC9EEA2969B357ABC5C0EE214BE16A7D4C58FC620D5B5A20AFF001A",
376 "D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C",
377 "558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156000"
378 ])
379 .to_nz()
380 .unwrap();
381 let expected = U1024::from_be_hex(concat![
382 "1EBF391306817E1BC610E213F4453AD70911CCBD59A901B2A468A4FC1D64F357",
383 "DBFC6381EC5635CAA664DF280028AF4651482C77A143DF38D6BFD4D64B6C0225",
384 "FC0E199B15A64966FB26D88A86AD144271F6BDCD3D63193AB2B3CC53B99F21A3",
385 "5B9BFAE5D43C6BC6E7A9856C71C7318C76530E9E5AE35882D5ABB02F1696874D",
386 ]);
387
388 let res = a.invert_mod(&m).unwrap();
389 assert_eq!(res, expected);
390 }
391
392 #[test]
393 fn test_invert_small() {
394 let a = U64::from(3u64);
395 let m = U64::from(13u64).to_odd().unwrap();
396
397 let res = a.invert_odd_mod(&m).unwrap();
398 assert_eq!(U64::from(9u64), res);
399 }
400
401 #[test]
402 fn test_no_inverse_small() {
403 let a = U64::from(14u64);
404 let m = U64::from(49u64).to_odd().unwrap();
405
406 let res = a.invert_odd_mod(&m);
407 assert!(res.is_none().to_bool_vartime());
408 }
409
410 #[test]
411 fn test_invert_edge() {
412 assert!(
413 U256::ZERO
414 .invert_odd_mod(&U256::ONE.to_odd().unwrap())
415 .is_none()
416 .to_bool_vartime()
417 );
418 assert_eq!(
419 U256::ONE
420 .invert_odd_mod(&U256::ONE.to_odd().unwrap())
421 .unwrap(),
422 U256::ZERO
423 );
424 assert_eq!(
425 U256::ONE
426 .invert_odd_mod(&U256::MAX.to_odd().unwrap())
427 .unwrap(),
428 U256::ONE
429 );
430 assert!(
431 U256::MAX
432 .invert_odd_mod(&U256::MAX.to_odd().unwrap())
433 .is_none()
434 .to_bool_vartime()
435 );
436 assert_eq!(
437 U256::MAX
438 .invert_odd_mod(&U256::ONE.to_odd().unwrap())
439 .unwrap(),
440 U256::ZERO
441 );
442 }
443
444 #[test]
445 fn invert_mod_precision() {
446 const BIG: Odd<Uint<8>> = Odd::new_unchecked(Uint::MAX);
447
448 fn test_invert_size<const LIMBS: usize>() {
449 let a = BIG.resize::<LIMBS>();
450 let a_inv = a.invert_mod_precision();
451 assert_eq!(a.as_ref().wrapping_mul(&a_inv), Uint::ONE);
452 }
453
454 test_invert_size::<1>();
455 test_invert_size::<2>();
456 test_invert_size::<3>();
457 test_invert_size::<4>();
458 test_invert_size::<5>();
459 test_invert_size::<6>();
460 test_invert_size::<7>();
461 test_invert_size::<8>();
462 test_invert_size::<9>();
463 test_invert_size::<10>();
464 }
465}