crypto_bigint/modular/bingcd/
div_mod_2k.rs1use crate::{Choice, Limb, Odd, OddUint, Uint, primitives::u32_min, word};
4
5impl<const LIMBS: usize> Uint<LIMBS> {
6 #[inline]
12 pub(super) const fn bounded_div2k_mod_q(
13 self,
14 k: u32,
15 k_upper_bound: u32,
16 q: &Odd<Self>,
17 ) -> Self {
18 let one_half_mod_q = OddUint::half_mod(q);
19
20 let (mut x, mut e) = (self, 0);
22
23 let max_iters_per_round = Limb::BITS - 1;
24 let rounds = k_upper_bound.div_ceil(max_iters_per_round);
25
26 let mut r = 0;
27 while r < rounds {
28 let f_upper_bound =
29 u32_min(k_upper_bound - r * max_iters_per_round, max_iters_per_round);
30 let f = u32_min(k - e, f_upper_bound);
31
32 let (_, s) =
34 x.limbs[0].bounded_div2k_mod_q(f, f_upper_bound, one_half_mod_q.as_ref().limbs[0]);
35
36 x = q.mul_add_div2k(s, &x, f);
38 e += f;
39
40 r += 1;
41 }
42
43 x
44 }
45
46 #[inline]
48 const fn carrying_mul_add_limb(
49 mut self,
50 b: Limb,
51 addend: &Self,
52 mut carry: Limb,
53 ) -> (Self, Limb) {
54 let mut i = 0;
55 while i < LIMBS {
56 (self.limbs[i], carry) = self.limbs[i].carrying_mul_add(b, addend.limbs[i], carry);
57 i += 1;
58 }
59 (self, carry)
60 }
61}
62
63impl Limb {
64 const fn bounded_div2k_mod_q(
71 mut self,
72 k: u32,
73 k_upper_bound: u32,
74 one_half_mod_q: Self,
75 ) -> (Self, Self) {
76 let mut factor = Limb::ZERO;
77 let mut i = 0;
78 while i < k_upper_bound {
79 let execute = Choice::from_u32_lt(i, k);
80
81 let (shifted, carry) = self.shr1();
82 self = Self::select(self, shifted, execute);
83
84 let overflow = word::choice_from_msb(carry.0);
85 let add_back_q = overflow.and(execute);
86 self = self.wrapping_add(Self::select(Self::ZERO, one_half_mod_q, add_back_q));
87 factor = factor.bitxor(Self::select(Self::ZERO, Self::ONE.shl(i), add_back_q));
88 i += 1;
89 }
90
91 (self, factor)
92 }
93}
94
95impl<const LIMBS: usize> OddUint<LIMBS> {
96 const fn half_mod(q: &Self) -> Self {
98 Odd::new_unchecked(q.as_ref().shr1().wrapping_add(&Uint::ONE))
103 }
104
105 #[allow(clippy::integer_division_remainder_used, reason = "needs triage")]
107 const fn mul_add_div2k(&self, b: Limb, addend: &Uint<LIMBS>, k: u32) -> Uint<LIMBS> {
108 let (lo, hi) = self.as_ref().carrying_mul_add_limb(b, addend, Limb::ZERO);
110 lo.shr_limb_with_carry(k, hi.unbounded_shl(Limb::BITS - k))
112 .0
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use crate::{Limb, U128, Uint};
119
120 #[test]
121 fn test_uint_bounded_div2k_mod_q() {
122 let q = U128::from(3u64).to_odd().unwrap();
123
124 let res = U128::ONE.shl_vartime(64).bounded_div2k_mod_q(0, 0, &q);
126 assert_eq!(res, U128::ONE.shl_vartime(64));
127
128 let res = U128::ONE.shl_vartime(64).bounded_div2k_mod_q(5, 5, &q);
130 assert_eq!(res, U128::ONE.shl_vartime(59));
131
132 let res = U128::ONE.bounded_div2k_mod_q(1, 1, &q);
134 assert_eq!(res, U128::from(2u64));
135
136 let res = U128::from(8u64).bounded_div2k_mod_q(17, 17, &q);
138 assert_eq!(res, U128::ONE);
139
140 let q = U128::from(2864434311u64).to_odd().unwrap();
142 let res = U128::from(8u64).bounded_div2k_mod_q(17, 17, &q);
143 assert_eq!(res, U128::from(303681787u64));
144
145 let q = U128::from_be_hex("0000AAAABBBB33330000AAAABBBB3333")
147 .to_odd()
148 .unwrap();
149 let res = U128::MAX.bounded_div2k_mod_q(71, 71, &q);
150 assert_eq!(res, U128::from_be_hex("00002D6F169DBBF300002D6F169DBBF3"));
151
152 let res = U128::MAX.bounded_div2k_mod_q(71, 0, &q);
154 assert_eq!(res, U128::MAX);
155
156 let res = U128::MAX.bounded_div2k_mod_q(71, 30, &q);
158 assert_eq!(res, U128::from_be_hex("000071EEB6013E76000071EEB6013E76"));
159
160 let res = U128::MAX.bounded_div2k_mod_q(30, 127, &q);
162 assert_eq!(res, U128::from_be_hex("000071EEB6013E76000071EEB6013E76"));
163 }
164
165 #[test]
166 fn test_limb_bounded_div2k_mod_q() {
167 let x = Limb::MAX.wrapping_sub(Limb::from(15u32));
168 let q = Limb::from(55u32);
169 let half_mod_q = q.shr1().0.wrapping_add(Limb::ONE);
170
171 let (res, factor) = x.bounded_div2k_mod_q(0, 3, half_mod_q);
173 assert_eq!(res, x);
174 assert_eq!(factor, Limb::ZERO);
175
176 let (res, factor) = x.bounded_div2k_mod_q(4, 4, half_mod_q);
178 assert_eq!(res, x.shr(4));
179 assert_eq!(factor, Limb::ZERO);
180
181 let (res, factor) = x.bounded_div2k_mod_q(5, 5, half_mod_q);
183 assert_eq!(res, x.shr(5).wrapping_add(half_mod_q));
184 assert_eq!(factor, Limb::ONE.shl(4));
185
186 let (res, factor) = x.bounded_div2k_mod_q(5, 4, half_mod_q);
188 assert_eq!(res, x.shr(4));
189 assert_eq!(factor, Limb::ZERO);
190 }
191
192 #[test]
193 fn test_carrying_mul_add_limb() {
194 let x = U128::from_be_hex("ABCDEF98765432100123456789FEDCBA");
196 let q = U128::MAX;
197 let f = Limb::ZERO;
198 let (res, carry) = q.carrying_mul_add_limb(f, &x, Limb::ZERO);
199 assert_eq!(res, x);
200 assert_eq!(carry, Limb::ZERO);
201
202 let x = U128::from_be_hex("ABCDEF98765432100123456789FEDCBA");
204 let q = U128::MAX;
205 let f = Limb::ONE;
206 let (res, carry) = q.carrying_mul_add_limb(f, &x, Limb::ZERO);
207 assert_eq!(res, x.wrapping_add(&q));
208 assert_eq!(carry, Limb::ONE);
209
210 let x = U128::from_be_hex("ABCDEF98765432100123456789FEDCBA");
212 let q = U128::MAX;
213 let f = Limb::MAX;
214 let (res, mac_carry) = q.carrying_mul_add_limb(f, &x, Limb::ZERO);
215
216 let (qf_lo, qf_hi) = q.widening_mul(&Uint::new([f; 1]));
217 let (lo, carry) = qf_lo.carrying_add(&x, Limb::ZERO);
218 let (hi, carry) = qf_hi.carrying_add(&Uint::ZERO, carry);
219 assert_eq!(res, lo);
220 assert_eq!(mac_carry, hi.limbs[0]);
221 assert_eq!(carry, Limb::ZERO);
222 }
223}