1#[cfg(feature = "alloc")]
13pub(crate) mod boxed;
14
15use crate::{Choice, CtOption, I64, Int, Limb, Odd, U64, Uint, bitlen, primitives::u32_min};
16use core::fmt;
17
18const GCD_BATCH_SIZE: u32 = 62;
19
20#[derive(Clone, Debug)]
45pub(crate) struct SafeGcdInverter<const LIMBS: usize> {
46 pub(super) modulus: Odd<Uint<LIMBS>>,
48
49 inverse: u64,
51
52 adjuster: Uint<LIMBS>,
54}
55
56type Matrix = [[i64; 2]; 2];
58
59impl<const LIMBS: usize> SafeGcdInverter<LIMBS> {
60 #[cfg(test)]
62 pub(crate) const fn new(modulus: &Odd<Uint<LIMBS>>, adjuster: &Uint<LIMBS>) -> Self {
63 Self::new_with_inverse(
64 modulus,
65 U64::from_u64(modulus.as_uint_ref().invert_mod_u64()),
66 adjuster,
67 )
68 }
69
70 #[inline]
71 pub(crate) const fn new_with_inverse(
72 modulus: &Odd<Uint<LIMBS>>,
73 inverse: U64,
74 adjuster: &Uint<LIMBS>,
75 ) -> Self {
76 Self {
77 modulus: *modulus,
78 inverse: inverse.as_uint_ref().lowest_u64(),
79 adjuster: *adjuster,
80 }
81 }
82
83 pub const fn invert(&self, value: &Uint<LIMBS>) -> CtOption<Uint<LIMBS>> {
86 invert_odd_mod_precomp::<LIMBS, false>(value, &self.modulus, self.inverse, &self.adjuster)
87 }
88
89 pub const fn invert_vartime(&self, value: &Uint<LIMBS>) -> CtOption<Uint<LIMBS>> {
94 invert_odd_mod_precomp::<LIMBS, true>(value, &self.modulus, self.inverse, &self.adjuster)
95 }
96}
97
98#[inline]
99pub const fn invert_odd_mod<const LIMBS: usize, const VARTIME: bool>(
100 a: &Uint<LIMBS>,
101 m: &Odd<Uint<LIMBS>>,
102) -> CtOption<Uint<LIMBS>> {
103 let mi = m.as_uint_ref().invert_mod_u64();
104 invert_odd_mod_precomp::<LIMBS, VARTIME>(a, m, mi, &Uint::ONE)
105}
106
107const fn invert_odd_mod_precomp<const LIMBS: usize, const VARTIME: bool>(
109 a: &Uint<LIMBS>,
110 m: &Odd<Uint<LIMBS>>,
111 mi: u64,
112 e: &Uint<LIMBS>,
113) -> CtOption<Uint<LIMBS>> {
114 let a_nonzero = a.is_nonzero();
115 let (mut f, mut g) = (SignedInt::from_uint(*m.as_ref()), SignedInt::from_uint(*a));
116 let (mut d, mut e) = (SignedInt::<LIMBS>::ZERO, SignedInt::from_uint(*e));
117 let mut steps = iterations(Uint::<LIMBS>::BITS);
118 let mut delta = 1;
119 let mut t;
120
121 while steps > 0 {
122 if VARTIME && g.is_zero_vartime() {
123 break;
124 }
125 let batch = u32_min(steps, GCD_BATCH_SIZE);
126 (delta, t) = jump::<VARTIME>(f.lowest(), g.lowest(), delta, batch);
127 (f, g) = update_fg(&f, &g, t, batch);
128 (d, e) = update_de(&d, &e, m.as_ref(), mi, t, batch);
129 steps -= batch;
130 }
131
132 let d = d.norm(f.is_negative(), m.as_ref());
133 CtOption::new(d, Uint::eq(&f.magnitude, &Uint::ONE).and(a_nonzero))
134}
135
136pub const fn gcd_odd<const LIMBS: usize, const VARTIME: bool>(
138 f: &Odd<Uint<LIMBS>>,
139 g: &Uint<LIMBS>,
140) -> Odd<Uint<LIMBS>> {
141 let (mut f, mut g) = (SignedInt::from_uint(*f.as_ref()), SignedInt::from_uint(*g));
142 let mut steps = iterations(Uint::<LIMBS>::BITS);
143 let mut delta = 1;
144 let mut t;
145
146 while steps > 0 {
147 if VARTIME && g.is_zero_vartime() {
148 break;
149 }
150 let batch = u32_min(steps, GCD_BATCH_SIZE);
151 (delta, t) = jump::<VARTIME>(f.lowest(), g.lowest(), delta, batch);
152 (f, g) = update_fg(&f, &g, t, batch);
153 steps -= batch;
154 }
155
156 f.magnitude().to_odd().expect_copied("odd by construction")
157}
158
159#[inline]
161const fn jump<const VARTIME: bool>(
162 mut f: i64,
163 mut g: i64,
164 mut delta: i64,
165 mut batch: u32,
166) -> (i64, Matrix) {
167 debug_assert!(f & 1 == 1, "f must be odd");
168 let mut t = [[1i64, 0], [0, 1]];
169 while batch > 0 {
170 (f, g, delta, t) = if VARTIME {
171 jump_step_vartime(f, g, delta, t)
172 } else {
173 jump_step(f, g, delta, t)
174 };
175 batch -= 1;
176 }
177 (delta, t)
178}
179
180#[inline(always)]
184#[allow(clippy::cast_sign_loss)]
185const fn jump_step(
186 mut f: i64,
187 mut g: i64,
188 mut delta: i64,
189 mut t: Matrix,
190) -> (i64, i64, i64, Matrix) {
191 let d_gtz = Choice::from_u64_nz((delta & !(delta >> 63)) as u64);
192 let g_odd = Choice::from_u64_lsb((g & 1) as u64);
193 let g_adj = g_odd.select_i64(0, f);
194 let swap = d_gtz.and(g_odd);
195 delta = swap.select_i64(2i64.wrapping_add(delta), 2i64.wrapping_sub(delta));
196 f = swap.select_i64(f, g);
197 g = swap.select_i64(g.wrapping_add(g_adj), g.wrapping_sub(g_adj)) >> 1;
198 t = [
199 [
200 swap.select_i64(t[0][0], t[1][0]) << 1,
201 swap.select_i64(t[0][1], t[1][1]) << 1,
202 ],
203 [
204 t[1][0].wrapping_add(g_odd.select_i64(0, d_gtz.select_i64(t[0][0], -t[0][0]))),
205 t[1][1].wrapping_add(g_odd.select_i64(0, d_gtz.select_i64(t[0][1], -t[0][1]))),
206 ],
207 ];
208 (f, g, delta, t)
209}
210
211#[inline(always)]
213const fn jump_step_vartime(
214 mut f: i64,
215 mut g: i64,
216 mut delta: i64,
217 mut t: Matrix,
218) -> (i64, i64, i64, Matrix) {
219 if (g & 1) != 0 {
220 (f, g, delta, t) = if delta > 0 {
221 (
222 g,
223 g.wrapping_sub(f),
224 2i64.wrapping_sub(delta),
225 [
226 t[1],
227 [t[1][0].wrapping_sub(t[0][0]), t[1][1].wrapping_sub(t[0][1])],
228 ],
229 )
230 } else {
231 (
232 f,
233 g.wrapping_add(f),
234 2i64.wrapping_add(delta),
235 [
236 t[0],
237 [t[1][0].wrapping_add(t[0][0]), t[1][1].wrapping_add(t[0][1])],
238 ],
239 )
240 };
241 } else {
242 delta = 2i64.wrapping_add(delta);
243 }
244 g >>= 1;
245 t[0][0] <<= 1;
246 t[0][1] <<= 1;
247 (f, g, delta, t)
248}
249
250#[inline]
251const fn update_fg<const LIMBS: usize>(
252 a: &SignedInt<LIMBS>,
253 b: &SignedInt<LIMBS>,
254 t: Matrix,
255 shift: u32,
256) -> (SignedInt<LIMBS>, SignedInt<LIMBS>) {
257 (
258 SignedInt::lincomb_int_reduce_shift(
259 a,
260 b,
261 &I64::from_i64(t[0][0]),
262 &I64::from_i64(t[0][1]),
263 shift,
264 ),
265 SignedInt::lincomb_int_reduce_shift(
266 a,
267 b,
268 &I64::from_i64(t[1][0]),
269 &I64::from_i64(t[1][1]),
270 shift,
271 ),
272 )
273}
274
275#[inline]
276const fn update_de<const LIMBS: usize>(
277 d: &SignedInt<LIMBS>,
278 e: &SignedInt<LIMBS>,
279 m: &Uint<LIMBS>,
280 mi: u64,
281 t: Matrix,
282 shift: u32,
283) -> (SignedInt<LIMBS>, SignedInt<LIMBS>) {
284 (
285 SignedInt::lincomb_int_reduce_shift_mod(
286 d,
287 e,
288 &Int::from_i64(t[0][0]),
289 &Int::from_i64(t[0][1]),
290 shift,
291 m,
292 U64::from_u64(mi),
293 ),
294 SignedInt::lincomb_int_reduce_shift_mod(
295 d,
296 e,
297 &Int::from_i64(t[1][0]),
298 &Int::from_i64(t[1][1]),
299 shift,
300 m,
301 U64::from_u64(mi),
302 ),
303 )
304}
305
306#[inline]
308const fn conditional_negate_in_place_wide<const L: usize, const H: usize>(
309 lo: &mut Uint<L>,
310 hi: &mut Uint<H>,
311 flag: Choice,
312) {
313 let (neg, carry) = lo.carrying_neg();
314 let hi_neg = hi
315 .not()
316 .wrapping_add(&Uint::select(&Uint::ZERO, &Uint::ONE, carry));
317 *lo = Uint::select(lo, &neg, flag);
318 *hi = Uint::select(hi, &hi_neg, flag);
319}
320
321#[inline]
323const fn shr_in_place_wide<const L: usize, const H: usize>(
324 lo: &mut Uint<L>,
325 hi: &mut Uint<H>,
326 shift: u32,
327) {
328 debug_assert!(H <= L);
329 debug_assert!(shift < Uint::<H>::BITS);
330 let copy = hi.shl_vartime(Uint::<H>::BITS - shift);
331 *hi = hi.shr_vartime(shift);
332 *lo = lo.shr_vartime(shift);
333 let mut offs = bitlen::to_limbs(shift);
334 lo.limbs[L - offs] = lo.limbs[L - offs].bitor(copy.limbs[H - offs]);
335 loop {
336 offs -= 1;
337 if offs == 0 {
338 break;
339 }
340 lo.limbs[L - offs] = copy.limbs[H - offs];
341 }
342}
343
344#[inline]
349#[allow(clippy::integer_division_remainder_used, reason = "public parameter")]
350const fn iterations(bits: u32) -> u32 {
351 (45907 * bits + 30179) / 19929
352}
353
354#[derive(Clone, Copy)]
356struct SignedInt<const LIMBS: usize> {
357 sign: Choice,
358 magnitude: Uint<LIMBS>,
359}
360
361impl<const LIMBS: usize> SignedInt<LIMBS> {
362 pub const ZERO: Self = Self::from_uint(Uint::ZERO);
363
364 pub const fn from_uint(uint: Uint<LIMBS>) -> Self {
366 Self {
367 sign: Choice::FALSE,
368 magnitude: uint,
369 }
370 }
371
372 pub const fn from_uint_sign(magnitude: Uint<LIMBS>, sign: Choice) -> Self {
374 Self { sign, magnitude }
375 }
376
377 pub const fn magnitude(&self) -> Uint<LIMBS> {
379 self.magnitude
380 }
381
382 pub const fn is_nonzero(&self) -> Choice {
384 self.magnitude.is_nonzero()
385 }
386
387 pub const fn is_zero_vartime(&self) -> bool {
389 self.magnitude.is_zero_vartime()
390 }
391
392 pub const fn is_negative(&self) -> Choice {
396 self.sign
397 }
398
399 #[allow(clippy::cast_possible_wrap)]
401 pub const fn lowest(&self) -> i64 {
402 let mag = (self.magnitude.as_uint_ref().lowest_u64() & (u64::MAX >> 1)) as i64;
403 self.sign.select_i64(mag, mag.wrapping_neg())
404 }
405
406 #[inline]
408 pub(crate) const fn lincomb_int<const RHS: usize>(
409 a: &SignedInt<LIMBS>,
410 b: &SignedInt<LIMBS>,
411 c: &Int<RHS>,
412 d: &Int<RHS>,
413 ) -> (Uint<LIMBS>, Uint<RHS>, Choice) {
414 let (c, c_sign) = c.abs_sign();
415 let (d, d_sign) = d.abs_sign();
416 let (mut x, mut x_hi) = a.magnitude.widening_mul(&c);
418 let x_neg = a.sign.xor(c_sign);
419 let (mut y, mut y_hi) = b.magnitude.widening_mul(&d);
420 let y_neg = b.sign.xor(d_sign);
421 let odd_neg = x_neg.xor(y_neg);
422
423 conditional_negate_in_place_wide(&mut y, &mut y_hi, odd_neg.not());
425
426 let mut borrow;
427 (x, borrow) = x.borrowing_sub(&y, Limb::ZERO);
428 (x_hi, borrow) = x_hi.borrowing_sub(&y_hi, borrow);
429 let swap = borrow.is_nonzero().and(odd_neg);
430
431 conditional_negate_in_place_wide(&mut x, &mut x_hi, swap);
434
435 let sign = x_neg.and(swap.not()).or(y_neg.and(swap));
436 (x, x_hi, sign)
437 }
438
439 pub(crate) const fn lincomb_int_reduce_shift<const S: usize>(
443 a: &Self,
444 b: &Self,
445 c: &Int<S>,
446 d: &Int<S>,
447 shift: u32,
448 ) -> Self {
449 debug_assert!(shift < Uint::<S>::BITS);
450 let (mut a, mut a_hi, a_sign) = Self::lincomb_int(a, b, c, d);
451 shr_in_place_wide(&mut a, &mut a_hi, shift);
452 SignedInt::from_uint_sign(a, a_sign)
453 }
454
455 pub(crate) const fn lincomb_int_reduce_shift_mod<const S: usize>(
459 a: &Self,
460 b: &Self,
461 c: &Int<S>,
462 d: &Int<S>,
463 shift: u32,
464 m: &Uint<LIMBS>,
465 mi: Uint<S>,
466 ) -> SignedInt<LIMBS> {
467 debug_assert!(shift < Uint::<S>::BITS);
468 let (mut x, mut x_hi, mut x_sign) = SignedInt::lincomb_int(a, b, c, d);
469
470 let mut mf = x.resize::<S>().wrapping_mul(&mi);
472 mf = mf.bitand(&Uint::MAX.shr_vartime(Uint::<S>::BITS - shift));
473 let (xa, xa_hi) = m.widening_mul(&mf);
474
475 let mut borrow;
477 (x, borrow) = x.borrowing_sub(&xa, Limb::ZERO);
478 (x_hi, borrow) = x_hi.borrowing_sub(&xa_hi, borrow);
479
480 let swap = borrow.is_nonzero();
482 conditional_negate_in_place_wide(&mut x, &mut x_hi, swap);
483 x_sign = x_sign.xor(swap);
484
485 shr_in_place_wide(&mut x, &mut x_hi, shift);
487 debug_assert!(
488 x_hi.shr1().is_nonzero().not().to_bool_vartime(),
489 "overflow was larger than one bit"
490 );
491
492 x = x.try_sub_with_carry(x_hi.limbs[0], m).0;
495
496 SignedInt::from_uint_sign(x, x_sign)
497 }
498
499 const fn norm(&self, f_sign: Choice, m: &Uint<LIMBS>) -> Uint<LIMBS> {
501 let swap = f_sign.xor(self.sign).and(self.is_nonzero());
502 Uint::select(&self.magnitude, &m.wrapping_sub(&self.magnitude), swap)
503 }
504
505 pub const fn eq(a: &Self, b: &Self) -> Choice {
507 Uint::eq(&a.magnitude, &b.magnitude).and(a.sign.eq(b.sign).or(a.is_nonzero().not()))
508 }
509}
510
511impl<const LIMBS: usize> fmt::Debug for SignedInt<LIMBS> {
512 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
513 f.write_fmt(format_args!(
514 "{}0x{}",
515 if self.sign.to_bool_vartime() {
516 "-"
517 } else {
518 "+"
519 },
520 &self.magnitude
521 ))
522 }
523}
524
525impl<const LIMBS: usize> PartialEq for SignedInt<LIMBS> {
526 fn eq(&self, other: &Self) -> bool {
527 Self::eq(self, other).to_bool_vartime()
528 }
529}
530
531#[cfg(test)]
532mod tests {
533 use super::SafeGcdInverter;
534 use crate::{U128, U256, modular::safegcd::shr_in_place_wide};
535
536 #[test]
537 fn invert() {
538 let g =
539 U256::from_be_hex("00000000CBF9350842F498CE441FC2DC23C7BF47D3DE91C327B2157C5E4EED77");
540 let modulus =
541 U256::from_be_hex("FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551")
542 .to_odd()
543 .unwrap();
544 let inverter = SafeGcdInverter::new(&modulus, &U256::ONE);
545 let result = inverter.invert(&g).unwrap();
546 assert_eq!(
547 U256::from_be_hex("FB668F8F509790BC549B077098918604283D42901C92981062EB48BC723F617B"),
548 result
549 );
550 }
551
552 #[test]
553 fn shr_wide() {
554 let hi = U128::from_u128(0x11111111222222223333333344444444);
555 let lo = U256::MAX;
556 let (mut a, mut a_hi) = (lo, hi);
557 shr_in_place_wide(&mut a, &mut a_hi, 16);
558 assert_eq!(
559 a,
560 U256::from_be_hex("4444FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF")
561 );
562 assert_eq!(a_hi, U128::from_u128(0x1111111122222222333333334444));
563 let (mut b, mut b_hi) = (lo, hi);
564 shr_in_place_wide(&mut b, &mut b_hi, 68);
565 assert_eq!(
566 b,
567 U256::from_be_hex("23333333344444444FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF")
568 );
569 assert_eq!(b_hi, U128::from_u128(0x111111112222222));
570 }
571}