Skip to main content

crypto_bigint/uint/boxed/
sqrt.rs

1//! [`BoxedUint`] square root operations.
2
3use crate::{
4    BitOps, BoxedUint, CheckedSquareRoot, ConcatenatingSquare, CtAssign, CtEq, CtGt, CtOption,
5    FloorSquareRoot, Limb,
6};
7use core::mem;
8
9impl BoxedUint {
10    /// Computes `floor(√(self))` in constant time.
11    ///
12    /// Callers can check if `self` is a square by squaring the result.
13    #[deprecated(since = "0.7.0", note = "please use `floor_sqrt` instead")]
14    #[must_use]
15    pub fn sqrt(&self) -> Self {
16        self.floor_sqrt()
17    }
18
19    /// Computes √(`self`) in constant time.
20    ///
21    /// Callers can check if `self` is a square by squaring the result.
22    #[must_use]
23    pub fn floor_sqrt(&self) -> Self {
24        // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13.
25        //
26        // See Hast, "Note on computation of integer square roots"
27        // for the proof of the sufficiency of the bound on iterations.
28        // https://github.com/RustCrypto/crypto-bigint/files/12600669/ct_sqrt.pdf
29
30        // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`.
31        // Will not overflow since `b <= BITS`.
32        let mut x = Self::one_with_precision(self.bits_precision());
33        x.unbounded_shl_assign_vartime((self.bits() + 1) >> 1); // ≥ √(`self`)
34
35        let mut nz_x = x.clone();
36        let mut quo = Self::zero_with_precision(self.bits_precision());
37        let mut rem = Self::zero_with_precision(self.bits_precision());
38        let mut i = 0;
39
40        // Repeat enough times to guarantee result has stabilized.
41        // TODO (#378): the tests indicate that just `Self::LOG2_BITS` may be enough.
42        while i < self.log2_bits() + 2 {
43            let x_nonzero = x.is_nonzero();
44            nz_x.ct_assign(&x, x_nonzero);
45
46            // Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)`
47            quo.limbs.copy_from_slice(&self.limbs);
48            rem.limbs.copy_from_slice(&nz_x.limbs);
49            quo.as_mut_uint_ref().div_rem(rem.as_mut_uint_ref());
50            x.conditional_carrying_add_assign(&quo, x_nonzero);
51            x.shr1_assign();
52
53            i += 1;
54        }
55
56        // At this point `x_prev == x_{n}` and `x == x_{n+1}`
57        // where `n == i - 1 == LOG2_BITS + 1 == floor(log2(BITS)) + 1`.
58        // Thus, according to Hast, `sqrt(self) = min(x_n, x_{n+1})`.
59        x.ct_assign(&nz_x, x.ct_gt(&nz_x));
60        x
61    }
62
63    /// Computes `floor(√(self))`.
64    ///
65    /// Callers can check if `self` is a square by squaring the result.
66    ///
67    /// Variable time with respect to `self`.
68    #[deprecated(since = "0.7.0", note = "please use `floor_sqrt_vartime` instead")]
69    #[must_use]
70    pub fn sqrt_vartime(&self) -> Self {
71        self.floor_sqrt_vartime()
72    }
73
74    /// Computes √(`self`).
75    ///
76    /// Callers can check if `self` is a square by squaring the result.
77    ///
78    /// Variable time with respect to `self`.
79    #[must_use]
80    pub fn floor_sqrt_vartime(&self) -> Self {
81        // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13
82
83        if self.is_zero_vartime() {
84            return Self::zero_with_precision(self.bits_precision());
85        }
86
87        // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`.
88        // Will not overflow since `b <= BITS`.
89        // The initial value of `x` is always greater than zero.
90        let mut x = Self::one_with_precision(self.bits_precision());
91        x.unbounded_shl_assign_vartime((self.bits() + 1) >> 1); // ≥ √(`self`)
92
93        let mut quo = Self::zero_with_precision(self.bits_precision());
94        let mut rem = Self::zero_with_precision(self.bits_precision());
95
96        loop {
97            // Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)`
98            quo.limbs.copy_from_slice(&self.limbs);
99            rem.limbs.copy_from_slice(&x.limbs);
100            quo.as_mut_uint_ref().div_rem_vartime(rem.as_mut_uint_ref());
101            quo.carrying_add_assign(&x, Limb::ZERO);
102            quo.shr1_assign();
103
104            // If `quo` is the same as `x` or greater, we reached convergence
105            // (`x` is guaranteed to either go down or oscillate between
106            // `sqrt(self)` and `sqrt(self) + 1`)
107            if !x.cmp_vartime(&quo).is_gt() {
108                break;
109            }
110            x.limbs.copy_from_slice(&quo.limbs);
111            if x.is_zero_vartime() {
112                break;
113            }
114        }
115
116        x
117    }
118
119    /// Wrapped sqrt is just `floor(√(self))`.
120    /// There’s no way wrapping could ever happen.
121    /// This function exists so that all operations are accounted for in the wrapping operations.
122    #[must_use]
123    pub fn wrapping_sqrt(&self) -> Self {
124        self.floor_sqrt()
125    }
126
127    /// Wrapped sqrt is just `floor(√(self))`.
128    /// There’s no way wrapping could ever happen.
129    /// This function exists so that all operations are accounted for in the wrapping operations.
130    ///
131    /// Variable time with respect to `self`.
132    #[must_use]
133    pub fn wrapping_sqrt_vartime(&self) -> Self {
134        self.floor_sqrt_vartime()
135    }
136
137    /// Perform checked sqrt, returning a [`CtOption`] which `is_some`
138    /// only if the square root is exact.
139    #[must_use]
140    pub fn checked_sqrt(&self) -> CtOption<Self> {
141        let r = self.floor_sqrt();
142        let s = r.wrapping_square();
143        CtOption::new(r, self.ct_eq(&s))
144    }
145
146    /// Perform checked sqrt, returning an [`Option`] which `is_some`
147    /// only if the square root is exact.
148    ///
149    /// Variable time with respect to `self`.
150    #[must_use]
151    pub fn checked_sqrt_vartime(&self) -> Option<Self> {
152        let r = self.floor_sqrt_vartime();
153        let s = r.wrapping_square();
154        if self.cmp_vartime(&s).is_eq() {
155            Some(r)
156        } else {
157            None
158        }
159    }
160
161    /// Assigns `floor(√(self))` to `self`, and returns a [`bool`]
162    /// indicating whether the square root is exact.
163    ///
164    /// Variable time with respect to `self`.
165    pub fn floor_sqrt_assign_vartime(&mut self) -> bool {
166        // TODO(tarcieri): more optimized implementation
167        let mut ret = self.floor_sqrt_vartime();
168        mem::swap(&mut ret, self);
169        self.concatenating_square() == ret
170    }
171}
172
173impl CheckedSquareRoot for BoxedUint {
174    type Output = Self;
175
176    fn checked_sqrt(&self) -> CtOption<Self::Output> {
177        self.checked_sqrt()
178    }
179
180    fn checked_sqrt_vartime(&self) -> Option<Self::Output> {
181        self.checked_sqrt_vartime()
182    }
183}
184
185impl FloorSquareRoot for BoxedUint {
186    fn floor_sqrt(&self) -> Self {
187        self.floor_sqrt()
188    }
189
190    fn floor_sqrt_vartime(&self) -> Self {
191        self.floor_sqrt_vartime()
192    }
193}
194
195#[cfg(test)]
196#[allow(clippy::integer_division_remainder_used, reason = "test")]
197mod tests {
198    use crate::{BoxedUint, Limb};
199
200    #[cfg(feature = "rand_core")]
201    use {
202        crate::RandomBits,
203        chacha20::ChaCha8Rng,
204        rand_core::{Rng, SeedableRng},
205    };
206
207    #[test]
208    fn edge() {
209        assert_eq!(
210            BoxedUint::zero_with_precision(256).floor_sqrt(),
211            BoxedUint::zero_with_precision(256)
212        );
213        assert_eq!(
214            BoxedUint::one_with_precision(256).floor_sqrt(),
215            BoxedUint::one_with_precision(256)
216        );
217        let mut half = BoxedUint::zero_with_precision(256);
218        for i in 0..half.limbs.len() / 2 {
219            half.limbs[i] = Limb::MAX;
220        }
221        let u256_max = !BoxedUint::zero_with_precision(256);
222        assert_eq!(u256_max.floor_sqrt(), half);
223
224        // Test edge cases that use up the maximum number of iterations.
225
226        // `x = (r + 1)^2 - 583`, where `r` is the expected square root.
227        assert_eq!(
228            BoxedUint::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d", 192)
229                .unwrap()
230                .floor_sqrt(),
231            BoxedUint::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21", 192)
232                .unwrap(),
233        );
234        assert_eq!(
235            BoxedUint::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d", 192)
236                .unwrap()
237                .floor_sqrt_vartime(),
238            BoxedUint::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21", 192)
239                .unwrap()
240        );
241
242        // `x = (r + 1)^2 - 205`, where `r` is the expected square root.
243        assert_eq!(
244            BoxedUint::from_be_hex(
245                "4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597",
246                256
247            )
248            .unwrap()
249            .floor_sqrt(),
250            BoxedUint::from_be_hex(
251                "000000000000000000000000000000008b3956339e8315cff66eb6107b610075",
252                256
253            )
254            .unwrap()
255        );
256        assert_eq!(
257            BoxedUint::from_be_hex(
258                "4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597",
259                256
260            )
261            .unwrap()
262            .floor_sqrt_vartime(),
263            BoxedUint::from_be_hex(
264                "000000000000000000000000000000008b3956339e8315cff66eb6107b610075",
265                256
266            )
267            .unwrap()
268        );
269    }
270
271    #[test]
272    fn edge_vartime() {
273        assert_eq!(
274            BoxedUint::zero_with_precision(256).floor_sqrt_vartime(),
275            BoxedUint::zero_with_precision(256)
276        );
277        assert_eq!(
278            BoxedUint::one_with_precision(256).floor_sqrt_vartime(),
279            BoxedUint::one_with_precision(256)
280        );
281        let mut half = BoxedUint::zero_with_precision(256);
282        for i in 0..half.limbs.len() / 2 {
283            half.limbs[i] = Limb::MAX;
284        }
285        let u256_max = !BoxedUint::zero_with_precision(256);
286        assert_eq!(u256_max.floor_sqrt_vartime(), half);
287    }
288
289    #[test]
290    fn simple() {
291        let tests = [
292            (4u8, 2u8),
293            (9, 3),
294            (16, 4),
295            (25, 5),
296            (36, 6),
297            (49, 7),
298            (64, 8),
299            (81, 9),
300            (100, 10),
301            (121, 11),
302            (144, 12),
303            (169, 13),
304        ];
305        for (a, e) in &tests {
306            let l = BoxedUint::from(*a);
307            let r = BoxedUint::from(*e);
308            assert_eq!(l.floor_sqrt(), r);
309            assert_eq!(l.floor_sqrt_vartime(), r);
310            assert!(l.checked_sqrt().is_some().to_bool());
311            assert!(l.checked_sqrt_vartime().is_some());
312        }
313    }
314
315    #[test]
316    fn nonsquares() {
317        assert_eq!(BoxedUint::from(2u8).floor_sqrt(), BoxedUint::from(1u8));
318        assert!(!BoxedUint::from(2u8).checked_sqrt().is_some().to_bool());
319        assert_eq!(BoxedUint::from(3u8).floor_sqrt(), BoxedUint::from(1u8));
320        assert!(!BoxedUint::from(3u8).checked_sqrt().is_some().to_bool());
321        assert_eq!(BoxedUint::from(5u8).floor_sqrt(), BoxedUint::from(2u8));
322        assert_eq!(BoxedUint::from(6u8).floor_sqrt(), BoxedUint::from(2u8));
323        assert_eq!(BoxedUint::from(7u8).floor_sqrt(), BoxedUint::from(2u8));
324        assert_eq!(BoxedUint::from(8u8).floor_sqrt(), BoxedUint::from(2u8));
325        assert_eq!(BoxedUint::from(10u8).floor_sqrt(), BoxedUint::from(3u8));
326    }
327
328    #[test]
329    fn nonsquares_vartime() {
330        assert_eq!(
331            BoxedUint::from(2u8).floor_sqrt_vartime(),
332            BoxedUint::from(1u8)
333        );
334        assert!(BoxedUint::from(2u8).checked_sqrt_vartime().is_none());
335        assert_eq!(
336            BoxedUint::from(3u8).floor_sqrt_vartime(),
337            BoxedUint::from(1u8)
338        );
339        assert!(BoxedUint::from(3u8).checked_sqrt_vartime().is_none());
340        assert_eq!(
341            BoxedUint::from(5u8).floor_sqrt_vartime(),
342            BoxedUint::from(2u8)
343        );
344        assert_eq!(
345            BoxedUint::from(6u8).floor_sqrt_vartime(),
346            BoxedUint::from(2u8)
347        );
348        assert_eq!(
349            BoxedUint::from(7u8).floor_sqrt_vartime(),
350            BoxedUint::from(2u8)
351        );
352        assert_eq!(
353            BoxedUint::from(8u8).floor_sqrt_vartime(),
354            BoxedUint::from(2u8)
355        );
356        assert_eq!(
357            BoxedUint::from(10u8).floor_sqrt_vartime(),
358            BoxedUint::from(3u8)
359        );
360    }
361
362    #[cfg(feature = "rand_core")]
363    #[test]
364    fn fuzz() {
365        use crate::{CheckedSquareRoot, ConcatenatingSquare};
366
367        let mut rng = ChaCha8Rng::from_seed([7u8; 32]);
368        let rounds = if cfg!(miri) { 10 } else { 50 };
369        for _ in 0..rounds {
370            let t = u64::from(rng.next_u32());
371            let s = BoxedUint::from(t);
372            let s2 = s.checked_mul(&s).unwrap();
373            assert_eq!(s2.floor_sqrt(), s);
374            assert_eq!(s2.floor_sqrt_vartime(), s);
375            assert!(CheckedSquareRoot::checked_sqrt(&s2).is_some().to_bool());
376            assert!(CheckedSquareRoot::checked_sqrt_vartime(&s2).is_some());
377        }
378
379        for _ in 0..rounds {
380            let s = BoxedUint::random_bits(&mut rng, 512);
381            let mut s2 = BoxedUint::zero_with_precision(512);
382            s2.limbs[..s.limbs.len()].copy_from_slice(&s.limbs);
383            assert_eq!(s.concatenating_square().floor_sqrt(), s2);
384            assert_eq!(s.concatenating_square().floor_sqrt_vartime(), s2);
385        }
386    }
387}