1use crate::{
4 Choice, Gcd, Int, NonZero, NonZeroUint, Odd, OddUint, Uint, Xgcd,
5 modular::{bingcd::xgcd::PatternXgcdOutput, safegcd},
6 primitives::u32_min,
7};
8
9impl<const LIMBS: usize> Uint<LIMBS> {
10 #[must_use]
12 pub const fn gcd(&self, rhs: &Self) -> Self {
13 let self_is_nz = self.is_nonzero();
14 let self_nz = NonZero::new_unchecked(Uint::select(&Uint::ONE, self, self_is_nz));
16 Uint::select(rhs, self_nz.gcd_unsigned(rhs).as_ref(), self_is_nz)
17 }
18
19 #[must_use]
23 pub const fn gcd_vartime(&self, rhs: &Self) -> Self {
24 if self.is_zero_vartime() {
25 return *rhs;
26 }
27 NonZero::new_unchecked(*self)
28 .gcd_unsigned_vartime(rhs)
29 .get_copy()
30 }
31
32 #[must_use]
36 pub const fn xgcd(&self, rhs: &Self) -> UintXgcdOutput<LIMBS> {
37 let self_is_zero = self.is_nonzero().not();
39 let self_nz = NonZero::new_unchecked(Uint::select(self, &Uint::ONE, self_is_zero));
40 let rhs_is_zero = rhs.is_nonzero().not();
41 let rhs_nz = NonZero::new_unchecked(Uint::select(rhs, &Uint::ONE, rhs_is_zero));
42
43 let NonZeroUintXgcdOutput {
44 gcd,
45 mut x,
46 mut y,
47 mut lhs_on_gcd,
48 mut rhs_on_gcd,
49 } = self_nz.xgcd(&rhs_nz);
50
51 let mut gcd = *gcd.as_ref();
53 gcd = Uint::select(&gcd, rhs, self_is_zero);
54 gcd = Uint::select(&gcd, self, rhs_is_zero);
55
56 x = Int::select(&x, &Int::ZERO, self_is_zero);
58 y = Int::select(&y, &Int::ONE, self_is_zero);
59 x = Int::select(&x, &Int::ONE, rhs_is_zero);
60 y = Int::select(&y, &Int::ZERO, rhs_is_zero);
61
62 lhs_on_gcd = Uint::select(&lhs_on_gcd, &Uint::ZERO, self_is_zero);
64 rhs_on_gcd = Uint::select(&rhs_on_gcd, &Uint::ONE, self_is_zero);
65 lhs_on_gcd = Uint::select(&lhs_on_gcd, &Uint::ONE, rhs_is_zero);
66 rhs_on_gcd = Uint::select(&rhs_on_gcd, &Uint::ZERO, rhs_is_zero);
67
68 UintXgcdOutput {
69 gcd,
70 x,
71 y,
72 lhs_on_gcd,
73 rhs_on_gcd,
74 }
75 }
76}
77
78impl<const LIMBS: usize> NonZeroUint<LIMBS> {
79 #[must_use]
81 pub const fn gcd_unsigned(&self, rhs: &Uint<LIMBS>) -> Self {
82 let lhs = self.as_ref();
83
84 let i = lhs.trailing_zeros();
95 let j = rhs.trailing_zeros();
96 let k = u32_min(i, j);
97
98 let odd_lhs = Odd::new_unchecked(lhs.shr(i));
99 let gcd_div_2k = odd_lhs.gcd_unsigned(rhs);
100 NonZero::new_unchecked(gcd_div_2k.as_ref().shl(k))
101 }
102
103 #[must_use]
107 pub const fn gcd_unsigned_vartime(&self, rhs: &Uint<LIMBS>) -> Self {
108 let lhs = self.as_ref();
109
110 let i = lhs.trailing_zeros_vartime();
111 let j = rhs.trailing_zeros_vartime();
112 let k = u32_min(i, j);
113
114 let odd_lhs = Odd::new_unchecked(lhs.shr_vartime(i));
115 let gcd_div_2k = odd_lhs.gcd_unsigned_vartime(rhs);
116 NonZero::new_unchecked(gcd_div_2k.as_ref().shl_vartime(k))
117 }
118
119 #[must_use]
123 pub const fn xgcd(&self, rhs: &Self) -> NonZeroUintXgcdOutput<LIMBS> {
124 let (mut lhs, mut rhs) = (*self.as_ref(), *rhs.as_ref());
125
126 let i = lhs.trailing_zeros();
128 let j = rhs.trailing_zeros();
129 let k = u32_min(i, j);
130 lhs = lhs.shr(k);
131 rhs = rhs.shr(k);
132
133 let swap = Choice::from_u32_lt(j, i);
135 Uint::conditional_swap(&mut lhs, &mut rhs, swap);
136 let lhs = lhs.to_odd().expect_copied("odd by construction");
137 let rhs = rhs.to_nz().expect_copied("non-zero by construction");
138
139 let odd_output = OddUintXgcdOutput::from_pattern_output(lhs.binxgcd_nz(&rhs));
140 odd_output.to_nz_output(k, swap)
141 }
142}
143
144impl<const LIMBS: usize> OddUint<LIMBS> {
145 #[inline(always)]
147 #[must_use]
148 pub const fn gcd_unsigned(&self, rhs: &Uint<LIMBS>) -> Self {
149 if LIMBS == 1 {
150 Self::classic_bingcd(self, rhs)
151 } else {
152 Self::safegcd(self, rhs)
153 }
154 }
155
156 #[inline(always)]
160 #[must_use]
161 pub const fn gcd_unsigned_vartime(&self, rhs: &Uint<LIMBS>) -> Self {
162 if LIMBS == 1 {
163 Self::classic_bingcd_vartime(self, rhs)
164 } else {
165 Self::safegcd_vartime(self, rhs)
166 }
167 }
168
169 #[doc(hidden)]
175 #[inline(always)]
176 #[must_use]
177 pub const fn bingcd(&self, rhs: &Uint<LIMBS>) -> Self {
178 if LIMBS < 4 {
179 self.classic_bingcd(rhs)
180 } else {
181 self.optimized_bingcd(rhs)
182 }
183 }
184
185 #[doc(hidden)]
193 #[inline(always)]
194 #[must_use]
195 pub const fn bingcd_vartime(&self, rhs: &Uint<LIMBS>) -> Self {
196 if LIMBS < 4 {
197 self.classic_bingcd_vartime(rhs)
198 } else {
199 self.optimized_bingcd_vartime(rhs)
200 }
201 }
202
203 #[doc(hidden)]
205 #[inline]
206 #[must_use]
207 pub const fn safegcd(&self, rhs: &Uint<LIMBS>) -> Self {
208 safegcd::gcd_odd::<LIMBS, false>(self, rhs)
209 }
210
211 #[doc(hidden)]
215 #[inline]
216 #[must_use]
217 pub const fn safegcd_vartime(&self, rhs: &Uint<LIMBS>) -> Self {
218 safegcd::gcd_odd::<LIMBS, true>(self, rhs)
219 }
220
221 #[inline]
225 #[must_use]
226 pub const fn xgcd(&self, rhs: &Self) -> OddUintXgcdOutput<LIMBS> {
227 OddUintXgcdOutput::from_pattern_output(self.binxgcd_odd(rhs))
228 }
229}
230
231pub type UintXgcdOutput<const LIMBS: usize> = XgcdOutput<LIMBS, Uint<LIMBS>>;
232pub type NonZeroUintXgcdOutput<const LIMBS: usize> = XgcdOutput<LIMBS, NonZeroUint<LIMBS>>;
233pub type OddUintXgcdOutput<const LIMBS: usize> = XgcdOutput<LIMBS, OddUint<LIMBS>>;
234
235#[derive(Debug, Copy, Clone)]
237pub struct XgcdOutput<const LIMBS: usize, GCD: Copy> {
238 pub gcd: GCD,
240 pub x: Int<LIMBS>,
242 pub y: Int<LIMBS>,
244 pub lhs_on_gcd: Uint<LIMBS>,
246 pub rhs_on_gcd: Uint<LIMBS>,
248}
249
250impl<const LIMBS: usize, GCD: Copy> XgcdOutput<LIMBS, GCD> {
251 pub const fn gcd(&self) -> GCD {
253 self.gcd
254 }
255
256 pub const fn bezout_coefficients(&self) -> (Int<LIMBS>, Int<LIMBS>) {
258 (self.x, self.y)
259 }
260
261 pub const fn quotients(&self) -> (Uint<LIMBS>, Uint<LIMBS>) {
263 (self.lhs_on_gcd, self.rhs_on_gcd)
264 }
265}
266
267impl<const LIMBS: usize> OddUintXgcdOutput<LIMBS> {
268 pub(crate) const fn from_pattern_output(output: PatternXgcdOutput<LIMBS>) -> Self {
269 let gcd = output.gcd();
270 let (x, y) = output.bezout_coefficients();
271 let (lhs_on_gcd, rhs_on_gcd) = output.quotients();
272
273 OddUintXgcdOutput {
274 gcd,
275 x,
276 y,
277 lhs_on_gcd,
278 rhs_on_gcd,
279 }
280 }
281
282 pub(crate) const fn to_nz_output(self, k: u32, swap: Choice) -> NonZeroUintXgcdOutput<LIMBS> {
283 let Self {
284 ref gcd,
285 mut x,
286 mut y,
287 mut lhs_on_gcd,
288 mut rhs_on_gcd,
289 } = self;
290
291 let gcd = gcd
293 .as_ref()
294 .shl(k)
295 .to_nz()
296 .expect_copied("is non-zero by construction");
297 Int::conditional_swap(&mut x, &mut y, swap);
298 Uint::conditional_swap(&mut lhs_on_gcd, &mut rhs_on_gcd, swap);
299
300 NonZeroUintXgcdOutput {
301 gcd,
302 x,
303 y,
304 lhs_on_gcd,
305 rhs_on_gcd,
306 }
307 }
308}
309
310macro_rules! impl_gcd {
311 ($slf:ty, [$($rhs:ty),+]) => {
312 $(
313 impl_gcd!($slf, $rhs, $rhs);
314 )+
315 };
316 ($slf:ty, $rhs:ty, $out:ty) => {
317 impl<const LIMBS: usize> Gcd<$rhs> for $slf {
318 type Output = $out;
319
320 #[inline]
321 fn gcd(&self, rhs: &$rhs) -> Self::Output {
322 rhs.gcd(self)
323 }
324
325 #[inline]
326 fn gcd_vartime(&self, rhs: &$rhs) -> Self::Output {
327 rhs.gcd_vartime(self)
328 }
329 }
330 };
331}
332
333macro_rules! impl_gcd_unsigned_lhs {
334 ($slf:ty, [$($rhs:ty),+]) => {
335 $(
336 impl_gcd_unsigned_lhs!($slf, $rhs, $slf);
337 )+
338 };
339 ($slf:ty, $rhs:ty, $out:ty) => {
340 impl<const LIMBS: usize> Gcd<$rhs> for $slf {
341 type Output = $out;
342
343 #[inline]
344 fn gcd(&self, rhs: &$rhs) -> Self::Output {
345 self.gcd_unsigned(&rhs)
346 }
347
348 #[inline]
349 fn gcd_vartime(&self, rhs: &$rhs) -> Self::Output {
350 self.gcd_unsigned_vartime(&rhs)
351 }
352 }
353 };
354}
355
356macro_rules! impl_gcd_unsigned_rhs {
357 ($slf:ty, [$($rhs:ty),+]) => {
358 $(
359 impl_gcd_unsigned_rhs!($slf, $rhs, $rhs);
360 )+
361 };
362 ($slf:ty, $rhs:ty, $out:ty) => {
363 impl<const LIMBS: usize> Gcd<$rhs> for $slf {
364 type Output = $out;
365
366 #[inline]
367 fn gcd(&self, rhs: &$rhs) -> Self::Output {
368 rhs.gcd_unsigned(self)
369 }
370
371 #[inline]
372 fn gcd_vartime(&self, rhs: &$rhs) -> Self::Output {
373 rhs.gcd_unsigned_vartime(self)
374 }
375 }
376 };
377}
378
379pub(crate) use impl_gcd_unsigned_lhs;
380pub(crate) use impl_gcd_unsigned_rhs;
381
382impl_gcd!(
383 Uint<LIMBS>,
384 [Uint<LIMBS>, NonZeroUint<LIMBS>, OddUint<LIMBS>]
385);
386impl_gcd_unsigned_lhs!(NonZeroUint<LIMBS>, [Uint<LIMBS>]);
387impl_gcd_unsigned_rhs!(
388 NonZeroUint<LIMBS>,
389 [NonZeroUint<LIMBS>, OddUint<LIMBS>]
390);
391impl_gcd_unsigned_lhs!(OddUint<LIMBS>, [Uint<LIMBS>, NonZeroUint<LIMBS>, OddUint<LIMBS>]);
392
393impl<const LIMBS: usize> Xgcd for Uint<LIMBS> {
394 type Output = UintXgcdOutput<LIMBS>;
395
396 fn xgcd(&self, rhs: &Uint<LIMBS>) -> Self::Output {
397 self.xgcd(rhs)
398 }
399
400 fn xgcd_vartime(&self, rhs: &Uint<LIMBS>) -> Self::Output {
401 self.xgcd(rhs)
403 }
404}
405
406impl<const LIMBS: usize> Xgcd for NonZeroUint<LIMBS> {
407 type Output = NonZeroUintXgcdOutput<LIMBS>;
408
409 fn xgcd(&self, rhs: &NonZeroUint<LIMBS>) -> Self::Output {
410 self.xgcd(rhs)
411 }
412
413 fn xgcd_vartime(&self, rhs: &NonZeroUint<LIMBS>) -> Self::Output {
414 self.xgcd(rhs)
416 }
417}
418
419impl<const LIMBS: usize> Xgcd for OddUint<LIMBS> {
420 type Output = OddUintXgcdOutput<LIMBS>;
421
422 fn xgcd(&self, rhs: &OddUint<LIMBS>) -> Self::Output {
423 self.xgcd(rhs)
424 }
425
426 fn xgcd_vartime(&self, rhs: &OddUint<LIMBS>) -> Self::Output {
427 self.xgcd(rhs)
429 }
430}
431
432#[cfg(all(test, not(miri)))]
433mod tests {
434 mod gcd {
435 use crate::{U64, U128, U256, U512, U1024, U2048, U4096, Uint};
436
437 fn test<const LIMBS: usize>(lhs: Uint<LIMBS>, rhs: Uint<LIMBS>, target: Uint<LIMBS>) {
438 assert_eq!(lhs.gcd(&rhs), target);
439 assert_eq!(lhs.gcd_vartime(&rhs), target);
440 }
441
442 fn run_tests<const LIMBS: usize>() {
443 test(Uint::<LIMBS>::ZERO, Uint::ZERO, Uint::ZERO);
444 test(Uint::<LIMBS>::ZERO, Uint::ONE, Uint::ONE);
445 test(Uint::<LIMBS>::ZERO, Uint::MAX, Uint::MAX);
446 test(Uint::<LIMBS>::ONE, Uint::ZERO, Uint::ONE);
447 test(Uint::<LIMBS>::ONE, Uint::ONE, Uint::ONE);
448 test(Uint::<LIMBS>::ONE, Uint::MAX, Uint::ONE);
449 test(Uint::<LIMBS>::MAX, Uint::ZERO, Uint::MAX);
450 test(Uint::<LIMBS>::MAX, Uint::ONE, Uint::ONE);
451 test(Uint::<LIMBS>::MAX, Uint::MAX, Uint::MAX);
452 }
453
454 #[test]
455 fn gcd_sizes() {
456 run_tests::<{ U64::LIMBS }>();
457 run_tests::<{ U128::LIMBS }>();
458 run_tests::<{ U256::LIMBS }>();
459 run_tests::<{ U512::LIMBS }>();
460 run_tests::<{ U1024::LIMBS }>();
461 run_tests::<{ U2048::LIMBS }>();
462 run_tests::<{ U4096::LIMBS }>();
463 }
464 }
465
466 mod xgcd {
467 use crate::{Concat, Int, U64, U128, U256, U512, U1024, U2048, U4096, U8192, U16384, Uint};
468 use core::ops::Div;
469
470 fn test<const LIMBS: usize, const DOUBLE: usize>(lhs: Uint<LIMBS>, rhs: Uint<LIMBS>)
471 where
472 Uint<LIMBS>: Concat<LIMBS, Output = Uint<DOUBLE>>,
473 {
474 let output = lhs.xgcd(&rhs);
475 assert_eq!(output.gcd, lhs.gcd(&rhs));
476
477 if output.gcd > Uint::ZERO {
478 assert_eq!(output.lhs_on_gcd, lhs.div(output.gcd.to_nz().unwrap()));
479 assert_eq!(output.rhs_on_gcd, rhs.div(output.gcd.to_nz().unwrap()));
480 }
481
482 let (x, y) = output.bezout_coefficients();
483 assert_eq!(
484 x.concatenating_mul_unsigned(&lhs) + y.concatenating_mul_unsigned(&rhs),
485 *output.gcd.resize().as_int()
486 );
487 }
488
489 fn run_tests<const LIMBS: usize, const DOUBLE: usize>()
490 where
491 Uint<LIMBS>: Concat<LIMBS, Output = Uint<DOUBLE>>,
492 {
493 let min = Int::MIN.abs();
494 test(Uint::ZERO, Uint::ZERO);
495 test(Uint::ZERO, Uint::ONE);
496 test(Uint::ZERO, min);
497 test(Uint::ZERO, Uint::MAX);
498 test(Uint::ONE, Uint::ZERO);
499 test(Uint::ONE, Uint::ONE);
500 test(Uint::ONE, min);
501 test(Uint::ONE, Uint::MAX);
502 test(min, Uint::ZERO);
503 test(min, Uint::ONE);
504 test(min, Int::MIN.abs());
505 test(min, Uint::MAX);
506 test(Uint::MAX, Uint::ZERO);
507 test(Uint::MAX, Uint::ONE);
508 test(Uint::MAX, min);
509 test(Uint::MAX, Uint::MAX);
510 }
511
512 #[test]
513 fn binxgcd() {
514 run_tests::<{ U64::LIMBS }, { U128::LIMBS }>();
515 run_tests::<{ U128::LIMBS }, { U256::LIMBS }>();
516 run_tests::<{ U256::LIMBS }, { U512::LIMBS }>();
517 run_tests::<{ U512::LIMBS }, { U1024::LIMBS }>();
518 run_tests::<{ U1024::LIMBS }, { U2048::LIMBS }>();
519 run_tests::<{ U2048::LIMBS }, { U4096::LIMBS }>();
520 run_tests::<{ U4096::LIMBS }, { U8192::LIMBS }>();
521 run_tests::<{ U8192::LIMBS }, { U16384::LIMBS }>();
522 }
523
524 #[test]
525 fn regression_tests() {
526 let a = U256::from_be_hex(
528 "000000000000000000000000000000000000001B5DFB3BA1D549DFAF611B8D4C",
529 );
530 let b = U256::from_be_hex(
531 "000000000000345EAEDFA8CA03C1F0F5B578A787FE2D23B82A807F178B37FD8E",
532 );
533 test(a, b);
534
535 let a = U256::from_be_hex(
537 "000000000000000000000000000000000000001A0DEEF6F3AC2566149D925044",
538 );
539 let b = U256::from_be_hex(
540 "000000000000072B69C9DD0AA15F135675EA9C5180CF8FF0A59298CFC92E87FA",
541 );
542 test(a, b);
543
544 let a = U512::from_be_hex(concat![
546 "7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364142",
547 "4EB38E6AC0E34DE2F34BFAF22DE683E1F4B92847B6871C780488D797042229E1"
548 ]);
549 let b = U512::from_be_hex(concat![
550 "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFD755DB9CD5E9140777FA4BD19A06C8283",
551 "9D671CD581C69BC5E697F5E45BCD07C52EC373A8BDC598B4493F50A1380E1281"
552 ]);
553 test(a, b);
554 }
555 }
556
557 mod traits {
558 use crate::{Gcd, I256, U256};
559
560 #[test]
561 fn gcd_relatively_prime() {
562 let f = U256::from(59u32 * 67);
564 let g = U256::from(61u32 * 71);
565 let gcd = f.gcd(&g);
566 assert_eq!(gcd, U256::ONE);
567 }
568
569 #[test]
570 fn gcd_nonprime() {
571 let f = U256::from(4391633u32);
572 let g = U256::from(2022161u32);
573 let gcd = f.gcd(&g);
574 assert_eq!(gcd, U256::from(1763u32));
575 }
576
577 #[test]
578 fn gcd_zero() {
579 assert_eq!(U256::ZERO.gcd(&U256::ZERO), U256::ZERO);
580 assert_eq!(U256::ZERO.gcd(&U256::ONE), U256::ONE);
581 assert_eq!(U256::ONE.gcd(&U256::ZERO), U256::ONE);
582 }
583
584 #[test]
585 fn gcd_one() {
586 let f = U256::ONE;
587 assert_eq!(U256::ONE, f.gcd(&U256::ONE));
588 assert_eq!(U256::ONE, f.gcd(&U256::from(2u8)));
589 }
590
591 #[test]
592 fn gcd_two() {
593 let f = U256::from_u8(2);
594 assert_eq!(f, f.gcd(&f));
595
596 let g = U256::from_u8(4);
597 assert_eq!(f, f.gcd(&g));
598 assert_eq!(f, g.gcd(&f));
599 }
600
601 #[test]
602 fn gcd_unsigned_int() {
603 let f = U256::from(61u32 * 71);
605 let g = I256::from(59i32 * 61);
606
607 let sixty_one = U256::from(61u32);
608 assert_eq!(sixty_one, <U256 as Gcd<I256>>::gcd(&f, &g));
609 assert_eq!(sixty_one, <U256 as Gcd<I256>>::gcd(&f, &g.wrapping_neg()));
610 }
611
612 #[test]
613 fn xgcd_expected() {
614 let f = U256::from(61u32 * 71);
616 let g = U256::from(59u32 * 61);
617
618 let actual = f.xgcd(&g);
619 assert_eq!(U256::from(61u32), actual.gcd);
620 assert_eq!(I256::from(5i32), actual.x);
621 assert_eq!(I256::from(-6i32), actual.y);
622 }
623 }
624}