crypto_bigint/modular/bingcd/
div_mod_2k.rs1use crate::{Choice, Limb, Odd, OddUint, Uint, primitives::u32_min, word};
4
5impl<const LIMBS: usize> Uint<LIMBS> {
6 #[inline]
12 pub(super) const fn bounded_div2k_mod_q(
13 self,
14 k: u32,
15 k_upper_bound: u32,
16 q: &Odd<Self>,
17 ) -> Self {
18 let one_half_mod_q = OddUint::half_mod(q).0;
19
20 let (mut x, mut e) = (self, 0);
22
23 let max_iters_per_round = Limb::BITS - 1;
24 let rounds = k_upper_bound.div_ceil(max_iters_per_round);
25
26 let mut r = 0;
27 while r < rounds {
28 let f_upper_bound =
29 u32_min(k_upper_bound - r * max_iters_per_round, max_iters_per_round);
30 let f = u32_min(k - e, f_upper_bound);
31
32 let (_, s) = x.limbs[0].bounded_div2k_mod_q(f, f_upper_bound, one_half_mod_q.limbs[0]);
34
35 x = q.mul_add_div2k(s, &x, f);
37 e += f;
38
39 r += 1;
40 }
41
42 x
43 }
44
45 #[inline]
47 const fn carrying_mul_add_limb(
48 mut self,
49 b: Limb,
50 addend: &Self,
51 mut carry: Limb,
52 ) -> (Self, Limb) {
53 let mut i = 0;
54 while i < LIMBS {
55 (self.limbs[i], carry) = self.limbs[i].carrying_mul_add(b, addend.limbs[i], carry);
56 i += 1;
57 }
58 (self, carry)
59 }
60}
61
62impl Limb {
63 const fn bounded_div2k_mod_q(
70 mut self,
71 k: u32,
72 k_upper_bound: u32,
73 one_half_mod_q: Self,
74 ) -> (Self, Self) {
75 let mut factor = Limb::ZERO;
76 let mut i = 0;
77 while i < k_upper_bound {
78 let execute = Choice::from_u32_lt(i, k);
79
80 let (shifted, carry) = self.shr1();
81 self = Self::select(self, shifted, execute);
82
83 let overflow = word::choice_from_msb(carry.0);
84 let add_back_q = overflow.and(execute);
85 self = self.wrapping_add(Self::select(Self::ZERO, one_half_mod_q, add_back_q));
86 factor = factor.bitxor(Self::select(Self::ZERO, Self::ONE.shl(i), add_back_q));
87 i += 1;
88 }
89
90 (self, factor)
91 }
92}
93
94impl<const LIMBS: usize> OddUint<LIMBS> {
95 const fn half_mod(q: &Self) -> Self {
97 Odd(q.as_ref().shr1().wrapping_add(&Uint::ONE))
102 }
103
104 #[allow(clippy::integer_division_remainder_used, reason = "needs triage")]
106 const fn mul_add_div2k(&self, b: Limb, addend: &Uint<LIMBS>, k: u32) -> Uint<LIMBS> {
107 let (lo, hi) = self.as_ref().carrying_mul_add_limb(b, addend, Limb::ZERO);
109 lo.shr_limb_with_carry(k, hi.unbounded_shl(Limb::BITS - k))
111 .0
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use crate::{Limb, U128, Uint};
118
119 #[test]
120 fn test_uint_bounded_div2k_mod_q() {
121 let q = U128::from(3u64).to_odd().unwrap();
122
123 let res = U128::ONE.shl_vartime(64).bounded_div2k_mod_q(0, 0, &q);
125 assert_eq!(res, U128::ONE.shl_vartime(64));
126
127 let res = U128::ONE.shl_vartime(64).bounded_div2k_mod_q(5, 5, &q);
129 assert_eq!(res, U128::ONE.shl_vartime(59));
130
131 let res = U128::ONE.bounded_div2k_mod_q(1, 1, &q);
133 assert_eq!(res, U128::from(2u64));
134
135 let res = U128::from(8u64).bounded_div2k_mod_q(17, 17, &q);
137 assert_eq!(res, U128::ONE);
138
139 let q = U128::from(2864434311u64).to_odd().unwrap();
141 let res = U128::from(8u64).bounded_div2k_mod_q(17, 17, &q);
142 assert_eq!(res, U128::from(303681787u64));
143
144 let q = U128::from_be_hex("0000AAAABBBB33330000AAAABBBB3333")
146 .to_odd()
147 .unwrap();
148 let res = U128::MAX.bounded_div2k_mod_q(71, 71, &q);
149 assert_eq!(res, U128::from_be_hex("00002D6F169DBBF300002D6F169DBBF3"));
150
151 let res = U128::MAX.bounded_div2k_mod_q(71, 0, &q);
153 assert_eq!(res, U128::MAX);
154
155 let res = U128::MAX.bounded_div2k_mod_q(71, 30, &q);
157 assert_eq!(res, U128::from_be_hex("000071EEB6013E76000071EEB6013E76"));
158
159 let res = U128::MAX.bounded_div2k_mod_q(30, 127, &q);
161 assert_eq!(res, U128::from_be_hex("000071EEB6013E76000071EEB6013E76"));
162 }
163
164 #[test]
165 fn test_limb_bounded_div2k_mod_q() {
166 let x = Limb::MAX.wrapping_sub(Limb::from(15u32));
167 let q = Limb::from(55u32);
168 let half_mod_q = q.shr1().0.wrapping_add(Limb::ONE);
169
170 let (res, factor) = x.bounded_div2k_mod_q(0, 3, half_mod_q);
172 assert_eq!(res, x);
173 assert_eq!(factor, Limb::ZERO);
174
175 let (res, factor) = x.bounded_div2k_mod_q(4, 4, half_mod_q);
177 assert_eq!(res, x.shr(4));
178 assert_eq!(factor, Limb::ZERO);
179
180 let (res, factor) = x.bounded_div2k_mod_q(5, 5, half_mod_q);
182 assert_eq!(res, x.shr(5).wrapping_add(half_mod_q));
183 assert_eq!(factor, Limb::ONE.shl(4));
184
185 let (res, factor) = x.bounded_div2k_mod_q(5, 4, half_mod_q);
187 assert_eq!(res, x.shr(4));
188 assert_eq!(factor, Limb::ZERO);
189 }
190
191 #[test]
192 fn test_carrying_mul_add_limb() {
193 let x = U128::from_be_hex("ABCDEF98765432100123456789FEDCBA");
195 let q = U128::MAX;
196 let f = Limb::ZERO;
197 let (res, carry) = q.carrying_mul_add_limb(f, &x, Limb::ZERO);
198 assert_eq!(res, x);
199 assert_eq!(carry, Limb::ZERO);
200
201 let x = U128::from_be_hex("ABCDEF98765432100123456789FEDCBA");
203 let q = U128::MAX;
204 let f = Limb::ONE;
205 let (res, carry) = q.carrying_mul_add_limb(f, &x, Limb::ZERO);
206 assert_eq!(res, x.wrapping_add(&q));
207 assert_eq!(carry, Limb::ONE);
208
209 let x = U128::from_be_hex("ABCDEF98765432100123456789FEDCBA");
211 let q = U128::MAX;
212 let f = Limb::MAX;
213 let (res, mac_carry) = q.carrying_mul_add_limb(f, &x, Limb::ZERO);
214
215 let (qf_lo, qf_hi) = q.widening_mul(&Uint::new([f; 1]));
216 let (lo, carry) = qf_lo.carrying_add(&x, Limb::ZERO);
217 let (hi, carry) = qf_hi.carrying_add(&Uint::ZERO, carry);
218 assert_eq!(res, lo);
219 assert_eq!(mac_carry, hi.limbs[0]);
220 assert_eq!(carry, Limb::ZERO);
221 }
222}