1use alloc::boxed::Box;
4use alloc::vec::Vec;
5
6use crypto_bigint::{Choice, CtAssign, CtEq, CtOption};
7use digest::{Digest, FixedOutputReset};
8use rand_core::TryCryptoRng;
9use zeroize::Zeroizing;
10
11use super::mgf::{mgf1_xor, mgf1_xor_digest};
12use crate::errors::{Error, Result};
13
14const MAX_LABEL_LEN: u64 = 1 << 61;
20
21#[inline]
22fn encrypt_internal<R: TryCryptoRng + ?Sized, MGF: FnMut(&mut [u8], &mut [u8])>(
23 rng: &mut R,
24 msg: &[u8],
25 p_hash: &[u8],
26 h_size: usize,
27 k: usize,
28 mut mgf: MGF,
29) -> Result<Zeroizing<Vec<u8>>> {
30 if msg.len() + 2 * h_size + 2 > k {
31 return Err(Error::MessageTooLong);
32 }
33
34 let mut em = Zeroizing::new(vec![0u8; k]);
35
36 let (_, payload) = em.split_at_mut(1);
37 let (seed, db) = payload.split_at_mut(h_size);
38 rng.try_fill_bytes(seed).map_err(|_| Error::Rng)?;
39
40 let db_len = k - h_size - 1;
42
43 db[0..h_size].copy_from_slice(p_hash);
44 db[db_len - msg.len() - 1] = 1;
45 db[db_len - msg.len()..].copy_from_slice(msg);
46
47 mgf(seed, db);
48
49 Ok(em)
50}
51
52#[inline]
60pub(crate) fn oaep_encrypt<R, D, MGD>(
61 rng: &mut R,
62 msg: &[u8],
63 digest: &mut D,
64 mgf_digest: &mut MGD,
65 label: Option<Box<[u8]>>,
66 k: usize,
67) -> Result<Zeroizing<Vec<u8>>>
68where
69 R: TryCryptoRng + ?Sized,
70 D: Digest + FixedOutputReset,
71 MGD: Digest + FixedOutputReset,
72{
73 let h_size = <D as Digest>::output_size();
74
75 let label = label.unwrap_or_default();
76 if label.len() as u64 >= MAX_LABEL_LEN {
77 return Err(Error::LabelTooLong);
78 }
79
80 Digest::update(digest, &label);
81 let p_hash = digest.finalize_reset();
82
83 encrypt_internal(rng, msg, &p_hash, h_size, k, |seed, db| {
84 mgf1_xor(db, mgf_digest, seed);
85 mgf1_xor(seed, mgf_digest, db);
86 })
87}
88
89#[inline]
97pub(crate) fn oaep_encrypt_digest<R, D, MGD>(
98 rng: &mut R,
99 msg: &[u8],
100 label: Option<Box<[u8]>>,
101 k: usize,
102) -> Result<Zeroizing<Vec<u8>>>
103where
104 R: TryCryptoRng + ?Sized,
105 D: Digest,
106 MGD: Digest + FixedOutputReset,
107{
108 let h_size = <D as Digest>::output_size();
109
110 let label = label.unwrap_or_default();
111 if label.len() as u64 >= MAX_LABEL_LEN {
112 return Err(Error::LabelTooLong);
113 }
114
115 let p_hash = D::digest(&label);
116
117 encrypt_internal(rng, msg, &p_hash, h_size, k, |seed, db| {
118 let mut mgf_digest = MGD::new();
119 mgf1_xor_digest(db, &mut mgf_digest, seed);
120 mgf1_xor_digest(seed, &mut mgf_digest, db);
121 })
122}
123
124#[inline]
135pub(crate) fn oaep_decrypt<D, MGD>(
136 em: &mut [u8],
137 digest: &mut D,
138 mgf_digest: &mut MGD,
139 label: Option<Box<[u8]>>,
140 k: usize,
141) -> Result<Vec<u8>>
142where
143 D: Digest + FixedOutputReset,
144 MGD: Digest + FixedOutputReset,
145{
146 let h_size = <D as Digest>::output_size();
147
148 let label = label.unwrap_or_default();
149 if label.len() as u64 >= MAX_LABEL_LEN {
150 return Err(Error::Decryption);
151 }
152
153 Digest::update(digest, &label);
154
155 let expected_p_hash = digest.finalize_reset();
156
157 let res = decrypt_inner(em, h_size, &expected_p_hash, k, |seed, db| {
158 mgf1_xor(seed, mgf_digest, db);
159 mgf1_xor(db, mgf_digest, seed);
160 })?;
161 if res.is_none().into() {
162 return Err(Error::Decryption);
163 }
164
165 let index = res.unwrap();
166
167 Ok(em[index as usize..].to_vec())
168}
169
170#[inline]
181pub(crate) fn oaep_decrypt_digest<D, MGD>(
182 em: &mut [u8],
183 label: Option<Box<[u8]>>,
184 k: usize,
185) -> Result<Vec<u8>>
186where
187 D: Digest,
188 MGD: Digest + FixedOutputReset,
189{
190 let h_size = <D as Digest>::output_size();
191
192 let label = label.unwrap_or_default();
193 if label.len() as u64 >= MAX_LABEL_LEN {
194 return Err(Error::LabelTooLong);
195 }
196
197 let expected_p_hash = D::digest(&label);
198
199 let res = decrypt_inner(em, h_size, &expected_p_hash, k, |seed, db| {
200 let mut mgf_digest = MGD::new();
201 mgf1_xor_digest(seed, &mut mgf_digest, db);
202 mgf1_xor_digest(db, &mut mgf_digest, seed);
203 })?;
204 if res.is_none().into() {
205 return Err(Error::Decryption);
206 }
207
208 let index = res.unwrap();
209
210 Ok(em[index as usize..].to_vec())
211}
212
213#[inline]
216fn decrypt_inner<MGF: FnMut(&mut [u8], &mut [u8])>(
217 em: &mut [u8],
218 h_size: usize,
219 expected_p_hash: &[u8],
220 k: usize,
221 mut mgf: MGF,
222) -> Result<CtOption<u32>> {
223 if k < 11 {
224 return Err(Error::Decryption);
225 }
226
227 if k < h_size * 2 + 2 {
228 return Err(Error::Decryption);
229 }
230
231 let first_byte_is_zero = em[0].ct_eq(&0u8);
232
233 let (_, payload) = em.split_at_mut(1);
234 let (seed, db) = payload.split_at_mut(h_size);
235
236 mgf(seed, db);
237
238 let hash_are_equal = db[0..h_size].ct_eq(expected_p_hash);
239
240 let mut looking_for_index = Choice::TRUE;
246 let mut index = 0u32;
247 let mut nonzero_before_one = Choice::FALSE;
248
249 for (i, el) in db.iter().skip(h_size).enumerate() {
250 let equals0 = el.ct_eq(&0u8);
251 let equals1 = el.ct_eq(&1u8);
252 index.ct_assign(&(i as u32), looking_for_index & equals1);
253 looking_for_index &= !equals1;
254 nonzero_before_one |= looking_for_index & !equals0;
255 }
256
257 let valid = first_byte_is_zero & hash_are_equal & !nonzero_before_one & !looking_for_index;
258
259 Ok(CtOption::new(index + 2 + (h_size * 2) as u32, valid))
260}