Skip to main content

rsa/algorithms/
rsa.rs

1//! Generic RSA implementation
2
3use core::cmp::Ordering;
4
5use crypto_bigint::{
6    modular::{BoxedMontyForm, BoxedMontyParams},
7    BoxedUint, ConcatenatingMul, ConcatenatingSquare, Gcd, NonZero, Odd, RandomMod, Resize,
8};
9use rand_core::TryCryptoRng;
10use zeroize::Zeroize;
11
12use crate::errors::{Error, Result};
13use crate::traits::keys::{PrivateKeyParts, PublicKeyParts};
14
15/// ⚠️ Raw RSA encryption of m with the public key. No padding is performed.
16///
17/// # ☢️️ WARNING: HAZARDOUS API ☢️
18///
19/// Use this function with great care! Raw RSA should never be used without an appropriate padding
20/// or signature scheme. See the [module-level documentation][crate::hazmat] for more information.
21#[inline]
22pub fn rsa_encrypt<K: PublicKeyParts>(key: &K, m: &BoxedUint) -> Result<BoxedUint> {
23    let e = key.e();
24    let res = pow_mod_params_vartime_exp_bits(m, e, e.bits(), key.n_params());
25    Ok(res)
26}
27
28/// ⚠️ Performs raw RSA decryption with no padding or error checking.
29///
30/// Returns a plaintext `BoxedUint`. Performs RSA blinding if an `Rng` is passed.
31///
32/// # ☢️️ WARNING: HAZARDOUS API ☢️
33///
34/// Use this function with great care! Raw RSA should never be used without an appropriate padding
35/// or signature scheme. See the [module-level documentation][crate::hazmat] for more information.
36#[inline]
37pub fn rsa_decrypt<R: TryCryptoRng + ?Sized>(
38    rng: Option<&mut R>,
39    priv_key: &impl PrivateKeyParts,
40    c: &BoxedUint,
41) -> Result<BoxedUint> {
42    let n = priv_key.n();
43    let d = priv_key.d();
44
45    if c.bits_precision() != n.as_ref().bits_precision() {
46        return Err(Error::Decryption);
47    }
48
49    if c >= n.as_ref() {
50        return Err(Error::Decryption);
51    }
52
53    let mut ir = None;
54
55    let n_params = priv_key.n_params();
56    let bits = d.bits_precision();
57
58    let c = if let Some(rng) = rng {
59        let (blinded, unblinder) = blind(rng, priv_key, c, n_params)?;
60        ir = Some(unblinder);
61        blinded.try_resize(bits).ok_or(Error::Internal)?
62    } else {
63        c.try_resize(bits).ok_or(Error::Internal)?
64    };
65
66    let is_multiprime = priv_key.primes().len() > 2;
67
68    let m = match (
69        priv_key.dp(),
70        priv_key.dq(),
71        priv_key.qinv(),
72        priv_key.p_params(),
73        priv_key.q_params(),
74    ) {
75        (Some(dp), Some(dq), Some(qinv), Some(p_params), Some(q_params)) if !is_multiprime => {
76            // We have the precalculated values needed for the CRT.
77
78            let p = &priv_key.primes()[0];
79            let q = &priv_key.primes()[1];
80
81            // precomputed: dP = (1/e) mod (p-1) = d mod (p-1)
82            // precomputed: dQ = (1/e) mod (q-1) = d mod (q-1)
83
84            // TODO: it may be faster to convert to and from Montgomery with prepared parameters
85            // (modulo `p` and `q`) rather than calculating the remainder directly.
86
87            // m1 = c^dP mod p
88            let p_wide = p_params.modulus().resize_unchecked(c.bits_precision());
89            let c_mod_dp = (&c % p_wide.as_nz_ref()).resize_unchecked(dp.bits_precision());
90            let cp = BoxedMontyForm::new(c_mod_dp, p_params);
91            let mut m1 = cp.pow(dp);
92            // m2 = c^dQ mod q
93            let q_wide = q_params.modulus().resize_unchecked(c.bits_precision());
94            let c_mod_dq = (&c % q_wide.as_nz_ref()).resize_unchecked(dq.bits_precision());
95            let cq = BoxedMontyForm::new(c_mod_dq, q_params);
96            let m2 = cq.pow(dq).retrieve();
97
98            // Note that since `p` and `q` may have different `bits_precision`,
99            // it may be different for `m1` and `m2` as well.
100
101            // (m1 - m2) mod p = (m1 mod p) - (m2 mod p) mod p
102            let m2_mod_p = match p_params.bits_precision().cmp(&q_params.bits_precision()) {
103                Ordering::Less => {
104                    let p_wide = NonZero::new(p.clone())
105                        .expect("`p` is non-zero")
106                        .resize_unchecked(q_params.bits_precision());
107                    (&m2 % p_wide).resize_unchecked(p_params.bits_precision())
108                }
109                Ordering::Greater => (&m2).resize_unchecked(p_params.bits_precision()),
110                Ordering::Equal => m2.clone(),
111            };
112            let m2r = BoxedMontyForm::new(m2_mod_p, p_params);
113            m1 -= &m2r;
114
115            // precomputed: qInv = (1/q) mod p
116
117            // h = qInv.(m1 - m2) mod p
118            let h = (qinv * m1).retrieve();
119
120            // m = m2 + h.q
121            let m2 = m2.try_resize(n.bits_precision()).ok_or(Error::Internal)?;
122            let hq = h
123                .concatenating_mul(&q)
124                .try_resize(n.bits_precision())
125                .ok_or(Error::Internal)?;
126            m2.wrapping_add(&hq)
127        }
128        _ => {
129            // c^d (mod n)
130            pow_mod_params(&c, d, n_params)
131        }
132    };
133
134    match ir {
135        Some(ref ir) => {
136            // unblind
137            let res = unblind(&m, ir, n_params);
138            Ok(res)
139        }
140        None => Ok(m),
141    }
142}
143
144/// ⚠️ Performs raw RSA decryption with no padding.
145///
146/// Returns a plaintext `BoxedUint`. Performs RSA blinding if an `Rng` is passed.  This will also
147/// check for errors in the CRT computation.
148///
149/// `c` must have the same `bits_precision` as the RSA key modulus.
150///
151/// # ☢️️ WARNING: HAZARDOUS API ☢️
152///
153/// Use this function with great care! Raw RSA should never be used without an appropriate padding
154/// or signature scheme. See the [module-level documentation][crate::hazmat] for more information.
155#[inline]
156pub fn rsa_decrypt_and_check<R: TryCryptoRng + ?Sized>(
157    priv_key: &impl PrivateKeyParts,
158    rng: Option<&mut R>,
159    c: &BoxedUint,
160) -> Result<BoxedUint> {
161    let m = rsa_decrypt(rng, priv_key, c)?;
162
163    // In order to defend against errors in the CRT computation, m^e is
164    // calculated, which should match the original ciphertext.
165    let check = rsa_encrypt(priv_key, &m)?;
166
167    if c != &check {
168        return Err(Error::Internal);
169    }
170
171    Ok(m)
172}
173
174/// Returns the blinded c, along with the unblinding factor.
175fn blind<R: TryCryptoRng + ?Sized, K: PublicKeyParts>(
176    rng: &mut R,
177    key: &K,
178    c: &BoxedUint,
179    n_params: &BoxedMontyParams,
180) -> Result<(BoxedUint, BoxedUint)> {
181    // Blinding involves multiplying c by r^e.
182    // Then the decryption operation performs (m^e * r^e)^d mod n
183    // which equals mr mod n. The factor of r can then be removed
184    // by multiplying by the multiplicative inverse of r.
185    debug_assert_eq!(&key.n().clone().get(), n_params.modulus());
186    let bits = key.n_bits_precision();
187
188    let mut r: BoxedUint = BoxedUint::zero_with_precision(bits);
189    let mut ir: Option<BoxedUint> = None;
190    while ir.is_none() {
191        r = BoxedUint::try_random_mod_vartime(rng, key.n()).map_err(|_| Error::Rng)?;
192
193        // r^-1 (mod n)
194        ir = r.invert_mod(key.n()).into();
195    }
196
197    let blinded = {
198        // r^e (mod n)
199        let e = key.e();
200        let mut rpowe = pow_mod_params_vartime_exp_bits(&r, e, e.bits(), n_params);
201        // c * r^e (mod n)
202        let c = c.mul_mod(&rpowe, n_params.modulus().as_nz_ref());
203        rpowe.zeroize();
204
205        c
206    };
207
208    let ir = ir.expect("loop exited");
209    debug_assert_eq!(blinded.bits_precision(), bits);
210    debug_assert_eq!(ir.bits_precision(), bits);
211
212    Ok((blinded, ir))
213}
214
215/// Given an m and unblinding factor, unblind the m.
216fn unblind(m: &BoxedUint, unblinder: &BoxedUint, n_params: &BoxedMontyParams) -> BoxedUint {
217    // m * r^-1 (mod n)
218    debug_assert_eq!(
219        m.bits_precision(),
220        unblinder.bits_precision(),
221        "invalid unblinder"
222    );
223
224    debug_assert_eq!(
225        m.bits_precision(),
226        n_params.bits_precision(),
227        "invalid n_params"
228    );
229
230    m.mul_mod(unblinder, n_params.modulus().as_nz_ref())
231}
232
233/// Computes `base.pow_mod(exp, n)` with precomputed `n_params`.
234fn pow_mod_params(base: &BoxedUint, exp: &BoxedUint, n_params: &BoxedMontyParams) -> BoxedUint {
235    let base = reduce_vartime(base, n_params);
236    base.pow(exp).retrieve()
237}
238
239/// Computes `base.pow_mod(exp, n)` with a bounded exponent and precomputed `n_params`.
240///
241/// The exponent bit length `exp_bits` may be leaked in the time pattern.
242fn pow_mod_params_vartime_exp_bits(
243    base: &BoxedUint,
244    exp: &BoxedUint,
245    exp_bits: u32,
246    n_params: &BoxedMontyParams,
247) -> BoxedUint {
248    let base = reduce_vartime(base, n_params);
249    base.pow_bounded_exp(exp, exp_bits).retrieve()
250}
251
252fn reduce_vartime(n: &BoxedUint, p: &BoxedMontyParams) -> BoxedMontyForm {
253    let modulus = p.modulus().as_nz_ref().clone();
254    let n_reduced = n.rem_vartime(&modulus).resize_unchecked(p.bits_precision());
255    BoxedMontyForm::new(n_reduced, p)
256}
257
258/// The following (deterministic) algorithm also recovers the prime factors `p` and `q` of a modulus `n`, given the
259/// public exponent `e` and private exponent `d` using the method described in
260/// [NIST 800-56B Appendix C.2](https://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-56Br2.pdf).
261pub fn recover_primes(
262    n: &NonZero<BoxedUint>,
263    e: &BoxedUint,
264    d: &BoxedUint,
265) -> Result<(BoxedUint, BoxedUint)> {
266    // Check precondition
267
268    // Note: because e is at most u64::MAX, it is already
269    // known to be < 2**256
270    if e <= &BoxedUint::from(2u64.pow(16)) {
271        return Err(Error::InvalidArguments);
272    }
273
274    // 1. Let a = (de – 1) × GCD(n – 1, de – 1).
275    let bits = d.bits_precision() * 2;
276    let one = BoxedUint::one_with_precision(bits);
277    let e = e.resize_unchecked(d.bits_precision());
278    let d = d.resize_unchecked(d.bits_precision());
279    let n = n.resize_unchecked(bits);
280
281    let a1 = d.concatenating_mul(&e) - &one;
282    let a2 = (n.as_ref() - &one).gcd(&a1);
283    let a = a1.concatenating_mul(&a2);
284    let n = n.resize_unchecked(a.bits_precision());
285
286    // 2. Let m = floor(a /n) and r = a – m n, so that a = m n + r and 0 ≤ r < n.
287    let m = &a / &n;
288    let r = a - m.concatenating_mul(&*n);
289    let n = n.get();
290
291    // 3. Let b = ( (n – r)/(m + 1) ) + 1; if b is not an integer or b^2 ≤ 4n, then output an error indicator,
292    //    and exit without further processing.
293    let modulus_check = (&n - &r) % NonZero::new(&m + &one).expect("adding 1");
294    if (!modulus_check.is_zero()).into() {
295        return Err(Error::InvalidArguments);
296    }
297    let b = ((&n - &r) / NonZero::new(&m + &one).expect("adding one")) + one;
298
299    let four = BoxedUint::from(4u32);
300    let four_n = n.concatenating_mul(&four);
301    let b_squared = b.concatenating_square();
302
303    if b_squared <= four_n {
304        return Err(Error::InvalidArguments);
305    }
306    let b_squared_minus_four_n = b_squared - four_n;
307
308    // 4. Let ϒ be the positive square root of b^2 – 4n; if ϒ is not an integer,
309    //    then output an error indicator, and exit without further processing.
310    let y = b_squared_minus_four_n.floor_sqrt();
311
312    let y_squared = y.concatenating_square();
313    let sqrt_is_whole_number = y_squared == b_squared_minus_four_n;
314    if !sqrt_is_whole_number {
315        return Err(Error::InvalidArguments);
316    }
317
318    let bits = core::cmp::max(b.bits_precision(), y.bits_precision());
319    let two = NonZero::new(BoxedUint::from(2u64))
320        .expect("2 is non zero")
321        .resize_unchecked(bits);
322    let p = (&b + &y) / &two;
323    let q = (b - y) / two;
324
325    Ok((p, q))
326}
327
328/// Compute the modulus of a key from its primes.
329pub(crate) fn compute_modulus(primes: &[BoxedUint]) -> Odd<BoxedUint> {
330    let mut primes = primes.iter();
331    let mut out = primes.next().expect("must at least be one prime").clone();
332    for p in primes {
333        out = out.concatenating_mul(&p);
334    }
335    Odd::new(out).expect("modulus must be odd")
336}
337
338/// Compute the private exponent from its primes (p and q) and public exponent
339/// This uses Euler's totient function
340#[inline]
341pub(crate) fn compute_private_exponent_euler_totient(
342    primes: &[BoxedUint],
343    exp: &BoxedUint,
344) -> Result<BoxedUint> {
345    if primes.len() < 2 {
346        return Err(Error::InvalidPrime);
347    }
348    let bits = primes[0].bits_precision();
349    let mut totient = BoxedUint::one_with_precision(bits);
350
351    for prime in primes {
352        totient = totient.concatenating_mul(&(prime - &BoxedUint::one()));
353    }
354    let exp = exp.resize_unchecked(totient.bits_precision());
355
356    // NOTE: `mod_inverse` checks if `exp` evenly divides `totient` and returns `None` if so.
357    // This ensures that `exp` is not a factor of any `(prime - 1)`.
358    let totient = NonZero::new(totient).expect("known");
359    match exp.invert_mod(&totient).into_option() {
360        Some(res) => Ok(res),
361        None => Err(Error::InvalidPrime),
362    }
363}
364
365/// Compute the private exponent from its primes (p and q) and public exponent
366///
367/// This is using the method defined by
368/// [NIST 800-56B Section 6.2.1](https://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-56Br2.pdf#page=47).
369/// (Carmichael function)
370///
371/// FIPS 186-4 **requires** the private exponent to be less than λ(n), which would
372/// make Euler's totiem unreliable.
373#[inline]
374pub(crate) fn compute_private_exponent_carmicheal(
375    p: &BoxedUint,
376    q: &BoxedUint,
377    exp: &BoxedUint,
378) -> Result<BoxedUint> {
379    let one = BoxedUint::one();
380    let p1 = p - &one;
381    let q1 = q - &one;
382
383    // LCM inlined
384    let gcd = p1.gcd(&q1);
385    let lcm = (p1 / NonZero::new(gcd).expect("gcd is non zero")).concatenating_mul(&q1);
386    let exp = exp.resize_unchecked(lcm.bits_precision());
387    if let Some(d) = exp.invert_mod(&NonZero::new(lcm).expect("non zero")).into() {
388        Ok(d)
389    } else {
390        // `exp` evenly divides `lcm`
391        Err(Error::InvalidPrime)
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398
399    #[test]
400    fn recover_primes_works() {
401        let bits = 2048;
402
403        let n = BoxedUint::from_be_hex(
404            concat!(
405                "d397b84d98a4c26138ed1b695a8106ead91d553bf06041b62d3fdc50a041e222",
406                "b8f4529689c1b82c5e71554f5dd69fa2f4b6158cf0dbeb57811a0fc327e1f28e",
407                "74fe74d3bc166c1eabdc1b8b57b934ca8be5b00b4f29975bcc99acaf415b59bb",
408                "28a6782bb41a2c3c2976b3c18dbadef62f00c6bb226640095096c0cc60d22fe7",
409                "ef987d75c6a81b10d96bf292028af110dc7cc1bbc43d22adab379a0cd5d8078c",
410                "c780ff5cd6209dea34c922cf784f7717e428d75b5aec8ff30e5f0141510766e2",
411                "e0ab8d473c84e8710b2b98227c3db095337ad3452f19e2b9bfbccdd8148abf67",
412                "76fa552775e6e75956e45229ae5a9c46949bab1e622f0e48f56524a84ed3483b"
413            ),
414            bits,
415        )
416        .unwrap();
417        let e = BoxedUint::from(65_537u64);
418        let d = BoxedUint::from_be_hex(
419            concat!(
420                "c4e70c689162c94c660828191b52b4d8392115df486a9adbe831e458d7395832",
421                "0dc1b755456e93701e9702d76fb0b92f90e01d1fe248153281fe79aa9763a92f",
422                "ae69d8d7ecd144de29fa135bd14f9573e349e45031e3b76982f583003826c552",
423                "e89a397c1a06bd2163488630d92e8c2bb643d7abef700da95d685c941489a46f",
424                "54b5316f62b5d2c3a7f1bbd134cb37353a44683fdc9d95d36458de22f6c44057",
425                "fe74a0a436c4308f73f4da42f35c47ac16a7138d483afc91e41dc3a1127382e0",
426                "c0f5119b0221b4fc639d6b9c38177a6de9b526ebd88c38d7982c07f98a0efd87",
427                "7d508aae275b946915c02e2e1106d175d74ec6777f5e80d12c053d9c7be1e341"
428            ),
429            bits,
430        )
431        .unwrap();
432        let p = BoxedUint::from_be_hex(
433            concat!(
434                "f827bbf3a41877c7cc59aebf42ed4b29c32defcb8ed96863d5b090a05a8930dd",
435                "624a21c9dcf9838568fdfa0df65b8462a5f2ac913d6c56f975532bd8e78fb07b",
436                "d405ca99a484bcf59f019bbddcb3933f2bce706300b4f7b110120c5df9018159",
437                "067c35da3061a56c8635a52b54273b31271b4311f0795df6021e6355e1a42e61"
438            ),
439            bits / 2,
440        )
441        .unwrap();
442        let q = BoxedUint::from_be_hex(
443            concat!(
444                "da4817ce0089dd36f2ade6a3ff410c73ec34bf1b4f6bda38431bfede11cef1f7",
445                "f6efa70e5f8063a3b1f6e17296ffb15feefa0912a0325b8d1fd65a559e717b5b",
446                "961ec345072e0ec5203d03441d29af4d64054a04507410cf1da78e7b6119d909",
447                "ec66e6ad625bf995b279a4b3c5be7d895cd7c5b9c4c497fde730916fcdb4e41b"
448            ),
449            bits / 2,
450        )
451        .unwrap();
452
453        let (mut p1, mut q1) = recover_primes(&NonZero::new(n).unwrap(), &e, &d).unwrap();
454
455        if p1 < q1 {
456            std::mem::swap(&mut p1, &mut q1);
457        }
458        assert_eq!(p, p1);
459        assert_eq!(q, q1);
460    }
461}