1#[cfg(feature = "alloc")]
13pub(crate) mod boxed;
14
15use core::fmt;
16
17use crate::{Choice, CtOption, I64, Int, Limb, Odd, U64, Uint, primitives::u32_min};
18
19const GCD_BATCH_SIZE: u32 = 62;
20
21#[derive(Clone, Debug)]
46pub(crate) struct SafeGcdInverter<const LIMBS: usize> {
47 pub(super) modulus: Odd<Uint<LIMBS>>,
49
50 inverse: u64,
52
53 adjuster: Uint<LIMBS>,
55}
56
57type Matrix = [[i64; 2]; 2];
59
60impl<const LIMBS: usize> SafeGcdInverter<LIMBS> {
61 #[cfg(test)]
63 pub(crate) const fn new(modulus: &Odd<Uint<LIMBS>>, adjuster: &Uint<LIMBS>) -> Self {
64 Self::new_with_inverse(
65 modulus,
66 U64::from_u64(modulus.as_uint_ref().invert_mod_u64()),
67 adjuster,
68 )
69 }
70
71 #[inline]
72 pub(crate) const fn new_with_inverse(
73 modulus: &Odd<Uint<LIMBS>>,
74 inverse: U64,
75 adjuster: &Uint<LIMBS>,
76 ) -> Self {
77 Self {
78 modulus: *modulus,
79 inverse: inverse.as_uint_ref().lowest_u64(),
80 adjuster: *adjuster,
81 }
82 }
83
84 pub const fn invert(&self, value: &Uint<LIMBS>) -> CtOption<Uint<LIMBS>> {
87 invert_odd_mod_precomp::<LIMBS, false>(value, &self.modulus, self.inverse, &self.adjuster)
88 }
89
90 pub const fn invert_vartime(&self, value: &Uint<LIMBS>) -> CtOption<Uint<LIMBS>> {
95 invert_odd_mod_precomp::<LIMBS, true>(value, &self.modulus, self.inverse, &self.adjuster)
96 }
97}
98
99#[inline]
100pub const fn invert_odd_mod<const LIMBS: usize, const VARTIME: bool>(
101 a: &Uint<LIMBS>,
102 m: &Odd<Uint<LIMBS>>,
103) -> CtOption<Uint<LIMBS>> {
104 let mi = m.as_uint_ref().invert_mod_u64();
105 invert_odd_mod_precomp::<LIMBS, VARTIME>(a, m, mi, &Uint::ONE)
106}
107
108const fn invert_odd_mod_precomp<const LIMBS: usize, const VARTIME: bool>(
110 a: &Uint<LIMBS>,
111 m: &Odd<Uint<LIMBS>>,
112 mi: u64,
113 e: &Uint<LIMBS>,
114) -> CtOption<Uint<LIMBS>> {
115 let a_nonzero = a.is_nonzero();
116 let (mut f, mut g) = (SignedInt::from_uint(*m.as_ref()), SignedInt::from_uint(*a));
117 let (mut d, mut e) = (SignedInt::<LIMBS>::ZERO, SignedInt::from_uint(*e));
118 let mut steps = iterations(Uint::<LIMBS>::BITS);
119 let mut delta = 1;
120 let mut t;
121
122 while steps > 0 {
123 if VARTIME && g.is_zero_vartime() {
124 break;
125 }
126 let batch = u32_min(steps, GCD_BATCH_SIZE);
127 (delta, t) = jump::<VARTIME>(f.lowest(), g.lowest(), delta, batch);
128 (f, g) = update_fg(&f, &g, t, batch);
129 (d, e) = update_de(&d, &e, m.as_ref(), mi, t, batch);
130 steps -= batch;
131 }
132
133 let d = d.norm(f.is_negative(), m.as_ref());
134 CtOption::new(d, Uint::eq(&f.magnitude, &Uint::ONE).and(a_nonzero))
135}
136
137pub const fn gcd_odd<const LIMBS: usize, const VARTIME: bool>(
139 f: &Odd<Uint<LIMBS>>,
140 g: &Uint<LIMBS>,
141) -> Odd<Uint<LIMBS>> {
142 let (mut f, mut g) = (SignedInt::from_uint(*f.as_ref()), SignedInt::from_uint(*g));
143 let mut steps = iterations(Uint::<LIMBS>::BITS);
144 let mut delta = 1;
145 let mut t;
146
147 while steps > 0 {
148 if VARTIME && g.is_zero_vartime() {
149 break;
150 }
151 let batch = u32_min(steps, GCD_BATCH_SIZE);
152 (delta, t) = jump::<VARTIME>(f.lowest(), g.lowest(), delta, batch);
153 (f, g) = update_fg(&f, &g, t, batch);
154 steps -= batch;
155 }
156
157 f.magnitude().to_odd().expect_copied("odd by construction")
158}
159
160#[inline]
162const fn jump<const VARTIME: bool>(
163 mut f: i64,
164 mut g: i64,
165 mut delta: i64,
166 mut batch: u32,
167) -> (i64, Matrix) {
168 debug_assert!(f & 1 == 1, "f must be odd");
169 let mut t = [[1i64, 0], [0, 1]];
170 while batch > 0 {
171 (f, g, delta, t) = if VARTIME {
172 jump_step_vartime(f, g, delta, t)
173 } else {
174 jump_step(f, g, delta, t)
175 };
176 batch -= 1;
177 }
178 (delta, t)
179}
180
181#[inline(always)]
185#[allow(clippy::cast_sign_loss)]
186const fn jump_step(
187 mut f: i64,
188 mut g: i64,
189 mut delta: i64,
190 mut t: Matrix,
191) -> (i64, i64, i64, Matrix) {
192 let d_gtz = Choice::from_u64_nz((delta & !(delta >> 63)) as u64);
193 let g_odd = Choice::from_u64_lsb((g & 1) as u64);
194 let g_adj = g_odd.select_i64(0, f);
195 let swap = d_gtz.and(g_odd);
196 delta = swap.select_i64(2i64.wrapping_add(delta), 2i64.wrapping_sub(delta));
197 f = swap.select_i64(f, g);
198 g = swap.select_i64(g.wrapping_add(g_adj), g.wrapping_sub(g_adj)) >> 1;
199 t = [
200 [
201 swap.select_i64(t[0][0], t[1][0]) << 1,
202 swap.select_i64(t[0][1], t[1][1]) << 1,
203 ],
204 [
205 t[1][0].wrapping_add(g_odd.select_i64(0, d_gtz.select_i64(t[0][0], -t[0][0]))),
206 t[1][1].wrapping_add(g_odd.select_i64(0, d_gtz.select_i64(t[0][1], -t[0][1]))),
207 ],
208 ];
209 (f, g, delta, t)
210}
211
212#[inline(always)]
214const fn jump_step_vartime(
215 mut f: i64,
216 mut g: i64,
217 mut delta: i64,
218 mut t: Matrix,
219) -> (i64, i64, i64, Matrix) {
220 if (g & 1) != 0 {
221 (f, g, delta, t) = if delta > 0 {
222 (
223 g,
224 g.wrapping_sub(f),
225 2i64.wrapping_sub(delta),
226 [
227 t[1],
228 [t[1][0].wrapping_sub(t[0][0]), t[1][1].wrapping_sub(t[0][1])],
229 ],
230 )
231 } else {
232 (
233 f,
234 g.wrapping_add(f),
235 2i64.wrapping_add(delta),
236 [
237 t[0],
238 [t[1][0].wrapping_add(t[0][0]), t[1][1].wrapping_add(t[0][1])],
239 ],
240 )
241 };
242 } else {
243 delta = 2i64.wrapping_add(delta);
244 }
245 g >>= 1;
246 t[0][0] <<= 1;
247 t[0][1] <<= 1;
248 (f, g, delta, t)
249}
250
251#[inline]
252const fn update_fg<const LIMBS: usize>(
253 a: &SignedInt<LIMBS>,
254 b: &SignedInt<LIMBS>,
255 t: Matrix,
256 shift: u32,
257) -> (SignedInt<LIMBS>, SignedInt<LIMBS>) {
258 (
259 SignedInt::lincomb_int_reduce_shift(
260 a,
261 b,
262 &I64::from_i64(t[0][0]),
263 &I64::from_i64(t[0][1]),
264 shift,
265 ),
266 SignedInt::lincomb_int_reduce_shift(
267 a,
268 b,
269 &I64::from_i64(t[1][0]),
270 &I64::from_i64(t[1][1]),
271 shift,
272 ),
273 )
274}
275
276#[inline]
277const fn update_de<const LIMBS: usize>(
278 d: &SignedInt<LIMBS>,
279 e: &SignedInt<LIMBS>,
280 m: &Uint<LIMBS>,
281 mi: u64,
282 t: Matrix,
283 shift: u32,
284) -> (SignedInt<LIMBS>, SignedInt<LIMBS>) {
285 (
286 SignedInt::lincomb_int_reduce_shift_mod(
287 d,
288 e,
289 &Int::from_i64(t[0][0]),
290 &Int::from_i64(t[0][1]),
291 shift,
292 m,
293 U64::from_u64(mi),
294 ),
295 SignedInt::lincomb_int_reduce_shift_mod(
296 d,
297 e,
298 &Int::from_i64(t[1][0]),
299 &Int::from_i64(t[1][1]),
300 shift,
301 m,
302 U64::from_u64(mi),
303 ),
304 )
305}
306
307#[inline]
309const fn conditional_negate_in_place_wide<const L: usize, const H: usize>(
310 lo: &mut Uint<L>,
311 hi: &mut Uint<H>,
312 flag: Choice,
313) {
314 let (neg, carry) = lo.carrying_neg();
315 let hi_neg = hi
316 .not()
317 .wrapping_add(&Uint::select(&Uint::ZERO, &Uint::ONE, carry));
318 *lo = Uint::select(lo, &neg, flag);
319 *hi = Uint::select(hi, &hi_neg, flag);
320}
321
322#[inline]
324const fn shr_in_place_wide<const L: usize, const H: usize>(
325 lo: &mut Uint<L>,
326 hi: &mut Uint<H>,
327 shift: u32,
328) {
329 debug_assert!(H <= L);
330 debug_assert!(shift < Uint::<H>::BITS);
331 let copy = hi.shl_vartime(Uint::<H>::BITS - shift);
332 *hi = hi.shr_vartime(shift);
333 *lo = lo.shr_vartime(shift);
334 let mut offs = shift.div_ceil(Limb::BITS) as usize;
335 lo.limbs[L - offs] = lo.limbs[L - offs].bitor(copy.limbs[H - offs]);
336 loop {
337 offs -= 1;
338 if offs == 0 {
339 break;
340 }
341 lo.limbs[L - offs] = copy.limbs[H - offs];
342 }
343}
344
345#[inline]
350#[allow(clippy::integer_division_remainder_used, reason = "public parameter")]
351const fn iterations(bits: u32) -> u32 {
352 (45907 * bits + 30179) / 19929
353}
354
355#[derive(Clone, Copy)]
357struct SignedInt<const LIMBS: usize> {
358 sign: Choice,
359 magnitude: Uint<LIMBS>,
360}
361
362impl<const LIMBS: usize> SignedInt<LIMBS> {
363 pub const ZERO: Self = Self::from_uint(Uint::ZERO);
364
365 pub const fn from_uint(uint: Uint<LIMBS>) -> Self {
367 Self {
368 sign: Choice::FALSE,
369 magnitude: uint,
370 }
371 }
372
373 pub const fn from_uint_sign(magnitude: Uint<LIMBS>, sign: Choice) -> Self {
375 Self { sign, magnitude }
376 }
377
378 pub const fn magnitude(&self) -> Uint<LIMBS> {
380 self.magnitude
381 }
382
383 pub const fn is_nonzero(&self) -> Choice {
385 self.magnitude.is_nonzero()
386 }
387
388 pub const fn is_zero_vartime(&self) -> bool {
390 self.magnitude.is_zero_vartime()
391 }
392
393 pub const fn is_negative(&self) -> Choice {
397 self.sign
398 }
399
400 #[allow(clippy::cast_possible_wrap)]
402 pub const fn lowest(&self) -> i64 {
403 let mag = (self.magnitude.as_uint_ref().lowest_u64() & (u64::MAX >> 1)) as i64;
404 self.sign.select_i64(mag, mag.wrapping_neg())
405 }
406
407 #[inline]
409 pub(crate) const fn lincomb_int<const RHS: usize>(
410 a: &SignedInt<LIMBS>,
411 b: &SignedInt<LIMBS>,
412 c: &Int<RHS>,
413 d: &Int<RHS>,
414 ) -> (Uint<LIMBS>, Uint<RHS>, Choice) {
415 let (c, c_sign) = c.abs_sign();
416 let (d, d_sign) = d.abs_sign();
417 let (mut x, mut x_hi) = a.magnitude.widening_mul(&c);
419 let x_neg = a.sign.xor(c_sign);
420 let (mut y, mut y_hi) = b.magnitude.widening_mul(&d);
421 let y_neg = b.sign.xor(d_sign);
422 let odd_neg = x_neg.xor(y_neg);
423
424 conditional_negate_in_place_wide(&mut y, &mut y_hi, odd_neg.not());
426
427 let mut borrow;
428 (x, borrow) = x.borrowing_sub(&y, Limb::ZERO);
429 (x_hi, borrow) = x_hi.borrowing_sub(&y_hi, borrow);
430 let swap = borrow.is_nonzero().and(odd_neg);
431
432 conditional_negate_in_place_wide(&mut x, &mut x_hi, swap);
435
436 let sign = x_neg.and(swap.not()).or(y_neg.and(swap));
437 (x, x_hi, sign)
438 }
439
440 pub(crate) const fn lincomb_int_reduce_shift<const S: usize>(
444 a: &Self,
445 b: &Self,
446 c: &Int<S>,
447 d: &Int<S>,
448 shift: u32,
449 ) -> Self {
450 debug_assert!(shift < Uint::<S>::BITS);
451 let (mut a, mut a_hi, a_sign) = Self::lincomb_int(a, b, c, d);
452 shr_in_place_wide(&mut a, &mut a_hi, shift);
453 SignedInt::from_uint_sign(a, a_sign)
454 }
455
456 pub(crate) const fn lincomb_int_reduce_shift_mod<const S: usize>(
460 a: &Self,
461 b: &Self,
462 c: &Int<S>,
463 d: &Int<S>,
464 shift: u32,
465 m: &Uint<LIMBS>,
466 mi: Uint<S>,
467 ) -> SignedInt<LIMBS> {
468 debug_assert!(shift < Uint::<S>::BITS);
469 let (mut x, mut x_hi, mut x_sign) = SignedInt::lincomb_int(a, b, c, d);
470
471 let mut mf = x.resize::<S>().wrapping_mul(&mi);
473 mf = mf.bitand(&Uint::MAX.shr_vartime(Uint::<S>::BITS - shift));
474 let (xa, xa_hi) = m.widening_mul(&mf);
475
476 let mut borrow;
478 (x, borrow) = x.borrowing_sub(&xa, Limb::ZERO);
479 (x_hi, borrow) = x_hi.borrowing_sub(&xa_hi, borrow);
480
481 let swap = borrow.is_nonzero();
483 conditional_negate_in_place_wide(&mut x, &mut x_hi, swap);
484 x_sign = x_sign.xor(swap);
485
486 shr_in_place_wide(&mut x, &mut x_hi, shift);
488 debug_assert!(
489 x_hi.shr1().is_nonzero().not().to_bool_vartime(),
490 "overflow was larger than one bit"
491 );
492
493 x = x.try_sub_with_carry(x_hi.limbs[0], m).0;
496
497 SignedInt::from_uint_sign(x, x_sign)
498 }
499
500 const fn norm(&self, f_sign: Choice, m: &Uint<LIMBS>) -> Uint<LIMBS> {
502 let swap = f_sign.xor(self.sign).and(self.is_nonzero());
503 Uint::select(&self.magnitude, &m.wrapping_sub(&self.magnitude), swap)
504 }
505
506 pub const fn eq(a: &Self, b: &Self) -> Choice {
508 Uint::eq(&a.magnitude, &b.magnitude).and(a.sign.eq(b.sign).or(a.is_nonzero().not()))
509 }
510}
511
512impl<const LIMBS: usize> fmt::Debug for SignedInt<LIMBS> {
513 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
514 f.write_fmt(format_args!(
515 "{}0x{}",
516 if self.sign.to_bool_vartime() {
517 "-"
518 } else {
519 "+"
520 },
521 &self.magnitude
522 ))
523 }
524}
525
526impl<const LIMBS: usize> PartialEq for SignedInt<LIMBS> {
527 fn eq(&self, other: &Self) -> bool {
528 Self::eq(self, other).to_bool_vartime()
529 }
530}
531
532#[cfg(test)]
533mod tests {
534 use super::SafeGcdInverter;
535 use crate::{U128, U256, modular::safegcd::shr_in_place_wide};
536
537 #[test]
538 fn invert() {
539 let g =
540 U256::from_be_hex("00000000CBF9350842F498CE441FC2DC23C7BF47D3DE91C327B2157C5E4EED77");
541 let modulus =
542 U256::from_be_hex("FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551")
543 .to_odd()
544 .unwrap();
545 let inverter = SafeGcdInverter::new(&modulus, &U256::ONE);
546 let result = inverter.invert(&g).unwrap();
547 assert_eq!(
548 U256::from_be_hex("FB668F8F509790BC549B077098918604283D42901C92981062EB48BC723F617B"),
549 result
550 );
551 }
552
553 #[test]
554 fn shr_wide() {
555 let hi = U128::from_u128(0x11111111222222223333333344444444);
556 let lo = U256::MAX;
557 let (mut a, mut a_hi) = (lo, hi);
558 shr_in_place_wide(&mut a, &mut a_hi, 16);
559 assert_eq!(
560 a,
561 U256::from_be_hex("4444FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF")
562 );
563 assert_eq!(a_hi, U128::from_u128(0x1111111122222222333333334444));
564 let (mut b, mut b_hi) = (lo, hi);
565 shr_in_place_wide(&mut b, &mut b_hi, 68);
566 assert_eq!(
567 b,
568 U256::from_be_hex("23333333344444444FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF")
569 );
570 assert_eq!(b_hi, U128::from_u128(0x111111112222222));
571 }
572}