1use core::cmp::Ordering;
4
5use crypto_bigint::{
6 modular::{BoxedMontyForm, BoxedMontyParams},
7 BoxedUint, ConcatenatingMul, ConcatenatingSquare, Gcd, NonZero, Odd, RandomMod, Resize,
8};
9use rand_core::TryCryptoRng;
10use zeroize::Zeroize;
11
12use crate::errors::{Error, Result};
13use crate::traits::keys::{PrivateKeyParts, PublicKeyParts};
14
15#[inline]
22pub fn rsa_encrypt<K: PublicKeyParts>(key: &K, m: &BoxedUint) -> Result<BoxedUint> {
23 let e = key.e();
24 let res = pow_mod_params_vartime_exp_bits(m, e, e.bits(), key.n_params());
25 Ok(res)
26}
27
28#[inline]
37pub fn rsa_decrypt<R: TryCryptoRng + ?Sized>(
38 rng: Option<&mut R>,
39 priv_key: &impl PrivateKeyParts,
40 c: &BoxedUint,
41) -> Result<BoxedUint> {
42 let n = priv_key.n();
43 let d = priv_key.d();
44
45 if c.bits_precision() != n.as_ref().bits_precision() {
46 return Err(Error::Decryption);
47 }
48
49 if c >= n.as_ref() {
50 return Err(Error::Decryption);
51 }
52
53 let mut ir = None;
54
55 let n_params = priv_key.n_params();
56 let bits = d.bits_precision();
57
58 let c = if let Some(rng) = rng {
59 let (blinded, unblinder) = blind(rng, priv_key, c, n_params)?;
60 ir = Some(unblinder);
61 blinded.try_resize(bits).ok_or(Error::Internal)?
62 } else {
63 c.try_resize(bits).ok_or(Error::Internal)?
64 };
65
66 let is_multiprime = priv_key.primes().len() > 2;
67
68 let m = match (
69 priv_key.dp(),
70 priv_key.dq(),
71 priv_key.qinv(),
72 priv_key.p_params(),
73 priv_key.q_params(),
74 ) {
75 (Some(dp), Some(dq), Some(qinv), Some(p_params), Some(q_params)) if !is_multiprime => {
76 let p = &priv_key.primes()[0];
79 let q = &priv_key.primes()[1];
80
81 let p_wide = p_params.modulus().resize_unchecked(c.bits_precision());
89 let c_mod_dp = (&c % p_wide.as_nz_ref()).resize_unchecked(dp.bits_precision());
90 let cp = BoxedMontyForm::new(c_mod_dp, p_params);
91 let mut m1 = cp.pow(dp);
92 let q_wide = q_params.modulus().resize_unchecked(c.bits_precision());
94 let c_mod_dq = (&c % q_wide.as_nz_ref()).resize_unchecked(dq.bits_precision());
95 let cq = BoxedMontyForm::new(c_mod_dq, q_params);
96 let m2 = cq.pow(dq).retrieve();
97
98 let m2_mod_p = match p_params.bits_precision().cmp(&q_params.bits_precision()) {
103 Ordering::Less => {
104 let p_wide = NonZero::new(p.clone())
105 .expect("`p` is non-zero")
106 .resize_unchecked(q_params.bits_precision());
107 (&m2 % p_wide).resize_unchecked(p_params.bits_precision())
108 }
109 Ordering::Greater => (&m2).resize_unchecked(p_params.bits_precision()),
110 Ordering::Equal => m2.clone(),
111 };
112 let m2r = BoxedMontyForm::new(m2_mod_p, p_params);
113 m1 -= &m2r;
114
115 let h = (qinv * m1).retrieve();
119
120 let m2 = m2.try_resize(n.bits_precision()).ok_or(Error::Internal)?;
122 let hq = h
123 .concatenating_mul(&q)
124 .try_resize(n.bits_precision())
125 .ok_or(Error::Internal)?;
126 m2.wrapping_add(&hq)
127 }
128 _ => {
129 pow_mod_params(&c, d, n_params)
131 }
132 };
133
134 match ir {
135 Some(ref ir) => {
136 let res = unblind(&m, ir, n_params);
138 Ok(res)
139 }
140 None => Ok(m),
141 }
142}
143
144#[inline]
156pub fn rsa_decrypt_and_check<R: TryCryptoRng + ?Sized>(
157 priv_key: &impl PrivateKeyParts,
158 rng: Option<&mut R>,
159 c: &BoxedUint,
160) -> Result<BoxedUint> {
161 let m = rsa_decrypt(rng, priv_key, c)?;
162
163 let check = rsa_encrypt(priv_key, &m)?;
166
167 if c != &check {
168 return Err(Error::Internal);
169 }
170
171 Ok(m)
172}
173
174fn blind<R: TryCryptoRng + ?Sized, K: PublicKeyParts>(
176 rng: &mut R,
177 key: &K,
178 c: &BoxedUint,
179 n_params: &BoxedMontyParams,
180) -> Result<(BoxedUint, BoxedUint)> {
181 debug_assert_eq!(&key.n().clone().get(), n_params.modulus());
186 let bits = key.n_bits_precision();
187
188 let mut r: BoxedUint = BoxedUint::zero_with_precision(bits);
189 let mut ir: Option<BoxedUint> = None;
190 while ir.is_none() {
191 r = BoxedUint::try_random_mod_vartime(rng, key.n()).map_err(|_| Error::Rng)?;
192
193 ir = r.invert_mod(key.n()).into();
195 }
196
197 let blinded = {
198 let e = key.e();
200 let mut rpowe = pow_mod_params_vartime_exp_bits(&r, e, e.bits(), n_params);
201 let c = c.mul_mod(&rpowe, n_params.modulus().as_nz_ref());
203 rpowe.zeroize();
204
205 c
206 };
207
208 let ir = ir.expect("loop exited");
209 debug_assert_eq!(blinded.bits_precision(), bits);
210 debug_assert_eq!(ir.bits_precision(), bits);
211
212 Ok((blinded, ir))
213}
214
215fn unblind(m: &BoxedUint, unblinder: &BoxedUint, n_params: &BoxedMontyParams) -> BoxedUint {
217 debug_assert_eq!(
219 m.bits_precision(),
220 unblinder.bits_precision(),
221 "invalid unblinder"
222 );
223
224 debug_assert_eq!(
225 m.bits_precision(),
226 n_params.bits_precision(),
227 "invalid n_params"
228 );
229
230 m.mul_mod(unblinder, n_params.modulus().as_nz_ref())
231}
232
233fn pow_mod_params(base: &BoxedUint, exp: &BoxedUint, n_params: &BoxedMontyParams) -> BoxedUint {
235 let base = reduce_vartime(base, n_params);
236 base.pow(exp).retrieve()
237}
238
239fn pow_mod_params_vartime_exp_bits(
243 base: &BoxedUint,
244 exp: &BoxedUint,
245 exp_bits: u32,
246 n_params: &BoxedMontyParams,
247) -> BoxedUint {
248 let base = reduce_vartime(base, n_params);
249 base.pow_bounded_exp(exp, exp_bits).retrieve()
250}
251
252fn reduce_vartime(n: &BoxedUint, p: &BoxedMontyParams) -> BoxedMontyForm {
253 let modulus = p.modulus().as_nz_ref().clone();
254 let n_reduced = n.rem_vartime(&modulus).resize_unchecked(p.bits_precision());
255 BoxedMontyForm::new(n_reduced, p)
256}
257
258pub fn recover_primes(
262 n: &NonZero<BoxedUint>,
263 e: &BoxedUint,
264 d: &BoxedUint,
265) -> Result<(BoxedUint, BoxedUint)> {
266 if e <= &BoxedUint::from(2u64.pow(16)) {
271 return Err(Error::InvalidArguments);
272 }
273
274 let bits = d.bits_precision() * 2;
276 let one = BoxedUint::one_with_precision(bits);
277 let e = e.resize_unchecked(d.bits_precision());
278 let d = d.resize_unchecked(d.bits_precision());
279 let n = n.resize_unchecked(bits);
280
281 let a1 = d.concatenating_mul(&e) - &one;
282 let a2 = (n.as_ref() - &one).gcd(&a1);
283 let a = a1.concatenating_mul(&a2);
284 let n = n.resize_unchecked(a.bits_precision());
285
286 let m = &a / &n;
288 let r = a - m.concatenating_mul(&*n);
289 let n = n.get();
290
291 let modulus_check = (&n - &r) % NonZero::new(&m + &one).expect("adding 1");
294 if (!modulus_check.is_zero()).into() {
295 return Err(Error::InvalidArguments);
296 }
297 let b = ((&n - &r) / NonZero::new(&m + &one).expect("adding one")) + one;
298
299 let four = BoxedUint::from(4u32);
300 let four_n = n.concatenating_mul(&four);
301 let b_squared = b.concatenating_square();
302
303 if b_squared <= four_n {
304 return Err(Error::InvalidArguments);
305 }
306 let b_squared_minus_four_n = b_squared - four_n;
307
308 let y = b_squared_minus_four_n.floor_sqrt();
311
312 let y_squared = y.concatenating_square();
313 let sqrt_is_whole_number = y_squared == b_squared_minus_four_n;
314 if !sqrt_is_whole_number {
315 return Err(Error::InvalidArguments);
316 }
317
318 let bits = core::cmp::max(b.bits_precision(), y.bits_precision());
319 let two = NonZero::new(BoxedUint::from(2u64))
320 .expect("2 is non zero")
321 .resize_unchecked(bits);
322 let p = (&b + &y) / &two;
323 let q = (b - y) / two;
324
325 Ok((p, q))
326}
327
328pub(crate) fn compute_modulus(primes: &[BoxedUint]) -> Odd<BoxedUint> {
330 let mut primes = primes.iter();
331 let mut out = primes.next().expect("must at least be one prime").clone();
332 for p in primes {
333 out = out.concatenating_mul(&p);
334 }
335 Odd::new(out).expect("modulus must be odd")
336}
337
338#[inline]
341pub(crate) fn compute_private_exponent_euler_totient(
342 primes: &[BoxedUint],
343 exp: &BoxedUint,
344) -> Result<BoxedUint> {
345 if primes.len() < 2 {
346 return Err(Error::InvalidPrime);
347 }
348 let bits = primes[0].bits_precision();
349 let mut totient = BoxedUint::one_with_precision(bits);
350
351 for prime in primes {
352 totient = totient.concatenating_mul(&(prime - &BoxedUint::one()));
353 }
354 let exp = exp.resize_unchecked(totient.bits_precision());
355
356 let totient = NonZero::new(totient).expect("known");
359 match exp.invert_mod(&totient).into_option() {
360 Some(res) => Ok(res),
361 None => Err(Error::InvalidPrime),
362 }
363}
364
365#[inline]
374pub(crate) fn compute_private_exponent_carmicheal(
375 p: &BoxedUint,
376 q: &BoxedUint,
377 exp: &BoxedUint,
378) -> Result<BoxedUint> {
379 let one = BoxedUint::one();
380 let p1 = p - &one;
381 let q1 = q - &one;
382
383 let gcd = p1.gcd(&q1);
385 let lcm = (p1 / NonZero::new(gcd).expect("gcd is non zero")).concatenating_mul(&q1);
386 let exp = exp.resize_unchecked(lcm.bits_precision());
387 if let Some(d) = exp.invert_mod(&NonZero::new(lcm).expect("non zero")).into() {
388 Ok(d)
389 } else {
390 Err(Error::InvalidPrime)
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398
399 #[test]
400 fn recover_primes_works() {
401 let bits = 2048;
402
403 let n = BoxedUint::from_be_hex(
404 concat!(
405 "d397b84d98a4c26138ed1b695a8106ead91d553bf06041b62d3fdc50a041e222",
406 "b8f4529689c1b82c5e71554f5dd69fa2f4b6158cf0dbeb57811a0fc327e1f28e",
407 "74fe74d3bc166c1eabdc1b8b57b934ca8be5b00b4f29975bcc99acaf415b59bb",
408 "28a6782bb41a2c3c2976b3c18dbadef62f00c6bb226640095096c0cc60d22fe7",
409 "ef987d75c6a81b10d96bf292028af110dc7cc1bbc43d22adab379a0cd5d8078c",
410 "c780ff5cd6209dea34c922cf784f7717e428d75b5aec8ff30e5f0141510766e2",
411 "e0ab8d473c84e8710b2b98227c3db095337ad3452f19e2b9bfbccdd8148abf67",
412 "76fa552775e6e75956e45229ae5a9c46949bab1e622f0e48f56524a84ed3483b"
413 ),
414 bits,
415 )
416 .unwrap();
417 let e = BoxedUint::from(65_537u64);
418 let d = BoxedUint::from_be_hex(
419 concat!(
420 "c4e70c689162c94c660828191b52b4d8392115df486a9adbe831e458d7395832",
421 "0dc1b755456e93701e9702d76fb0b92f90e01d1fe248153281fe79aa9763a92f",
422 "ae69d8d7ecd144de29fa135bd14f9573e349e45031e3b76982f583003826c552",
423 "e89a397c1a06bd2163488630d92e8c2bb643d7abef700da95d685c941489a46f",
424 "54b5316f62b5d2c3a7f1bbd134cb37353a44683fdc9d95d36458de22f6c44057",
425 "fe74a0a436c4308f73f4da42f35c47ac16a7138d483afc91e41dc3a1127382e0",
426 "c0f5119b0221b4fc639d6b9c38177a6de9b526ebd88c38d7982c07f98a0efd87",
427 "7d508aae275b946915c02e2e1106d175d74ec6777f5e80d12c053d9c7be1e341"
428 ),
429 bits,
430 )
431 .unwrap();
432 let p = BoxedUint::from_be_hex(
433 concat!(
434 "f827bbf3a41877c7cc59aebf42ed4b29c32defcb8ed96863d5b090a05a8930dd",
435 "624a21c9dcf9838568fdfa0df65b8462a5f2ac913d6c56f975532bd8e78fb07b",
436 "d405ca99a484bcf59f019bbddcb3933f2bce706300b4f7b110120c5df9018159",
437 "067c35da3061a56c8635a52b54273b31271b4311f0795df6021e6355e1a42e61"
438 ),
439 bits / 2,
440 )
441 .unwrap();
442 let q = BoxedUint::from_be_hex(
443 concat!(
444 "da4817ce0089dd36f2ade6a3ff410c73ec34bf1b4f6bda38431bfede11cef1f7",
445 "f6efa70e5f8063a3b1f6e17296ffb15feefa0912a0325b8d1fd65a559e717b5b",
446 "961ec345072e0ec5203d03441d29af4d64054a04507410cf1da78e7b6119d909",
447 "ec66e6ad625bf995b279a4b3c5be7d895cd7c5b9c4c497fde730916fcdb4e41b"
448 ),
449 bits / 2,
450 )
451 .unwrap();
452
453 let (mut p1, mut q1) = recover_primes(&NonZero::new(n).unwrap(), &e, &d).unwrap();
454
455 if p1 < q1 {
456 std::mem::swap(&mut p1, &mut q1);
457 }
458 assert_eq!(p, p1);
459 assert_eq!(q, q1);
460 }
461}