1use super::schoolbook;
31use crate::{Limb, Uint, UintRef};
32
33pub const MIN_STARTING_LIMBS: usize = 16;
34
35#[allow(clippy::cast_possible_truncation)]
40pub const fn widening_mul_fixed<const LHS: usize, const RHS: usize>(
41 lhs: &UintRef,
42 rhs: &UintRef,
43) -> (Uint<LHS>, Uint<RHS>) {
44 debug_assert!(lhs.nlimbs() == LHS && rhs.nlimbs() == RHS);
45
46 #[inline]
49 const fn reduce<const LHS: usize, const RHS: usize, const HALF: usize>(
50 x: &UintRef,
51 y: &UintRef,
52 ) -> (Uint<LHS>, Uint<RHS>) {
53 assert!(LHS <= RHS && LHS == HALF * 2);
54 let (x0, x1) = x.split_at(HALF);
55 let (y0, y1) = y.split_at(HALF);
56
57 let z0 = widening_mul_fixed(x0, y0);
59 let z2 = widening_mul_fixed(x1, y1);
61
62 let (mut l0, mut l1) = (Uint::<HALF>::ZERO, Uint::<HALF>::ZERO);
64 let (mut l0c, mut l1c) = (Limb::ZERO, Limb::ZERO);
65 let mut i = 0;
66 while i < HALF {
67 (l0.limbs[i], l0c) = x0.limbs[i].carrying_add(x1.limbs[i], l0c);
68 (l1.limbs[i], l1c) = y0.limbs[i].carrying_add(y1.limbs[i], l1c);
69 i += 1;
70 }
71 let z1 = widening_mul_fixed(l0.as_uint_ref(), l1.as_uint_ref());
72
73 let (mut s0, mut s1) = (z0.1, z2.0);
75 let (mut c, mut carry);
76
77 (s0, c) = s0.carrying_add(&z1.0, Limb::ZERO);
79 (s1, c) = s1.carrying_add(&z1.1, c);
80 carry = c;
81 (s1, c) = s1.carrying_add(
83 &Uint::select(&Uint::ZERO, &l0, l1c.is_nonzero()),
84 Limb::ZERO,
85 );
86 carry = carry.wrapping_add(c);
87 (s1, c) = s1.carrying_add(
88 &Uint::select(&Uint::ZERO, &l1, l0c.is_nonzero()),
89 Limb::ZERO,
90 );
91 carry = carry.wrapping_add(c);
92 carry = carry.wrapping_add(l0c.bitand(l1c));
93
94 (s0, c) = s0.borrowing_sub(&z0.0, Limb::ZERO);
96 (s1, c) = s1.borrowing_sub(&z0.1, c);
97 carry = carry.wrapping_add(c);
98 (s0, c) = s0.borrowing_sub(&z2.0, Limb::ZERO);
99 (s1, c) = s1.borrowing_sub(&z2.1, c);
100 carry = carry.wrapping_add(c);
101
102 (
103 concat_wide(&z0.0, &s0),
104 concat_wide(&s1, &z2.1.wrapping_add_limb(carry)),
105 )
106 }
107
108 if LHS < MIN_STARTING_LIMBS || RHS < MIN_STARTING_LIMBS {
110 let (mut lo, mut hi) = (Uint::ZERO, Uint::ZERO);
111 schoolbook::mul_wide(
112 lhs.as_limbs(),
113 rhs.as_limbs(),
114 lo.as_mut_limbs(),
115 hi.as_mut_limbs(),
116 );
117 (lo, hi)
118 }
119 else if LHS == RHS {
122 match LHS {
123 16 => reduce::<LHS, RHS, 8>(lhs, rhs),
124 32 => reduce::<LHS, RHS, 16>(lhs, rhs),
125 64 => reduce::<LHS, RHS, 32>(lhs, rhs),
126 128 => reduce::<LHS, RHS, 64>(lhs, rhs),
127 256 => reduce::<LHS, RHS, 128>(lhs, rhs),
128 _ => {
129 let mut lo_hi = [[Limb::ZERO; LHS]; 2];
130 wrapping_mul(lhs, rhs, UintRef::new_flattened_mut(&mut lo_hi), false);
131 (Uint::new(lo_hi[0]), Uint::new(lo_hi[1]).resize::<RHS>())
132 }
133 }
134 }
135 else if LHS < RHS {
138 let (y0, y1) = rhs.split_at(LHS);
139 let (lo, mut hi) = resize_wide(widening_mul_fixed::<LHS, LHS>(lhs, y0));
140 wrapping_mul(lhs, y1, hi.as_mut_uint_ref(), true);
141 (lo, hi)
142 }
143 else {
145 let (lo, hi) = widening_mul_fixed::<RHS, LHS>(rhs, lhs);
147 let mut lo = lo.resize::<LHS>();
149 lo.as_mut_uint_ref()
150 .trailing_mut(RHS)
151 .copy_from(hi.as_uint_ref().leading(LHS - RHS));
152 (
153 lo,
154 hi.unbounded_shr_by_limbs_vartime((LHS - RHS) as u32)
155 .resize::<RHS>(),
156 )
157 }
158}
159
160#[inline]
165pub const fn wrapping_mul_fixed<const LHS: usize>(
166 lhs: &UintRef,
167 rhs: &UintRef,
168) -> (Uint<LHS>, Limb) {
169 debug_assert!(lhs.nlimbs() == LHS);
170
171 #[inline]
178 const fn reduce<const LHS: usize, const HALF: usize>(
179 x: &UintRef,
180 y: &UintRef,
181 ) -> (Uint<LHS>, Limb) {
182 debug_assert!(LHS == HALF * 2);
183 let (x0, x1) = x.split_at(HALF);
184 let (y0, y1) = y.leading(LHS).split_at(HALF);
185
186 let z0 = widening_mul_fixed::<HALF, HALF>(x0, y0);
188 let (z1, z1c) = wrapping_mul_fixed::<HALF>(x0, y1);
190 let (z2, z2c) = wrapping_mul_fixed::<HALF>(x1, y0);
192
193 let (hi, c1) = z0.1.carrying_add(&z1, Limb::ZERO);
194 let (hi, c2) = hi.carrying_add(&z2, Limb::ZERO);
195 let carry = z1c.wrapping_add(z2c).wrapping_add(c1).wrapping_add(c2);
196
197 (concat_wide(&z0.0, &hi), carry)
198 }
199
200 if LHS < MIN_STARTING_LIMBS || rhs.nlimbs() < MIN_STARTING_LIMBS {
202 let mut lo = Uint::ZERO;
203 let carry = schoolbook::wrapping_mul_add(lhs.as_limbs(), rhs.as_limbs(), lo.as_mut_limbs());
204 return (lo, carry);
205 }
206 else if LHS <= rhs.nlimbs() {
209 match LHS {
210 16 => return reduce::<LHS, 8>(lhs, rhs),
211 32 => return reduce::<LHS, 16>(lhs, rhs),
212 64 => return reduce::<LHS, 32>(lhs, rhs),
213 128 => return reduce::<LHS, 64>(lhs, rhs),
214 256 => return reduce::<LHS, 128>(lhs, rhs),
215 _ => {}
216 }
217 }
218
219 let mut lo = Uint::ZERO;
221 let carry = wrapping_mul(lhs, rhs, lo.as_mut_uint_ref(), false);
222 (lo, carry)
223}
224
225pub const fn widening_square_fixed<const LIMBS: usize>(
230 uint: &UintRef,
231) -> (Uint<LIMBS>, Uint<LIMBS>) {
232 debug_assert!(
233 uint.nlimbs() == LIMBS,
234 "invalid arguments to widening_square_fixed"
235 );
236
237 #[inline]
240 const fn reduce<const LIMBS: usize, const HALF: usize>(
241 x: &UintRef,
242 ) -> (Uint<LIMBS>, Uint<LIMBS>) {
243 debug_assert!(LIMBS == HALF * 2);
244 let (x0, x1) = x.split_at(HALF);
245
246 let z0 = widening_square_fixed::<HALF>(x0);
248 let mut z1 = widening_mul_fixed::<HALF, HALF>(x0, x1);
250 let z2 = widening_square_fixed::<HALF>(x1);
252
253 let (mut c, mut carry);
254 (z1.0, c) = z1.0.shl1_with_carry(Limb::ZERO);
256 (z1.1, carry) = z1.1.shl1_with_carry(c);
257 (z1.0, c) = z1.0.carrying_add(&z0.1, Limb::ZERO);
259 (z1.1, c) = z1.1.carrying_add(&z2.0, c);
260 carry = carry.wrapping_add(c);
261
262 (
263 concat_wide(&z0.0, &z1.0),
264 concat_wide(&z1.1, &z2.1.wrapping_add_limb(carry)),
265 )
266 }
267
268 if LIMBS < MIN_STARTING_LIMBS {
270 let (mut lo, mut hi) = (Uint::ZERO, Uint::ZERO);
271 schoolbook::square_wide(uint.as_limbs(), lo.as_mut_limbs(), hi.as_mut_limbs());
272 (lo, hi)
273 }
274 else {
277 match LIMBS {
278 16 => reduce::<LIMBS, 8>(uint),
279 32 => reduce::<LIMBS, 16>(uint),
280 64 => reduce::<LIMBS, 32>(uint),
281 128 => reduce::<LIMBS, 64>(uint),
282 256 => reduce::<LIMBS, 128>(uint),
283 _ => {
284 let mut lo_hi = [[Limb::ZERO; LIMBS]; 2];
285 wrapping_square(uint, UintRef::new_flattened_mut(&mut lo_hi));
286 (Uint::new(lo_hi[0]), Uint::new(lo_hi[1]))
287 }
288 }
289 }
290}
291
292#[inline]
297pub const fn wrapping_square_fixed<const LIMBS: usize>(uint: &UintRef) -> (Uint<LIMBS>, Limb) {
298 let mut lo = Uint::ZERO;
299 let carry = wrapping_square(uint, lo.as_mut_uint_ref());
300 (lo, carry)
301}
302
303#[inline]
311pub const fn wrapping_mul(lhs: &UintRef, rhs: &UintRef, out: &mut UintRef, add: bool) -> Limb {
312 assert!(
313 lhs.nlimbs() + rhs.nlimbs() >= out.nlimbs(),
314 "invalid arguments to wrapping_mul"
315 );
316
317 const fn reduce<const LIMBS: usize>(
319 x: &UintRef,
320 y: &UintRef,
321 out: &mut UintRef,
322 add: bool,
323 ) -> Limb {
324 let out_len = out.nlimbs();
325
326 if out_len <= x.nlimbs() {
335 let (x0, x1) = x.leading(out_len).split_at(out_len - LIMBS);
336 let y0 = y.leading(LIMBS);
337
338 let (z1, mut carry) = wrapping_mul_fixed::<LIMBS>(x1, y0);
340 let assign = out.trailing_mut(out_len - LIMBS);
341 if add {
342 let c = assign.carrying_add_assign(z1.as_uint_ref(), Limb::ZERO);
343 carry = carry.wrapping_add(c);
344 } else {
345 assign.copy_from(z1.as_uint_ref());
346 }
347
348 if !x0.is_empty() {
350 let c = wrapping_mul(x0, y, out, true);
351 carry = carry.wrapping_add(c);
352 }
353 carry
354 }
355 else {
360 let (x0, x1) = x.split_at(LIMBS);
361 let y_len = if y.nlimbs() < out_len {
362 y.nlimbs()
363 } else {
364 out_len
365 };
366 let (y0, y1) = y.leading(y_len).split_at(LIMBS);
367
368 let (assign, tail) = out.split_at_mut(if out.nlimbs() < LIMBS * 2 {
369 out.nlimbs()
370 } else {
371 LIMBS * 2
372 });
373
374 let mut carry = if assign.nlimbs() < LIMBS * 2 {
375 if !add {
376 assign.fill(Limb::ZERO);
377 }
378 schoolbook::wrapping_mul_add(x0.as_limbs(), y0.as_limbs(), assign.as_mut_limbs())
379 } else {
380 let z0 = widening_mul_fixed::<LIMBS, LIMBS>(x0, y0);
381 let (lo, hi) = assign.split_at_mut(LIMBS);
382 if add {
383 let mut carry = lo.carrying_add_assign(z0.0.as_uint_ref(), Limb::ZERO);
384 carry = hi.carrying_add_assign(z0.1.as_uint_ref().leading(hi.nlimbs()), carry);
385 tail.add_assign_limb(carry)
386 } else {
387 lo.copy_from(z0.0.as_uint_ref());
388 hi.copy_from(z0.1.as_uint_ref().leading(hi.nlimbs()));
389 Limb::ZERO
390 }
391 };
392
393 if !x1.is_empty() {
395 let c = wrapping_mul(x1, y, out.trailing_mut(LIMBS), true);
396 carry = carry.wrapping_add(c);
397 }
398 if !y1.is_empty() {
400 let tail_len = out_len - LIMBS;
401 let assign_len = if y_len < tail_len { y_len } else { tail_len };
402 let (assign, tail) = out.trailing_mut(LIMBS).split_at_mut(assign_len);
403 let c = wrapping_mul(y1, x0, assign, true);
404 let c = tail.add_assign_limb(c);
405 carry = carry.wrapping_add(c);
406 }
407 carry
408 }
409 }
410
411 let overlap = if lhs.nlimbs() < rhs.nlimbs() {
412 lhs.nlimbs()
413 } else {
414 rhs.nlimbs()
415 };
416 let overlap = if overlap < out.nlimbs() {
417 overlap
418 } else {
419 out.nlimbs()
420 };
421 let split = previous_power_of_2(overlap);
422
423 if split < MIN_STARTING_LIMBS {
425 return schoolbook::wrapping_mul_add(lhs.as_limbs(), rhs.as_limbs(), out.as_mut_limbs());
426 }
427
428 match split {
430 16 => reduce::<16>(lhs, rhs, out, add),
431 32 => reduce::<32>(lhs, rhs, out, add),
432 64 => reduce::<64>(lhs, rhs, out, add),
433 128 => reduce::<128>(lhs, rhs, out, add),
434 _ => reduce::<256>(lhs, rhs, out, add),
435 }
436}
437
438#[inline]
442pub(crate) const fn wrapping_square(uint: &UintRef, out: &mut UintRef) -> Limb {
443 assert!(
444 out.nlimbs() <= uint.nlimbs() * 2,
445 "invalid arguments to wrapping_square"
446 );
447
448 const fn reduce<const LIMBS: usize>(x: &UintRef, out: &mut UintRef) -> Limb {
450 let (x0, x1) = x.split_at(LIMBS);
451 let (lo, hi) = out.split_at_mut(LIMBS);
452
453 let z0 = widening_square_fixed::<LIMBS>(x0);
455 lo.copy_from(z0.0.as_uint_ref());
456
457 if hi.nlimbs() <= LIMBS {
459 let (z1, _carry) = wrapping_mul_fixed::<LIMBS>(x0, x1);
460 let z1 = z1.shl1();
461 let z2 = z0.1.wrapping_add(&z1);
462 let (z2, tail) = z2.as_uint_ref().split_at(hi.nlimbs());
463 hi.copy_from(z2);
464 if tail.is_empty() {
465 Limb::ZERO
466 } else {
467 tail.limbs[0]
468 }
469 } else {
470 let (z01, z2) = hi.split_at_mut(LIMBS);
471 z01.copy_from(z0.1.as_uint_ref());
472 wrapping_square(x1, z2);
473 let mut dx0 = Uint::<LIMBS>::ZERO;
474 dx0.as_mut_uint_ref().copy_from(x0);
475 let (dx0, dx0_hi) = dx0.shl1_with_carry(Limb::ZERO);
476 let z2_len = if z2.nlimbs() < x1.nlimbs() {
477 z2.nlimbs()
478 } else {
479 x1.nlimbs()
480 };
481 let mut carry = z2.leading_mut(z2_len).conditional_add_assign(
482 x1.leading(z2_len),
483 Limb::ZERO,
484 dx0_hi.is_nonzero(),
485 );
486 let (z1, z1tail) = hi.split_at_mut(LIMBS + z2_len);
487 let c = wrapping_mul(dx0.as_uint_ref(), x1, z1, true);
488 carry = carry.wrapping_add(c);
489 z1tail.add_assign_limb(carry)
490 }
491 }
492
493 let x = if uint.nlimbs() >= out.nlimbs() {
495 uint.leading(out.nlimbs())
496 } else {
497 uint
498 };
499
500 if x.nlimbs() <= MIN_STARTING_LIMBS {
502 return schoolbook::wrapping_square(x.as_limbs(), out.as_mut_limbs());
503 }
504
505 let mut split = previous_power_of_2(out.nlimbs());
510 if split > x.nlimbs() || 2 * split >= out.nlimbs() + MIN_STARTING_LIMBS {
511 split /= 2;
512 }
513
514 match split {
516 16 => reduce::<16>(x, out),
517 32 => reduce::<32>(x, out),
518 64 => reduce::<64>(x, out),
519 128 => reduce::<128>(x, out),
520 _ => reduce::<256>(x, out),
521 }
522}
523
524#[inline]
526const fn concat_wide<const LIMBS: usize, const HALF: usize>(
527 lo: &Uint<HALF>,
528 hi: &Uint<HALF>,
529) -> Uint<LIMBS> {
530 assert!(LIMBS >= HALF * 2);
531 let mut res = Uint::<LIMBS>::ZERO;
532 let (lo_mut, hi_mut) = res
533 .as_mut_uint_ref()
534 .leading_mut(HALF * 2)
535 .split_at_mut(HALF);
536 lo_mut.copy_from_slice(lo.as_limbs());
537 hi_mut.copy_from_slice(hi.as_limbs());
538 res
539}
540
541#[inline(always)]
543const fn resize_wide<const LIMBS: usize, const LHS: usize, const RHS: usize>(
544 (lo, hi): (Uint<LIMBS>, Uint<LIMBS>),
545) -> (Uint<LHS>, Uint<RHS>) {
546 assert!(LHS == LIMBS && RHS >= LIMBS);
547 (lo.resize(), hi.resize())
548}
549
550#[inline]
552const fn previous_power_of_2(value: usize) -> usize {
553 if value == 0 {
554 0
555 } else {
556 1usize << value.ilog2()
557 }
558}
559
560#[cfg(feature = "rand_core")]
561#[cfg(test)]
562#[allow(clippy::integer_division_remainder_used, reason = "test")]
563mod tests {
564 use super::*;
565 use crate::Random;
566 use crate::{Limb, Uint, UintRef};
567 use rand_core::{Rng, SeedableRng};
568
569 fn assert_sparse_truncated_wide_split16_matches_schoolbook(add: bool) {
570 let mut lhs = [Limb::ZERO; 16];
571 let mut rhs = [Limb::ZERO; 17];
572 lhs[15] = Limb::MAX;
573 rhs[1] = Limb::MAX;
574 let mut karatsuba_out = [Limb::from(0x5au8); 17];
575 let mut schoolbook_out = karatsuba_out;
576
577 if !add {
578 schoolbook_out.fill(Limb::ZERO);
579 }
580
581 let karatsuba_carry = wrapping_mul(
582 UintRef::new(&lhs),
583 UintRef::new(&rhs),
584 UintRef::new_mut(&mut karatsuba_out),
585 add,
586 );
587 let schoolbook_carry = schoolbook::wrapping_mul_add(&lhs, &rhs, &mut schoolbook_out);
588
589 assert_eq!(karatsuba_out, schoolbook_out);
590 assert_eq!(karatsuba_carry, schoolbook_carry);
591 }
592
593 #[test]
594 fn truncated_wide_split16_carry_matches_schoolbook_when_overwriting() {
595 assert_sparse_truncated_wide_split16_matches_schoolbook(false);
596 }
597
598 #[test]
599 fn truncated_wide_split16_carry_matches_schoolbook_when_adding() {
600 assert_sparse_truncated_wide_split16_matches_schoolbook(true);
601 }
602
603 #[test]
604 fn wrapping_mul_sizes() {
605 const SIZE: usize = if cfg!(miri) { 10 } else { 40 };
606 let mut rng = chacha20::ChaCha8Rng::seed_from_u64(1);
607 for n in 0..100 {
608 let a = Uint::<SIZE>::random_from_rng(&mut rng);
609 let b = Uint::<SIZE>::random_from_rng(&mut rng);
610 let size_a = rng.next_u32() as usize % SIZE;
611 let size_b = rng.next_u32() as usize % SIZE;
612 let a = a.as_uint_ref().leading(size_a);
613 let b = b.as_uint_ref().leading(size_b);
614 let mut wide = [Limb::ZERO; SIZE * 2];
615 wrapping_mul(a, b, UintRef::new_mut(&mut wide[..size_a + size_b]), false);
616 for size in 0..size_a + size_b {
617 let mut check = [Limb::ZERO; SIZE * 2];
618 let wrapped = UintRef::new_mut(&mut check[..size]);
619 wrapping_mul(b, a, wrapped, false);
620 assert_eq!(
621 wrapped,
622 UintRef::new(&wide[..size]),
623 "comparison failed n={n}, a={a}, b={b}, size={size}"
624 );
625 }
626 }
627 }
628
629 #[test]
630 fn wrapping_square_sizes() {
631 const SIZE: usize = if cfg!(miri) { 10 } else { 40 };
632 let mut rng = chacha20::ChaCha8Rng::seed_from_u64(1);
633 for n in 0..100 {
634 let a = Uint::<SIZE>::random_from_rng(&mut rng);
635 let size_a = rng.next_u32() as usize % SIZE;
636 let a = a.as_uint_ref().leading(size_a);
637 let mut wide = [Limb::ZERO; SIZE * 2];
638 wrapping_mul(a, a, UintRef::new_mut(&mut wide[..size_a * 2]), false);
639
640 for size in 0..=size_a * 2 {
641 let mut check = [Limb::ZERO; SIZE * 2];
642 let wrapped = UintRef::new_mut(&mut check[..size]);
643 wrapping_square(a, wrapped);
644 assert_eq!(
645 wrapped,
646 UintRef::new(&wide[..size]),
647 "comparison failed n={n}, a={a}, size={size}"
648 );
649 }
650 }
651 }
652}