Skip to main content

crypto_bigint/modular/bingcd/
div_mod_2k.rs

1//! Compute `x / 2^k mod q` for some prime `q`.
2
3use crate::{Choice, Limb, Odd, OddUint, Uint, primitives::u32_min, word};
4
5impl<const LIMBS: usize> Uint<LIMBS> {
6    /// Compute `self / 2^k mod q`.
7    ///
8    /// Requires that `k_bound ≥ k`.
9    ///
10    /// Executes in variable time w.r.t. `k_bound` only; executes in constant time w.r.t. `k`.
11    #[inline]
12    pub(super) const fn bounded_div2k_mod_q(
13        self,
14        k: u32,
15        k_upper_bound: u32,
16        q: &Odd<Self>,
17    ) -> Self {
18        let one_half_mod_q = OddUint::half_mod(q).0;
19
20        // Invariant: x = self / 2^e mod q.
21        let (mut x, mut e) = (self, 0);
22
23        let max_iters_per_round = Limb::BITS - 1;
24        let rounds = k_upper_bound.div_ceil(max_iters_per_round);
25
26        let mut r = 0;
27        while r < rounds {
28            let f_upper_bound =
29                u32_min(k_upper_bound - r * max_iters_per_round, max_iters_per_round);
30            let f = u32_min(k - e, f_upper_bound);
31
32            // Find `s` s.t. qs + x = 0 mod 2^f
33            let (_, s) = x.limbs[0].bounded_div2k_mod_q(f, f_upper_bound, one_half_mod_q.limbs[0]);
34
35            // Set x <- (x + qs) / 2^f
36            x = q.mul_add_div2k(s, &x, f);
37            e += f;
38
39            r += 1;
40        }
41
42        x
43    }
44
45    /// Computes `(self * b) + addend + carry`, returning the result along with the new carry.
46    #[inline]
47    const fn carrying_mul_add_limb(
48        mut self,
49        b: Limb,
50        addend: &Self,
51        mut carry: Limb,
52    ) -> (Self, Limb) {
53        let mut i = 0;
54        while i < LIMBS {
55            (self.limbs[i], carry) = self.limbs[i].carrying_mul_add(b, addend.limbs[i], carry);
56            i += 1;
57        }
58        (self, carry)
59    }
60}
61
62impl Limb {
63    /// Compute `self / 2^t mod q`, returning the result, as well as the minimal factor `f` such
64    /// that `2^t` divides `self + q·f`.
65    ///
66    /// Here, `q := 2·one_half_mod_q + 1` is assumed odd and `t := min(k, k_upper_bound)`.
67    ///
68    /// Executes in variable time w.r.t. `k_upper_bound` only; executes in constant time w.r.t `k`.
69    const fn bounded_div2k_mod_q(
70        mut self,
71        k: u32,
72        k_upper_bound: u32,
73        one_half_mod_q: Self,
74    ) -> (Self, Self) {
75        let mut factor = Limb::ZERO;
76        let mut i = 0;
77        while i < k_upper_bound {
78            let execute = Choice::from_u32_lt(i, k);
79
80            let (shifted, carry) = self.shr1();
81            self = Self::select(self, shifted, execute);
82
83            let overflow = word::choice_from_msb(carry.0);
84            let add_back_q = overflow.and(execute);
85            self = self.wrapping_add(Self::select(Self::ZERO, one_half_mod_q, add_back_q));
86            factor = factor.bitxor(Self::select(Self::ZERO, Self::ONE.shl(i), add_back_q));
87            i += 1;
88        }
89
90        (self, factor)
91    }
92}
93
94impl<const LIMBS: usize> OddUint<LIMBS> {
95    /// Compute `1/2 mod q`.
96    const fn half_mod(q: &Self) -> Self {
97        //        1  / 2      mod q
98        // = (q + 1) / 2      mod q
99        // = (q - 1) / 2  + 1 mod q
100        // = floor(q / 2) + 1 mod q, since q is odd.
101        Odd(q.as_ref().shr1().wrapping_add(&Uint::ONE))
102    }
103
104    /// Compute `((self * b) + addend) / 2^k`
105    #[allow(clippy::integer_division_remainder_used, reason = "needs triage")]
106    const fn mul_add_div2k(&self, b: Limb, addend: &Uint<LIMBS>, k: u32) -> Uint<LIMBS> {
107        // Compute `self * b + addend`
108        let (lo, hi) = self.as_ref().carrying_mul_add_limb(b, addend, Limb::ZERO);
109        // Divide by 2^k
110        lo.shr_limb_with_carry(k, hi.unbounded_shl(Limb::BITS - k))
111            .0
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use crate::{Limb, U128, Uint};
118
119    #[test]
120    fn test_uint_bounded_div2k_mod_q() {
121        let q = U128::from(3u64).to_odd().unwrap();
122
123        // Do nothing
124        let res = U128::ONE.shl_vartime(64).bounded_div2k_mod_q(0, 0, &q);
125        assert_eq!(res, U128::ONE.shl_vartime(64));
126
127        // Simply shift out 5 factors
128        let res = U128::ONE.shl_vartime(64).bounded_div2k_mod_q(5, 5, &q);
129        assert_eq!(res, U128::ONE.shl_vartime(59));
130
131        // Add in one factor of q
132        let res = U128::ONE.bounded_div2k_mod_q(1, 1, &q);
133        assert_eq!(res, U128::from(2u64));
134
135        // Add in many factors of q
136        let res = U128::from(8u64).bounded_div2k_mod_q(17, 17, &q);
137        assert_eq!(res, U128::ONE);
138
139        // Larger q
140        let q = U128::from(2864434311u64).to_odd().unwrap();
141        let res = U128::from(8u64).bounded_div2k_mod_q(17, 17, &q);
142        assert_eq!(res, U128::from(303681787u64));
143
144        // Shift greater than Limb::BITS
145        let q = U128::from_be_hex("0000AAAABBBB33330000AAAABBBB3333")
146            .to_odd()
147            .unwrap();
148        let res = U128::MAX.bounded_div2k_mod_q(71, 71, &q);
149        assert_eq!(res, U128::from_be_hex("00002D6F169DBBF300002D6F169DBBF3"));
150
151        // Have k_bound restrict the number of shifts to 0
152        let res = U128::MAX.bounded_div2k_mod_q(71, 0, &q);
153        assert_eq!(res, U128::MAX);
154
155        // Have k_bound < k
156        let res = U128::MAX.bounded_div2k_mod_q(71, 30, &q);
157        assert_eq!(res, U128::from_be_hex("000071EEB6013E76000071EEB6013E76"));
158
159        // Have k_bound >> k
160        let res = U128::MAX.bounded_div2k_mod_q(30, 127, &q);
161        assert_eq!(res, U128::from_be_hex("000071EEB6013E76000071EEB6013E76"));
162    }
163
164    #[test]
165    fn test_limb_bounded_div2k_mod_q() {
166        let x = Limb::MAX.wrapping_sub(Limb::from(15u32));
167        let q = Limb::from(55u32);
168        let half_mod_q = q.shr1().0.wrapping_add(Limb::ONE);
169
170        // Do nothing
171        let (res, factor) = x.bounded_div2k_mod_q(0, 3, half_mod_q);
172        assert_eq!(res, x);
173        assert_eq!(factor, Limb::ZERO);
174
175        // Divide by 2^4 without requiring the addition of q
176        let (res, factor) = x.bounded_div2k_mod_q(4, 4, half_mod_q);
177        assert_eq!(res, x.shr(4));
178        assert_eq!(factor, Limb::ZERO);
179
180        // Divide by 2^5, requiring a single addition of q * 2^4
181        let (res, factor) = x.bounded_div2k_mod_q(5, 5, half_mod_q);
182        assert_eq!(res, x.shr(5).wrapping_add(half_mod_q));
183        assert_eq!(factor, Limb::ONE.shl(4));
184
185        // Execute at most k_bound iterations
186        let (res, factor) = x.bounded_div2k_mod_q(5, 4, half_mod_q);
187        assert_eq!(res, x.shr(4));
188        assert_eq!(factor, Limb::ZERO);
189    }
190
191    #[test]
192    fn test_carrying_mul_add_limb() {
193        // Do nothing
194        let x = U128::from_be_hex("ABCDEF98765432100123456789FEDCBA");
195        let q = U128::MAX;
196        let f = Limb::ZERO;
197        let (res, carry) = q.carrying_mul_add_limb(f, &x, Limb::ZERO);
198        assert_eq!(res, x);
199        assert_eq!(carry, Limb::ZERO);
200
201        // f = 1
202        let x = U128::from_be_hex("ABCDEF98765432100123456789FEDCBA");
203        let q = U128::MAX;
204        let f = Limb::ONE;
205        let (res, carry) = q.carrying_mul_add_limb(f, &x, Limb::ZERO);
206        assert_eq!(res, x.wrapping_add(&q));
207        assert_eq!(carry, Limb::ONE);
208
209        // f = max
210        let x = U128::from_be_hex("ABCDEF98765432100123456789FEDCBA");
211        let q = U128::MAX;
212        let f = Limb::MAX;
213        let (res, mac_carry) = q.carrying_mul_add_limb(f, &x, Limb::ZERO);
214
215        let (qf_lo, qf_hi) = q.widening_mul(&Uint::new([f; 1]));
216        let (lo, carry) = qf_lo.carrying_add(&x, Limb::ZERO);
217        let (hi, carry) = qf_hi.carrying_add(&Uint::ZERO, carry);
218        assert_eq!(res, lo);
219        assert_eq!(mac_carry, hi.limbs[0]);
220        assert_eq!(carry, Limb::ZERO);
221    }
222}