1use crate::{
2 Choice, CtOption, Uint,
3 modular::{FixedMontyForm, FixedMontyParams, prime_params::PrimeParams},
4};
5
6#[must_use]
9pub const fn sqrt_montgomery_form<const LIMBS: usize>(
10 monty_value: &Uint<LIMBS>,
11 monty_params: &FixedMontyParams<LIMBS>,
12 prime_params: &PrimeParams<LIMBS>,
13) -> CtOption<Uint<LIMBS>> {
14 let value = FixedMontyForm::from_montgomery(*monty_value, monty_params);
15 let b = value.pow_vartime(&prime_params.sqrt_exp);
16
17 let x = match prime_params.s.get() {
22 1 => {
23 b
25 }
26 2 => {
27 let ru = FixedMontyForm::from_montgomery(prime_params.monty_root_unity, monty_params);
29 let cb = value.mul(&b);
30 let zeta = cb.mul(&b);
31 let is_one = Uint::eq(zeta.as_montgomery(), monty_params.one());
32 monty_select(&cb.mul(&ru), &cb, is_one)
33 }
34 3 => {
35 let ru = FixedMontyForm::from_montgomery(prime_params.monty_root_unity, monty_params);
37 let ru_2 =
38 FixedMontyForm::from_montgomery(prime_params.monty_root_unity_p2, monty_params);
39 let ru_3 = ru.mul(&ru_2);
40 let cb = value.mul(&b);
41 let zeta = cb.mul(&b);
42
43 let mut m = monty_select(
44 &ru,
45 &FixedMontyForm::one(ru.params()),
46 Uint::eq(zeta.as_montgomery(), monty_params.one()),
47 );
48 m = monty_select(
50 &m,
51 &ru_2,
52 Uint::eq(zeta.neg().as_montgomery(), monty_params.one()),
53 );
54 m = monty_select(&m, &ru_3, monty_eq(&zeta, &ru_2));
56
57 cb.mul(&m)
58 }
59 4 => {
60 let ru = FixedMontyForm::from_montgomery(prime_params.monty_root_unity, monty_params);
62 let ru_2 =
63 FixedMontyForm::from_montgomery(prime_params.monty_root_unity_p2, monty_params);
64 let ru_4 = ru_2.square();
65 let ru_6 = ru_2.mul(&ru_4);
66 let cb = value.mul(&b);
67 let zeta = cb.mul(&b);
68
69 let neg_zeta = zeta.neg();
70 let zeta_b = monty_eq(&zeta, &ru_2);
71 let neg_zeta_b = monty_eq(&neg_zeta, &ru_2);
72 let zeta_d = monty_eq(&zeta, &ru_6);
73
74 let mut m = monty_select(
76 &FixedMontyForm::one(ru.params()),
77 &ru_2,
78 neg_zeta_b.or(monty_eq(&neg_zeta, &ru_4)),
79 );
80 m = monty_select(
82 &m,
83 &ru_4,
84 Uint::eq(neg_zeta.as_montgomery(), monty_params.one()).or(zeta_d),
85 );
86 m = monty_select(&m, &ru_6, zeta_b.or(monty_eq(&zeta, &ru_4)));
88 m = monty_select(
90 &m,
91 &m.mul(&ru),
92 zeta_b
93 .or(zeta_d)
94 .or(neg_zeta_b)
95 .or(monty_eq(&neg_zeta, &ru_6)),
96 );
97
98 cb.mul(&m)
99 }
100 _ => {
101 let mut x = value.mul(&b);
103 let mut d = x.mul(&b);
104 let mut z =
105 FixedMontyForm::from_montgomery(prime_params.monty_root_unity, monty_params);
106 let mut v = prime_params.s.get();
107 let mut max_v = v;
108
109 while max_v >= 1 {
110 let mut k = 1;
111 let mut tmp = d.square();
112 let mut j_less_than_v = Choice::TRUE;
113
114 let mut j = 2;
115 while j < max_v {
116 let tmp_is_one = Uint::eq(tmp.as_montgomery(), monty_params.one());
117 let squared = monty_select(&tmp, &z, tmp_is_one).square();
118 tmp = monty_select(&squared, &tmp, tmp_is_one);
119 j_less_than_v = j_less_than_v.and(Choice::from_u32_eq(j, v).not());
120 z = monty_select(&z, &squared, tmp_is_one.and(j_less_than_v));
121 k = tmp_is_one.select_u32(j, k);
122 j += 1;
123 }
124
125 let b_is_one = Uint::eq(d.as_montgomery(), monty_params.one());
126 x = monty_select(&x.mul(&z), &x, b_is_one);
127 z = z.square();
128 d = d.mul(&z);
129 v = k;
130 max_v -= 1;
131 }
132
133 x
134 }
135 };
136
137 CtOption::new(x.to_montgomery(), monty_eq(&x.square(), &value))
138}
139
140const fn monty_eq<const LIMBS: usize>(
141 a: &FixedMontyForm<LIMBS>,
142 b: &FixedMontyForm<LIMBS>,
143) -> Choice {
144 Uint::eq(a.as_montgomery(), b.as_montgomery())
145}
146
147const fn monty_select<const LIMBS: usize>(
148 a: &FixedMontyForm<LIMBS>,
149 b: &FixedMontyForm<LIMBS>,
150 c: Choice,
151) -> FixedMontyForm<LIMBS> {
152 FixedMontyForm::from_montgomery(
153 Uint::select(a.as_montgomery(), b.as_montgomery(), c),
154 a.params(),
155 )
156}
157
158#[cfg(test)]
159mod tests {
160 use super::sqrt_montgomery_form;
161 use crate::{
162 Odd, U256, U576, Uint,
163 modular::{FixedMontyForm, FixedMontyParams, PrimeParams},
164 };
165
166 fn root_of_unity<const LIMBS: usize>(
167 monty_params: &FixedMontyParams<LIMBS>,
168 prime_params: &PrimeParams<LIMBS>,
169 ) -> Uint<LIMBS> {
170 FixedMontyForm::from_montgomery(prime_params.monty_root_unity, monty_params).retrieve()
171 }
172
173 fn test_monty_sqrt<const LIMBS: usize>(
174 monty_params: FixedMontyParams<LIMBS>,
175 prime_params: PrimeParams<LIMBS>,
176 ) {
177 let modulus = monty_params.modulus.get();
178 let rounds = if cfg!(miri) { 1..=2 } else { 0..=256 };
179 for i in rounds {
180 let s = i * i;
181 let s_monty = FixedMontyForm::new(&Uint::from_u32(s), &monty_params);
182 let rt_monty =
183 sqrt_montgomery_form(s_monty.as_montgomery(), &monty_params, &prime_params)
184 .expect("no sqrt found");
185 let rt = FixedMontyForm::from_montgomery(rt_monty, &monty_params).retrieve();
186 let i = Uint::from_u32(i);
187 assert!(
188 Uint::eq(&rt, &i)
189 .or(Uint::eq(&rt, &modulus.wrapping_sub(&i)))
190 .to_bool_vartime()
191 );
192 }
193
194 let generator = Uint::from_u32(prime_params.generator.get());
196 let gen_monty = FixedMontyForm::new(&generator, &monty_params);
197 assert!(
198 sqrt_montgomery_form(gen_monty.as_montgomery(), &monty_params, &prime_params)
199 .is_none()
200 .to_bool_vartime()
201 );
202 }
203
204 #[test]
205 fn mod_sqrt_s_1() {
206 let monty_params = FixedMontyParams::new_vartime(Odd::<U256>::from_be_hex(
209 "ffffffff00000001000000000000000000000000ffffffffffffffffffffffff",
210 ));
211 let prime_params = PrimeParams::new_vartime(&monty_params, 6);
212 assert_eq!(prime_params.s.get(), 1);
213 assert_eq!(prime_params.generator.get(), 6);
214 assert_eq!(
215 root_of_unity(&monty_params, &prime_params),
216 U256::from_be_hex("FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFE")
217 );
218
219 test_monty_sqrt(monty_params, prime_params);
220 }
221
222 #[test]
223 fn mod_sqrt_s_2() {
224 let monty_params = FixedMontyParams::new_vartime(Odd::<U256>::from_be_hex(
227 "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed",
228 ));
229 let prime_params = PrimeParams::new_vartime(&monty_params, 2);
230 assert_eq!(prime_params.s.get(), 2);
231 assert_eq!(prime_params.generator.get(), 2);
232 assert_eq!(
233 root_of_unity(&monty_params, &prime_params),
234 U256::from_be_hex("2B8324804FC1DF0B2B4D00993DFBD7A72F431806AD2FE478C4EE1B274A0EA0B0")
235 );
236
237 test_monty_sqrt(monty_params, prime_params);
238 }
239
240 #[test]
241 fn mod_sqrt_s_3() {
242 let monty_params = FixedMontyParams::new_vartime(Odd::<U576>::from_be_hex(
245 "00000000000001fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffa51868783bf2f966b7fcc0148f709a5d03bb5c9b8899c47aebb6fb71e91386409",
246 ));
247 let prime_params = PrimeParams::new_vartime(&monty_params, 3);
248 assert_eq!(prime_params.s.get(), 3);
249 assert_eq!(prime_params.generator.get(), 3);
250 assert_eq!(
251 root_of_unity(&monty_params, &prime_params),
252 U576::from_be_hex(
253 "000000000000009a0a650d44b28c17f3d708ad2fa8c4fbc7e6000d7c12dafa92fcc5673a3055276d535f79ff391dcdbcd998b7836647d3a72472b3da861ac810a7f9c7b7b63e2205"
254 )
255 );
256
257 test_monty_sqrt(monty_params, prime_params);
258 }
259
260 #[test]
261 fn mod_sqrt_s_4() {
262 let monty_params = FixedMontyParams::new_vartime(Odd::<U256>::from_be_hex(
265 "ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551",
266 ));
267 let prime_params = PrimeParams::new_vartime(&monty_params, 7);
268 assert_eq!(prime_params.s.get(), 4);
269 assert_eq!(prime_params.generator.get(), 7);
270 assert_eq!(
271 root_of_unity(&monty_params, &prime_params),
272 U256::from_be_hex("ffc97f062a770992ba807ace842a3dfc1546cad004378daf0592d7fbb41e6602")
273 );
274
275 test_monty_sqrt(monty_params, prime_params);
276 }
277
278 #[test]
279 fn mod_sqrt_s_6() {
280 let monty_params = FixedMontyParams::new_vartime(Odd::<U256>::from_be_hex(
283 "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141",
284 ));
285 let prime_params = PrimeParams::new_vartime(&monty_params, 7);
286 assert_eq!(prime_params.s.get(), 6);
287 assert_eq!(prime_params.generator.get(), 7);
288 assert_eq!(
289 root_of_unity(&monty_params, &prime_params),
290 U256::from_be_hex("0C1DC060E7A91986DF9879A3FBC483A898BDEAB680756045992F4B5402B052F2")
291 );
292
293 test_monty_sqrt(monty_params, prime_params);
294 }
295}