1use alloc::vec::Vec;
13use crypto_bigint::{Choice, CtEq, CtSelect};
14use digest::{Digest, FixedOutputReset};
15
16use super::mgf::{mgf1_xor, mgf1_xor_digest};
17use crate::errors::{Error, Result};
18
19pub(crate) fn emsa_pss_encode<D>(
20 m_hash: &[u8],
21 em_bits: usize,
22 salt: &[u8],
23 hash: &mut D,
24) -> Result<Vec<u8>>
25where
26 D: Digest + FixedOutputReset,
27{
28 let h_len = <D as Digest>::output_size();
30 let s_len = salt.len();
31 let em_len = em_bits.div_ceil(8);
32
33 if m_hash.len() != h_len {
39 return Err(Error::InputNotHashed);
40 }
41
42 if em_len < h_len + s_len + 2 {
44 return Err(Error::Internal);
46 }
47
48 let mut em = vec![0; em_len];
49
50 let (db, h) = em.split_at_mut(em_len - h_len - 1);
51 let h = &mut h[..(em_len - 1) - db.len()];
52
53 let prefix = [0u8; 8];
64
65 Digest::update(hash, prefix);
66 Digest::update(hash, m_hash);
67 Digest::update(hash, salt);
68
69 let hashed = hash.finalize_reset();
70 h.copy_from_slice(&hashed);
71
72 db[em_len - s_len - h_len - 2] = 0x01;
78 db[em_len - s_len - h_len - 1..].copy_from_slice(salt);
79
80 mgf1_xor(db, hash, h);
84
85 db[0] &= 0xFF >> (8 * em_len - em_bits);
88
89 em[em_len - 1] = 0xBC;
91
92 Ok(em)
93}
94
95pub(crate) fn emsa_pss_encode_digest<D>(
96 m_hash: &[u8],
97 em_bits: usize,
98 salt: &[u8],
99) -> Result<Vec<u8>>
100where
101 D: Digest + FixedOutputReset,
102{
103 let h_len = <D as Digest>::output_size();
105 let s_len = salt.len();
106 let em_len = em_bits.div_ceil(8);
107
108 if m_hash.len() != h_len {
114 return Err(Error::InputNotHashed);
115 }
116
117 if em_len < h_len + s_len + 2 {
119 return Err(Error::Internal);
121 }
122
123 let mut em = vec![0; em_len];
124
125 let (db, h) = em.split_at_mut(em_len - h_len - 1);
126 let h = &mut h[..(em_len - 1) - db.len()];
127
128 let prefix = [0u8; 8];
139
140 let mut hash = D::new();
141
142 Digest::update(&mut hash, prefix);
143 Digest::update(&mut hash, m_hash);
144 Digest::update(&mut hash, salt);
145
146 let hashed = hash.finalize_reset();
147 h.copy_from_slice(&hashed);
148
149 db[em_len - s_len - h_len - 2] = 0x01;
155 db[em_len - s_len - h_len - 1..].copy_from_slice(salt);
156
157 mgf1_xor_digest(db, &mut hash, h);
161
162 db[0] &= 0xFF >> (8 * em_len - em_bits);
165
166 em[em_len - 1] = 0xBC;
168
169 Ok(em)
170}
171
172fn emsa_pss_verify_pre<'a>(
173 m_hash: &[u8],
174 em: &'a mut [u8],
175 em_bits: usize,
176 s_len: Option<usize>,
177 h_len: usize,
178) -> Result<(&'a mut [u8], &'a mut [u8])> {
179 if m_hash.len() != h_len {
185 return Err(Error::Verification);
186 }
187
188 let em_len = em.len(); if let Some(s_len) = s_len {
190 if em_len < h_len + s_len + 2 {
192 return Err(Error::Verification);
193 }
194 }
195
196 if em[em.len() - 1] != 0xBC {
199 return Err(Error::Verification);
200 }
201
202 let (db, h) = em.split_at_mut(em_len - h_len - 1);
205 let h = &mut h[..h_len];
206
207 if db[0]
211 & (0xFF_u8
212 .checked_shl(8 - (8 * em_len - em_bits) as u32)
213 .unwrap_or(0))
214 != 0
215 {
216 return Err(Error::Verification);
217 }
218
219 Ok((db, h))
220}
221
222fn emsa_pss_verify_salt(db: &[u8], em_len: usize, s_len: usize, h_len: usize) -> Choice {
223 let (zeroes, rest) = db.split_at(em_len - h_len - s_len - 2);
228 let valid: Choice = zeroes.iter().fold(Choice::TRUE, |a, e| a & e.ct_eq(&0x00));
229
230 valid & rest[0].ct_eq(&0x01)
231}
232
233fn emsa_pss_get_salt_len(db: &[u8], em_len: usize, h_len: usize) -> (usize, Choice) {
236 let em_len = em_len as u32;
237 let h_len = h_len as u32;
238 let max_scan_len = em_len - h_len - 2;
239
240 let mut separator_pos = 0u32;
241 let mut found_separator = Choice::FALSE;
242 let mut padding_valid = Choice::TRUE;
243
244 for i in 0..=max_scan_len {
246 let byte_val = db[i as usize];
247 let is_zero = byte_val.ct_eq(&0x00);
248 let is_separator = byte_val.ct_eq(&0x01);
249 let is_invalid = !(is_zero | is_separator);
250
251 let should_update_pos = is_separator & !found_separator;
253 separator_pos = u32::ct_select(&separator_pos, &i, should_update_pos);
254 found_separator = Choice::ct_select(&found_separator, &Choice::TRUE, should_update_pos);
255
256 let corrupts_padding = is_invalid & !found_separator;
258 padding_valid &= !corrupts_padding;
259 }
260
261 let salt_len = max_scan_len.wrapping_sub(separator_pos);
262 let final_valid = found_separator & padding_valid;
263
264 let result_len = u32::ct_select(&0u32, &salt_len, final_valid);
266
267 (result_len as usize, final_valid)
268}
269
270pub(crate) fn emsa_pss_verify<D>(
271 m_hash: &[u8],
272 em: &mut [u8],
273 s_len: Option<usize>,
274 hash: &mut D,
275 key_bits: usize,
276) -> Result<()>
277where
278 D: Digest + FixedOutputReset,
279{
280 let em_bits = key_bits - 1;
281 let em_len = em_bits.div_ceil(8);
282 let key_len = key_bits.div_ceil(8);
283 let h_len = <D as Digest>::output_size();
284
285 let em = &mut em[key_len - em_len..];
286
287 let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?;
288
289 mgf1_xor(db, hash, &*h);
293
294 db[0] &= 0xFF >> (8 * em_len - em_bits);
297
298 let (s_len, salt_valid) = match s_len {
299 Some(s_len) => (s_len, emsa_pss_verify_salt(db, em_len, s_len, h_len)),
300 None => emsa_pss_get_salt_len(db, em_len, h_len),
301 };
302
303 let salt = &db[db.len() - s_len..];
305
306 let prefix = [0u8; 8];
313
314 Digest::update(hash, &prefix[..]);
315 Digest::update(hash, m_hash);
316 Digest::update(hash, salt);
317 let h0 = hash.finalize_reset();
318
319 if (salt_valid & h0.as_slice().ct_eq(h)).into() {
321 Ok(())
322 } else {
323 Err(Error::Verification)
324 }
325}
326
327pub(crate) fn emsa_pss_verify_digest<D>(
328 m_hash: &[u8],
329 em: &mut [u8],
330 s_len: Option<usize>,
331 key_bits: usize,
332) -> Result<()>
333where
334 D: Digest + FixedOutputReset,
335{
336 let em_bits = key_bits - 1;
337 let em_len = em_bits.div_ceil(8);
338 let key_len = key_bits.div_ceil(8);
339 let h_len = <D as Digest>::output_size();
340
341 let em = &mut em[key_len - em_len..];
342
343 let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?;
344
345 let mut hash = D::new();
346
347 mgf1_xor_digest::<D>(db, &mut hash, &*h);
351
352 db[0] &= 0xFF >> (8 * em_len - em_bits);
355
356 let (s_len, salt_valid) = match s_len {
357 Some(s_len) => (s_len, emsa_pss_verify_salt(db, em_len, s_len, h_len)),
358 None => emsa_pss_get_salt_len(db, em_len, h_len),
359 };
360
361 let salt = &db[db.len() - s_len..];
363
364 let prefix = [0u8; 8];
371
372 Digest::update(&mut hash, &prefix[..]);
373 Digest::update(&mut hash, m_hash);
374 Digest::update(&mut hash, salt);
375 let h0 = hash.finalize_reset();
376
377 if (salt_valid & h0.as_slice().ct_eq(h)).into() {
379 Ok(())
380 } else {
381 Err(Error::Verification)
382 }
383}