Skip to main content

rsa/algorithms/
oaep.rs

1//! Encryption and Decryption using [OAEP padding](https://datatracker.ietf.org/doc/html/rfc8017#section-7.1).
2//!
3use alloc::boxed::Box;
4use alloc::vec::Vec;
5
6use crypto_bigint::{Choice, CtAssign, CtEq, CtOption};
7use digest::{Digest, FixedOutputReset};
8use rand_core::TryCryptoRng;
9use zeroize::Zeroizing;
10
11use super::mgf::{mgf1_xor, mgf1_xor_digest};
12use crate::errors::{Error, Result};
13
14/// Maximum label size (2^64 bits) for SHA-1 and SHA-256 hash functions.
15///
16/// In theory, other hash functions (e.g. SHA-512 and SHA-3) can process longer labels,
17/// but such huge inputs are practically impossible on one machine, so we use this limit
18/// for all hash functions.
19const MAX_LABEL_LEN: u64 = 1 << 61;
20
21#[inline]
22fn encrypt_internal<R: TryCryptoRng + ?Sized, MGF: FnMut(&mut [u8], &mut [u8])>(
23    rng: &mut R,
24    msg: &[u8],
25    p_hash: &[u8],
26    h_size: usize,
27    k: usize,
28    mut mgf: MGF,
29) -> Result<Zeroizing<Vec<u8>>> {
30    if msg.len() + 2 * h_size + 2 > k {
31        return Err(Error::MessageTooLong);
32    }
33
34    let mut em = Zeroizing::new(vec![0u8; k]);
35
36    let (_, payload) = em.split_at_mut(1);
37    let (seed, db) = payload.split_at_mut(h_size);
38    rng.try_fill_bytes(seed).map_err(|_| Error::Rng)?;
39
40    // Data block DB =  pHash || PS || 01 || M
41    let db_len = k - h_size - 1;
42
43    db[0..h_size].copy_from_slice(p_hash);
44    db[db_len - msg.len() - 1] = 1;
45    db[db_len - msg.len()..].copy_from_slice(msg);
46
47    mgf(seed, db);
48
49    Ok(em)
50}
51
52/// Encrypts the given message with RSA and the padding scheme from
53/// [PKCS#1 OAEP].
54///
55/// The message must be no longer than the length of the public modulus minus
56/// `2 + (2 * hash.size())`.
57///
58/// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1
59#[inline]
60pub(crate) fn oaep_encrypt<R, D, MGD>(
61    rng: &mut R,
62    msg: &[u8],
63    digest: &mut D,
64    mgf_digest: &mut MGD,
65    label: Option<Box<[u8]>>,
66    k: usize,
67) -> Result<Zeroizing<Vec<u8>>>
68where
69    R: TryCryptoRng + ?Sized,
70    D: Digest + FixedOutputReset,
71    MGD: Digest + FixedOutputReset,
72{
73    let h_size = <D as Digest>::output_size();
74
75    let label = label.unwrap_or_default();
76    if label.len() as u64 >= MAX_LABEL_LEN {
77        return Err(Error::LabelTooLong);
78    }
79
80    Digest::update(digest, &label);
81    let p_hash = digest.finalize_reset();
82
83    encrypt_internal(rng, msg, &p_hash, h_size, k, |seed, db| {
84        mgf1_xor(db, mgf_digest, seed);
85        mgf1_xor(seed, mgf_digest, db);
86    })
87}
88
89/// Encrypts the given message with RSA and the padding scheme from
90/// [PKCS#1 OAEP].
91///
92/// The message must be no longer than the length of the public modulus minus
93/// `2 + (2 * hash.size())`.
94///
95/// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1
96#[inline]
97pub(crate) fn oaep_encrypt_digest<R, D, MGD>(
98    rng: &mut R,
99    msg: &[u8],
100    label: Option<Box<[u8]>>,
101    k: usize,
102) -> Result<Zeroizing<Vec<u8>>>
103where
104    R: TryCryptoRng + ?Sized,
105    D: Digest,
106    MGD: Digest + FixedOutputReset,
107{
108    let h_size = <D as Digest>::output_size();
109
110    let label = label.unwrap_or_default();
111    if label.len() as u64 >= MAX_LABEL_LEN {
112        return Err(Error::LabelTooLong);
113    }
114
115    let p_hash = D::digest(&label);
116
117    encrypt_internal(rng, msg, &p_hash, h_size, k, |seed, db| {
118        let mut mgf_digest = MGD::new();
119        mgf1_xor_digest(db, &mut mgf_digest, seed);
120        mgf1_xor_digest(seed, &mut mgf_digest, db);
121    })
122}
123
124///Decrypts OAEP padding.
125///
126/// Note that whether this function returns an error or not discloses secret
127/// information. If an attacker can cause this function to run repeatedly and
128/// learn whether each instance returned an error then they can decrypt and
129/// forge signatures as if they had the private key.
130///
131/// See `decrypt_session_key` for a way of solving this problem.
132///
133/// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1
134#[inline]
135pub(crate) fn oaep_decrypt<D, MGD>(
136    em: &mut [u8],
137    digest: &mut D,
138    mgf_digest: &mut MGD,
139    label: Option<Box<[u8]>>,
140    k: usize,
141) -> Result<Vec<u8>>
142where
143    D: Digest + FixedOutputReset,
144    MGD: Digest + FixedOutputReset,
145{
146    let h_size = <D as Digest>::output_size();
147
148    let label = label.unwrap_or_default();
149    if label.len() as u64 >= MAX_LABEL_LEN {
150        return Err(Error::Decryption);
151    }
152
153    Digest::update(digest, &label);
154
155    let expected_p_hash = digest.finalize_reset();
156
157    let res = decrypt_inner(em, h_size, &expected_p_hash, k, |seed, db| {
158        mgf1_xor(seed, mgf_digest, db);
159        mgf1_xor(db, mgf_digest, seed);
160    })?;
161    if res.is_none().into() {
162        return Err(Error::Decryption);
163    }
164
165    let index = res.unwrap();
166
167    Ok(em[index as usize..].to_vec())
168}
169
170///Decrypts OAEP padding.
171///
172/// Note that whether this function returns an error or not discloses secret
173/// information. If an attacker can cause this function to run repeatedly and
174/// learn whether each instance returned an error then they can decrypt and
175/// forge signatures as if they had the private key.
176///
177/// See `decrypt_session_key` for a way of solving this problem.
178///
179/// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1
180#[inline]
181pub(crate) fn oaep_decrypt_digest<D, MGD>(
182    em: &mut [u8],
183    label: Option<Box<[u8]>>,
184    k: usize,
185) -> Result<Vec<u8>>
186where
187    D: Digest,
188    MGD: Digest + FixedOutputReset,
189{
190    let h_size = <D as Digest>::output_size();
191
192    let label = label.unwrap_or_default();
193    if label.len() as u64 >= MAX_LABEL_LEN {
194        return Err(Error::LabelTooLong);
195    }
196
197    let expected_p_hash = D::digest(&label);
198
199    let res = decrypt_inner(em, h_size, &expected_p_hash, k, |seed, db| {
200        let mut mgf_digest = MGD::new();
201        mgf1_xor_digest(seed, &mut mgf_digest, db);
202        mgf1_xor_digest(db, &mut mgf_digest, seed);
203    })?;
204    if res.is_none().into() {
205        return Err(Error::Decryption);
206    }
207
208    let index = res.unwrap();
209
210    Ok(em[index as usize..].to_vec())
211}
212
213/// Decrypts OAEP padding. It returns one or zero in valid that indicates whether the
214/// plaintext was correctly structured.
215#[inline]
216fn decrypt_inner<MGF: FnMut(&mut [u8], &mut [u8])>(
217    em: &mut [u8],
218    h_size: usize,
219    expected_p_hash: &[u8],
220    k: usize,
221    mut mgf: MGF,
222) -> Result<CtOption<u32>> {
223    if k < 11 {
224        return Err(Error::Decryption);
225    }
226
227    if k < h_size * 2 + 2 {
228        return Err(Error::Decryption);
229    }
230
231    let first_byte_is_zero = em[0].ct_eq(&0u8);
232
233    let (_, payload) = em.split_at_mut(1);
234    let (seed, db) = payload.split_at_mut(h_size);
235
236    mgf(seed, db);
237
238    let hash_are_equal = db[0..h_size].ct_eq(expected_p_hash);
239
240    // The remainder of the plaintext must be zero or more 0x00, followed
241    // by 0x01, followed by the message.
242    //   looking_for_index: 1 if we are still looking for the 0x01
243    //   index: the offset of the first 0x01 byte
244    //   zero_before_one: 1 if we saw a non-zero byte before the 1
245    let mut looking_for_index = Choice::TRUE;
246    let mut index = 0u32;
247    let mut nonzero_before_one = Choice::FALSE;
248
249    for (i, el) in db.iter().skip(h_size).enumerate() {
250        let equals0 = el.ct_eq(&0u8);
251        let equals1 = el.ct_eq(&1u8);
252        index.ct_assign(&(i as u32), looking_for_index & equals1);
253        looking_for_index &= !equals1;
254        nonzero_before_one |= looking_for_index & !equals0;
255    }
256
257    let valid = first_byte_is_zero & hash_are_equal & !nonzero_before_one & !looking_for_index;
258
259    Ok(CtOption::new(index + 2 + (h_size * 2) as u32, valid))
260}