Skip to main content

crypto_bigint/modular/bingcd/
div_mod_2k.rs

1//! Compute `x / 2^k mod q` for some prime `q`.
2
3use crate::{Choice, Limb, Odd, OddUint, Uint, primitives::u32_min, word};
4
5impl<const LIMBS: usize> Uint<LIMBS> {
6    /// Compute `self / 2^k mod q`.
7    ///
8    /// Requires that `k_bound ≥ k`.
9    ///
10    /// Executes in variable time w.r.t. `k_bound` only; executes in constant time w.r.t. `k`.
11    #[inline]
12    pub(super) const fn bounded_div2k_mod_q(
13        self,
14        k: u32,
15        k_upper_bound: u32,
16        q: &Odd<Self>,
17    ) -> Self {
18        let one_half_mod_q = OddUint::half_mod(q);
19
20        // Invariant: x = self / 2^e mod q.
21        let (mut x, mut e) = (self, 0);
22
23        let max_iters_per_round = Limb::BITS - 1;
24        let rounds = k_upper_bound.div_ceil(max_iters_per_round);
25
26        let mut r = 0;
27        while r < rounds {
28            let f_upper_bound =
29                u32_min(k_upper_bound - r * max_iters_per_round, max_iters_per_round);
30            let f = u32_min(k - e, f_upper_bound);
31
32            // Find `s` s.t. qs + x = 0 mod 2^f
33            let (_, s) =
34                x.limbs[0].bounded_div2k_mod_q(f, f_upper_bound, one_half_mod_q.as_ref().limbs[0]);
35
36            // Set x <- (x + qs) / 2^f
37            x = q.mul_add_div2k(s, &x, f);
38            e += f;
39
40            r += 1;
41        }
42
43        x
44    }
45
46    /// Computes `(self * b) + addend + carry`, returning the result along with the new carry.
47    #[inline]
48    const fn carrying_mul_add_limb(
49        mut self,
50        b: Limb,
51        addend: &Self,
52        mut carry: Limb,
53    ) -> (Self, Limb) {
54        let mut i = 0;
55        while i < LIMBS {
56            (self.limbs[i], carry) = self.limbs[i].carrying_mul_add(b, addend.limbs[i], carry);
57            i += 1;
58        }
59        (self, carry)
60    }
61}
62
63impl Limb {
64    /// Compute `self / 2^t mod q`, returning the result, as well as the minimal factor `f` such
65    /// that `2^t` divides `self + q·f`.
66    ///
67    /// Here, `q := 2·one_half_mod_q + 1` is assumed odd and `t := min(k, k_upper_bound)`.
68    ///
69    /// Executes in variable time w.r.t. `k_upper_bound` only; executes in constant time w.r.t `k`.
70    const fn bounded_div2k_mod_q(
71        mut self,
72        k: u32,
73        k_upper_bound: u32,
74        one_half_mod_q: Self,
75    ) -> (Self, Self) {
76        let mut factor = Limb::ZERO;
77        let mut i = 0;
78        while i < k_upper_bound {
79            let execute = Choice::from_u32_lt(i, k);
80
81            let (shifted, carry) = self.shr1();
82            self = Self::select(self, shifted, execute);
83
84            let overflow = word::choice_from_msb(carry.0);
85            let add_back_q = overflow.and(execute);
86            self = self.wrapping_add(Self::select(Self::ZERO, one_half_mod_q, add_back_q));
87            factor = factor.bitxor(Self::select(Self::ZERO, Self::ONE.shl(i), add_back_q));
88            i += 1;
89        }
90
91        (self, factor)
92    }
93}
94
95impl<const LIMBS: usize> OddUint<LIMBS> {
96    /// Compute `1/2 mod q`.
97    const fn half_mod(q: &Self) -> Self {
98        //        1  / 2      mod q
99        // = (q + 1) / 2      mod q
100        // = (q - 1) / 2  + 1 mod q
101        // = floor(q / 2) + 1 mod q, since q is odd.
102        Odd::new_unchecked(q.as_ref().shr1().wrapping_add(&Uint::ONE))
103    }
104
105    /// Compute `((self * b) + addend) / 2^k`
106    #[allow(clippy::integer_division_remainder_used, reason = "needs triage")]
107    const fn mul_add_div2k(&self, b: Limb, addend: &Uint<LIMBS>, k: u32) -> Uint<LIMBS> {
108        // Compute `self * b + addend`
109        let (lo, hi) = self.as_ref().carrying_mul_add_limb(b, addend, Limb::ZERO);
110        // Divide by 2^k
111        lo.shr_limb_with_carry(k, hi.unbounded_shl(Limb::BITS - k))
112            .0
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use crate::{Limb, U128, Uint};
119
120    #[test]
121    fn test_uint_bounded_div2k_mod_q() {
122        let q = U128::from(3u64).to_odd().unwrap();
123
124        // Do nothing
125        let res = U128::ONE.shl_vartime(64).bounded_div2k_mod_q(0, 0, &q);
126        assert_eq!(res, U128::ONE.shl_vartime(64));
127
128        // Simply shift out 5 factors
129        let res = U128::ONE.shl_vartime(64).bounded_div2k_mod_q(5, 5, &q);
130        assert_eq!(res, U128::ONE.shl_vartime(59));
131
132        // Add in one factor of q
133        let res = U128::ONE.bounded_div2k_mod_q(1, 1, &q);
134        assert_eq!(res, U128::from(2u64));
135
136        // Add in many factors of q
137        let res = U128::from(8u64).bounded_div2k_mod_q(17, 17, &q);
138        assert_eq!(res, U128::ONE);
139
140        // Larger q
141        let q = U128::from(2864434311u64).to_odd().unwrap();
142        let res = U128::from(8u64).bounded_div2k_mod_q(17, 17, &q);
143        assert_eq!(res, U128::from(303681787u64));
144
145        // Shift greater than Limb::BITS
146        let q = U128::from_be_hex("0000AAAABBBB33330000AAAABBBB3333")
147            .to_odd()
148            .unwrap();
149        let res = U128::MAX.bounded_div2k_mod_q(71, 71, &q);
150        assert_eq!(res, U128::from_be_hex("00002D6F169DBBF300002D6F169DBBF3"));
151
152        // Have k_bound restrict the number of shifts to 0
153        let res = U128::MAX.bounded_div2k_mod_q(71, 0, &q);
154        assert_eq!(res, U128::MAX);
155
156        // Have k_bound < k
157        let res = U128::MAX.bounded_div2k_mod_q(71, 30, &q);
158        assert_eq!(res, U128::from_be_hex("000071EEB6013E76000071EEB6013E76"));
159
160        // Have k_bound >> k
161        let res = U128::MAX.bounded_div2k_mod_q(30, 127, &q);
162        assert_eq!(res, U128::from_be_hex("000071EEB6013E76000071EEB6013E76"));
163    }
164
165    #[test]
166    fn test_limb_bounded_div2k_mod_q() {
167        let x = Limb::MAX.wrapping_sub(Limb::from(15u32));
168        let q = Limb::from(55u32);
169        let half_mod_q = q.shr1().0.wrapping_add(Limb::ONE);
170
171        // Do nothing
172        let (res, factor) = x.bounded_div2k_mod_q(0, 3, half_mod_q);
173        assert_eq!(res, x);
174        assert_eq!(factor, Limb::ZERO);
175
176        // Divide by 2^4 without requiring the addition of q
177        let (res, factor) = x.bounded_div2k_mod_q(4, 4, half_mod_q);
178        assert_eq!(res, x.shr(4));
179        assert_eq!(factor, Limb::ZERO);
180
181        // Divide by 2^5, requiring a single addition of q * 2^4
182        let (res, factor) = x.bounded_div2k_mod_q(5, 5, half_mod_q);
183        assert_eq!(res, x.shr(5).wrapping_add(half_mod_q));
184        assert_eq!(factor, Limb::ONE.shl(4));
185
186        // Execute at most k_bound iterations
187        let (res, factor) = x.bounded_div2k_mod_q(5, 4, half_mod_q);
188        assert_eq!(res, x.shr(4));
189        assert_eq!(factor, Limb::ZERO);
190    }
191
192    #[test]
193    fn test_carrying_mul_add_limb() {
194        // Do nothing
195        let x = U128::from_be_hex("ABCDEF98765432100123456789FEDCBA");
196        let q = U128::MAX;
197        let f = Limb::ZERO;
198        let (res, carry) = q.carrying_mul_add_limb(f, &x, Limb::ZERO);
199        assert_eq!(res, x);
200        assert_eq!(carry, Limb::ZERO);
201
202        // f = 1
203        let x = U128::from_be_hex("ABCDEF98765432100123456789FEDCBA");
204        let q = U128::MAX;
205        let f = Limb::ONE;
206        let (res, carry) = q.carrying_mul_add_limb(f, &x, Limb::ZERO);
207        assert_eq!(res, x.wrapping_add(&q));
208        assert_eq!(carry, Limb::ONE);
209
210        // f = max
211        let x = U128::from_be_hex("ABCDEF98765432100123456789FEDCBA");
212        let q = U128::MAX;
213        let f = Limb::MAX;
214        let (res, mac_carry) = q.carrying_mul_add_limb(f, &x, Limb::ZERO);
215
216        let (qf_lo, qf_hi) = q.widening_mul(&Uint::new([f; 1]));
217        let (lo, carry) = qf_lo.carrying_add(&x, Limb::ZERO);
218        let (hi, carry) = qf_hi.carrying_add(&Uint::ZERO, carry);
219        assert_eq!(res, lo);
220        assert_eq!(mac_carry, hi.limbs[0]);
221        assert_eq!(carry, Limb::ZERO);
222    }
223}