Skip to main content

crypto_bigint/modular/
sqrt.rs

1use crate::{
2    Choice, CtOption, Uint,
3    modular::{FixedMontyForm, FixedMontyParams, prime_params::PrimeParams},
4};
5
6/// Compute a modular square root (if it exists) given [`MontyParams`]
7/// and [`PrimeParams`] corresponding to `monty_value`, in Montgomery form.
8#[must_use]
9pub const fn sqrt_montgomery_form<const LIMBS: usize>(
10    monty_value: &Uint<LIMBS>,
11    monty_params: &FixedMontyParams<LIMBS>,
12    prime_params: &PrimeParams<LIMBS>,
13) -> CtOption<Uint<LIMBS>> {
14    let value = FixedMontyForm::from_montgomery(*monty_value, monty_params);
15    let b = value.pow_vartime(&prime_params.sqrt_exp);
16
17    // Constant-time versions of modular square root algorithms based on:
18    // Koo, N., Cho, G.H. and Kwon, S. (2013), "Square root algorithm in 𝔽q for q ≡ 2s + 1 (mod 2s+1)".
19    // Electron. Lett., 49: 467-469. https://doi.org/10.1049/el.2012.4239
20
21    let x = match prime_params.s.get() {
22        1 => {
23            // Shanks algorithm: sqrt = x^((p+1)/4) = x^(t+1)
24            b
25        }
26        2 => {
27            // Algorithm 3: p = 5 mod 8 (Atkins variant)
28            let ru = FixedMontyForm::from_montgomery(prime_params.monty_root_unity, monty_params);
29            let cb = value.mul(&b);
30            let zeta = cb.mul(&b);
31            let is_one = Uint::eq(zeta.as_montgomery(), monty_params.one());
32            monty_select(&cb.mul(&ru), &cb, is_one)
33        }
34        3 => {
35            // Algorithm 4: p = 9 mod 16
36            let ru = FixedMontyForm::from_montgomery(prime_params.monty_root_unity, monty_params);
37            let ru_2 =
38                FixedMontyForm::from_montgomery(prime_params.monty_root_unity_p2, monty_params);
39            let ru_3 = ru.mul(&ru_2);
40            let cb = value.mul(&b);
41            let zeta = cb.mul(&b);
42
43            let mut m = monty_select(
44                &ru,
45                &FixedMontyForm::one(ru.params()),
46                Uint::eq(zeta.as_montgomery(), monty_params.one()),
47            );
48            // m = ru^2 if zeta = -1
49            m = monty_select(
50                &m,
51                &ru_2,
52                Uint::eq(zeta.neg().as_montgomery(), monty_params.one()),
53            );
54            // m = ru^3 if zeta = ru^2
55            m = monty_select(&m, &ru_3, monty_eq(&zeta, &ru_2));
56
57            cb.mul(&m)
58        }
59        4 => {
60            // Algorithm 5: p = 17 mod 32
61            let ru = FixedMontyForm::from_montgomery(prime_params.monty_root_unity, monty_params);
62            let ru_2 =
63                FixedMontyForm::from_montgomery(prime_params.monty_root_unity_p2, monty_params);
64            let ru_4 = ru_2.square();
65            let ru_6 = ru_2.mul(&ru_4);
66            let cb = value.mul(&b);
67            let zeta = cb.mul(&b);
68
69            let neg_zeta = zeta.neg();
70            let zeta_b = monty_eq(&zeta, &ru_2);
71            let neg_zeta_b = monty_eq(&neg_zeta, &ru_2);
72            let zeta_d = monty_eq(&zeta, &ru_6);
73
74            // m = B if -zeta in {B, C}, else 1
75            let mut m = monty_select(
76                &FixedMontyForm::one(ru.params()),
77                &ru_2,
78                neg_zeta_b.or(monty_eq(&neg_zeta, &ru_4)),
79            );
80            // m = C if zeta in {-1, D}
81            m = monty_select(
82                &m,
83                &ru_4,
84                Uint::eq(neg_zeta.as_montgomery(), monty_params.one()).or(zeta_d),
85            );
86            // m = D if zeta in {B, C}
87            m = monty_select(&m, &ru_6, zeta_b.or(monty_eq(&zeta, &ru_4)));
88            // m = m•ru if zeta or -zeta in {B, D}
89            m = monty_select(
90                &m,
91                &m.mul(&ru),
92                zeta_b
93                    .or(zeta_d)
94                    .or(neg_zeta_b)
95                    .or(monty_eq(&neg_zeta, &ru_6)),
96            );
97
98            cb.mul(&m)
99        }
100        _ => {
101            // Tonelli-Shanks
102            let mut x = value.mul(&b);
103            let mut d = x.mul(&b);
104            let mut z =
105                FixedMontyForm::from_montgomery(prime_params.monty_root_unity, monty_params);
106            let mut v = prime_params.s.get();
107            let mut max_v = v;
108
109            while max_v >= 1 {
110                let mut k = 1;
111                let mut tmp = d.square();
112                let mut j_less_than_v = Choice::TRUE;
113
114                let mut j = 2;
115                while j < max_v {
116                    let tmp_is_one = Uint::eq(tmp.as_montgomery(), monty_params.one());
117                    let squared = monty_select(&tmp, &z, tmp_is_one).square();
118                    tmp = monty_select(&squared, &tmp, tmp_is_one);
119                    j_less_than_v = j_less_than_v.and(Choice::from_u32_eq(j, v).not());
120                    z = monty_select(&z, &squared, tmp_is_one.and(j_less_than_v));
121                    k = tmp_is_one.select_u32(j, k);
122                    j += 1;
123                }
124
125                let b_is_one = Uint::eq(d.as_montgomery(), monty_params.one());
126                x = monty_select(&x.mul(&z), &x, b_is_one);
127                z = z.square();
128                d = d.mul(&z);
129                v = k;
130                max_v -= 1;
131            }
132
133            x
134        }
135    };
136
137    CtOption::new(x.to_montgomery(), monty_eq(&x.square(), &value))
138}
139
140const fn monty_eq<const LIMBS: usize>(
141    a: &FixedMontyForm<LIMBS>,
142    b: &FixedMontyForm<LIMBS>,
143) -> Choice {
144    Uint::eq(a.as_montgomery(), b.as_montgomery())
145}
146
147const fn monty_select<const LIMBS: usize>(
148    a: &FixedMontyForm<LIMBS>,
149    b: &FixedMontyForm<LIMBS>,
150    c: Choice,
151) -> FixedMontyForm<LIMBS> {
152    FixedMontyForm::from_montgomery(
153        Uint::select(a.as_montgomery(), b.as_montgomery(), c),
154        a.params(),
155    )
156}
157
158#[cfg(test)]
159mod tests {
160    use super::sqrt_montgomery_form;
161    use crate::{
162        Odd, U256, U576, Uint,
163        modular::{FixedMontyForm, FixedMontyParams, PrimeParams},
164    };
165
166    fn root_of_unity<const LIMBS: usize>(
167        monty_params: &FixedMontyParams<LIMBS>,
168        prime_params: &PrimeParams<LIMBS>,
169    ) -> Uint<LIMBS> {
170        FixedMontyForm::from_montgomery(prime_params.monty_root_unity, monty_params).retrieve()
171    }
172
173    fn test_monty_sqrt<const LIMBS: usize>(
174        monty_params: FixedMontyParams<LIMBS>,
175        prime_params: PrimeParams<LIMBS>,
176    ) {
177        let modulus = monty_params.modulus.get();
178        let rounds = if cfg!(miri) { 1..=2 } else { 0..=256 };
179        for i in rounds {
180            let s = i * i;
181            let s_monty = FixedMontyForm::new(&Uint::from_u32(s), &monty_params);
182            let rt_monty =
183                sqrt_montgomery_form(s_monty.as_montgomery(), &monty_params, &prime_params)
184                    .expect("no sqrt found");
185            let rt = FixedMontyForm::from_montgomery(rt_monty, &monty_params).retrieve();
186            let i = Uint::from_u32(i);
187            assert!(
188                Uint::eq(&rt, &i)
189                    .or(Uint::eq(&rt, &modulus.wrapping_sub(&i)))
190                    .to_bool_vartime()
191            );
192        }
193
194        // generator must be non-residue
195        let generator = Uint::from_u32(prime_params.generator.get());
196        let gen_monty = FixedMontyForm::new(&generator, &monty_params);
197        assert!(
198            sqrt_montgomery_form(gen_monty.as_montgomery(), &monty_params, &prime_params)
199                .is_none()
200                .to_bool_vartime()
201        );
202    }
203
204    #[test]
205    fn mod_sqrt_s_1() {
206        // p = 3 mod 4, s = 1
207        // P-256 field modulus
208        let monty_params = FixedMontyParams::new_vartime(Odd::<U256>::from_be_hex(
209            "ffffffff00000001000000000000000000000000ffffffffffffffffffffffff",
210        ));
211        let prime_params = PrimeParams::new_vartime(&monty_params, 6);
212        assert_eq!(prime_params.s.get(), 1);
213        assert_eq!(prime_params.generator.get(), 6);
214        assert_eq!(
215            root_of_unity(&monty_params, &prime_params),
216            U256::from_be_hex("FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFE")
217        );
218
219        test_monty_sqrt(monty_params, prime_params);
220    }
221
222    #[test]
223    fn mod_sqrt_s_2() {
224        // p = 5 mod 8, s = 2
225        // ed25519 base field
226        let monty_params = FixedMontyParams::new_vartime(Odd::<U256>::from_be_hex(
227            "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed",
228        ));
229        let prime_params = PrimeParams::new_vartime(&monty_params, 2);
230        assert_eq!(prime_params.s.get(), 2);
231        assert_eq!(prime_params.generator.get(), 2);
232        assert_eq!(
233            root_of_unity(&monty_params, &prime_params),
234            U256::from_be_hex("2B8324804FC1DF0B2B4D00993DFBD7A72F431806AD2FE478C4EE1B274A0EA0B0")
235        );
236
237        test_monty_sqrt(monty_params, prime_params);
238    }
239
240    #[test]
241    fn mod_sqrt_s_3() {
242        // p = 9 mod 16, s = 3
243        // brainpoolP384 scalar field
244        let monty_params = FixedMontyParams::new_vartime(Odd::<U576>::from_be_hex(
245            "00000000000001fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffa51868783bf2f966b7fcc0148f709a5d03bb5c9b8899c47aebb6fb71e91386409",
246        ));
247        let prime_params = PrimeParams::new_vartime(&monty_params, 3);
248        assert_eq!(prime_params.s.get(), 3);
249        assert_eq!(prime_params.generator.get(), 3);
250        assert_eq!(
251            root_of_unity(&monty_params, &prime_params),
252            U576::from_be_hex(
253                "000000000000009a0a650d44b28c17f3d708ad2fa8c4fbc7e6000d7c12dafa92fcc5673a3055276d535f79ff391dcdbcd998b7836647d3a72472b3da861ac810a7f9c7b7b63e2205"
254            )
255        );
256
257        test_monty_sqrt(monty_params, prime_params);
258    }
259
260    #[test]
261    fn mod_sqrt_s_4() {
262        // p = 17 mod 32, s = 4
263        // P-256 scalar field
264        let monty_params = FixedMontyParams::new_vartime(Odd::<U256>::from_be_hex(
265            "ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551",
266        ));
267        let prime_params = PrimeParams::new_vartime(&monty_params, 7);
268        assert_eq!(prime_params.s.get(), 4);
269        assert_eq!(prime_params.generator.get(), 7);
270        assert_eq!(
271            root_of_unity(&monty_params, &prime_params),
272            U256::from_be_hex("ffc97f062a770992ba807ace842a3dfc1546cad004378daf0592d7fbb41e6602")
273        );
274
275        test_monty_sqrt(monty_params, prime_params);
276    }
277
278    #[test]
279    fn mod_sqrt_s_6() {
280        // s = 6
281        // K-256 scalar field
282        let monty_params = FixedMontyParams::new_vartime(Odd::<U256>::from_be_hex(
283            "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141",
284        ));
285        let prime_params = PrimeParams::new_vartime(&monty_params, 7);
286        assert_eq!(prime_params.s.get(), 6);
287        assert_eq!(prime_params.generator.get(), 7);
288        assert_eq!(
289            root_of_unity(&monty_params, &prime_params),
290            U256::from_be_hex("0C1DC060E7A91986DF9879A3FBC483A898BDEAB680756045992F4B5402B052F2")
291        );
292
293        test_monty_sqrt(monty_params, prime_params);
294    }
295}