Skip to main content

rsa/algorithms/
pss.rs

1//! Support for the [Probabilistic Signature Scheme] (PSS) a.k.a. RSASSA-PSS.
2//!
3//! Designed by Mihir Bellare and Phillip Rogaway. Specified in [RFC8017 § 8.1].
4//!
5//! # Usage
6//!
7//! See [code example in the toplevel rustdoc](../index.html#pss-signatures).
8//!
9//! [Probabilistic Signature Scheme]: https://en.wikipedia.org/wiki/Probabilistic_signature_scheme
10//! [RFC8017 § 8.1]: https://datatracker.ietf.org/doc/html/rfc8017#section-8.1
11
12use alloc::vec::Vec;
13use crypto_bigint::{Choice, CtEq, CtSelect};
14use digest::{Digest, FixedOutputReset};
15
16use super::mgf::{mgf1_xor, mgf1_xor_digest};
17use crate::errors::{Error, Result};
18
19pub(crate) fn emsa_pss_encode<D>(
20    m_hash: &[u8],
21    em_bits: usize,
22    salt: &[u8],
23    hash: &mut D,
24) -> Result<Vec<u8>>
25where
26    D: Digest + FixedOutputReset,
27{
28    // See [1], section 9.1.1
29    let h_len = <D as Digest>::output_size();
30    let s_len = salt.len();
31    let em_len = em_bits.div_ceil(8);
32
33    // 1. If the length of M is greater than the input limitation for the
34    //     hash function (2^61 - 1 octets for SHA-1), output "message too
35    //     long" and stop.
36    //
37    // 2.  Let mHash = Hash(M), an octet string of length hLen.
38    if m_hash.len() != h_len {
39        return Err(Error::InputNotHashed);
40    }
41
42    // 3. If em_len < h_len + s_len + 2, output "encoding error" and stop.
43    if em_len < h_len + s_len + 2 {
44        // TODO: Key size too small
45        return Err(Error::Internal);
46    }
47
48    let mut em = vec![0; em_len];
49
50    let (db, h) = em.split_at_mut(em_len - h_len - 1);
51    let h = &mut h[..(em_len - 1) - db.len()];
52
53    // 4. Generate a random octet string salt of length s_len; if s_len = 0,
54    //     then salt is the empty string.
55    //
56    // 5.  Let
57    //       M' = (0x)00 00 00 00 00 00 00 00 || m_hash || salt;
58    //
59    //     M' is an octet string of length 8 + h_len + s_len with eight
60    //     initial zero octets.
61    //
62    // 6.  Let H = Hash(M'), an octet string of length h_len.
63    let prefix = [0u8; 8];
64
65    Digest::update(hash, prefix);
66    Digest::update(hash, m_hash);
67    Digest::update(hash, salt);
68
69    let hashed = hash.finalize_reset();
70    h.copy_from_slice(&hashed);
71
72    // 7.  Generate an octet string PS consisting of em_len - s_len - h_len - 2
73    //     zero octets. The length of PS may be 0.
74    //
75    // 8.  Let DB = PS || 0x01 || salt; DB is an octet string of length
76    //     emLen - hLen - 1.
77    db[em_len - s_len - h_len - 2] = 0x01;
78    db[em_len - s_len - h_len - 1..].copy_from_slice(salt);
79
80    // 9.  Let dbMask = MGF(H, emLen - hLen - 1).
81    //
82    // 10. Let maskedDB = DB \xor dbMask.
83    mgf1_xor(db, hash, h);
84
85    // 11. Set the leftmost 8 * em_len - em_bits bits of the leftmost octet in
86    //     maskedDB to zero.
87    db[0] &= 0xFF >> (8 * em_len - em_bits);
88
89    // 12. Let EM = maskedDB || H || 0xbc.
90    em[em_len - 1] = 0xBC;
91
92    Ok(em)
93}
94
95pub(crate) fn emsa_pss_encode_digest<D>(
96    m_hash: &[u8],
97    em_bits: usize,
98    salt: &[u8],
99) -> Result<Vec<u8>>
100where
101    D: Digest + FixedOutputReset,
102{
103    // See [1], section 9.1.1
104    let h_len = <D as Digest>::output_size();
105    let s_len = salt.len();
106    let em_len = em_bits.div_ceil(8);
107
108    // 1. If the length of M is greater than the input limitation for the
109    //     hash function (2^61 - 1 octets for SHA-1), output "message too
110    //     long" and stop.
111    //
112    // 2.  Let mHash = Hash(M), an octet string of length hLen.
113    if m_hash.len() != h_len {
114        return Err(Error::InputNotHashed);
115    }
116
117    // 3. If em_len < h_len + s_len + 2, output "encoding error" and stop.
118    if em_len < h_len + s_len + 2 {
119        // TODO: Key size too small
120        return Err(Error::Internal);
121    }
122
123    let mut em = vec![0; em_len];
124
125    let (db, h) = em.split_at_mut(em_len - h_len - 1);
126    let h = &mut h[..(em_len - 1) - db.len()];
127
128    // 4. Generate a random octet string salt of length s_len; if s_len = 0,
129    //     then salt is the empty string.
130    //
131    // 5.  Let
132    //       M' = (0x)00 00 00 00 00 00 00 00 || m_hash || salt;
133    //
134    //     M' is an octet string of length 8 + h_len + s_len with eight
135    //     initial zero octets.
136    //
137    // 6.  Let H = Hash(M'), an octet string of length h_len.
138    let prefix = [0u8; 8];
139
140    let mut hash = D::new();
141
142    Digest::update(&mut hash, prefix);
143    Digest::update(&mut hash, m_hash);
144    Digest::update(&mut hash, salt);
145
146    let hashed = hash.finalize_reset();
147    h.copy_from_slice(&hashed);
148
149    // 7.  Generate an octet string PS consisting of em_len - s_len - h_len - 2
150    //     zero octets. The length of PS may be 0.
151    //
152    // 8.  Let DB = PS || 0x01 || salt; DB is an octet string of length
153    //     emLen - hLen - 1.
154    db[em_len - s_len - h_len - 2] = 0x01;
155    db[em_len - s_len - h_len - 1..].copy_from_slice(salt);
156
157    // 9.  Let dbMask = MGF(H, emLen - hLen - 1).
158    //
159    // 10. Let maskedDB = DB \xor dbMask.
160    mgf1_xor_digest(db, &mut hash, h);
161
162    // 11. Set the leftmost 8 * em_len - em_bits bits of the leftmost octet in
163    //     maskedDB to zero.
164    db[0] &= 0xFF >> (8 * em_len - em_bits);
165
166    // 12. Let EM = maskedDB || H || 0xbc.
167    em[em_len - 1] = 0xBC;
168
169    Ok(em)
170}
171
172fn emsa_pss_verify_pre<'a>(
173    m_hash: &[u8],
174    em: &'a mut [u8],
175    em_bits: usize,
176    s_len: Option<usize>,
177    h_len: usize,
178) -> Result<(&'a mut [u8], &'a mut [u8])> {
179    // 1. If the length of M is greater than the input limitation for the
180    //    hash function (2^61 - 1 octets for SHA-1), output "inconsistent"
181    //    and stop.
182    //
183    // 2. Let mHash = Hash(M), an octet string of length hLen
184    if m_hash.len() != h_len {
185        return Err(Error::Verification);
186    }
187
188    let em_len = em.len(); //(em_bits + 7) / 8;
189    if let Some(s_len) = s_len {
190        // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop.
191        if em_len < h_len + s_len + 2 {
192            return Err(Error::Verification);
193        }
194    }
195
196    // 4. If the rightmost octet of EM does not have hexadecimal value
197    //    0xbc, output "inconsistent" and stop.
198    if em[em.len() - 1] != 0xBC {
199        return Err(Error::Verification);
200    }
201
202    // 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and
203    //    let H be the next hLen octets.
204    let (db, h) = em.split_at_mut(em_len - h_len - 1);
205    let h = &mut h[..h_len];
206
207    // 6. If the leftmost 8 * em_len - em_bits bits of the leftmost octet in
208    //    maskedDB are not all equal to zero, output "inconsistent" and
209    //    stop.
210    if db[0]
211        & (0xFF_u8
212            .checked_shl(8 - (8 * em_len - em_bits) as u32)
213            .unwrap_or(0))
214        != 0
215    {
216        return Err(Error::Verification);
217    }
218
219    Ok((db, h))
220}
221
222fn emsa_pss_verify_salt(db: &[u8], em_len: usize, s_len: usize, h_len: usize) -> Choice {
223    // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero
224    //     or if the octet at position emLen - hLen - sLen - 1 (the leftmost
225    //     position is "position 1") does not have hexadecimal value 0x01,
226    //     output "inconsistent" and stop.
227    let (zeroes, rest) = db.split_at(em_len - h_len - s_len - 2);
228    let valid: Choice = zeroes.iter().fold(Choice::TRUE, |a, e| a & e.ct_eq(&0x00));
229
230    valid & rest[0].ct_eq(&0x01)
231}
232
233/// Detect salt length by scanning DB for the 0x01 separator byte.
234/// Returns (s_len, valid) where s_len is 0 on failure.
235fn emsa_pss_get_salt_len(db: &[u8], em_len: usize, h_len: usize) -> (usize, Choice) {
236    let em_len = em_len as u32;
237    let h_len = h_len as u32;
238    let max_scan_len = em_len - h_len - 2;
239
240    let mut separator_pos = 0u32;
241    let mut found_separator = Choice::FALSE;
242    let mut padding_valid = Choice::TRUE;
243
244    // Single forward scan to find separator and validate padding
245    for i in 0..=max_scan_len {
246        let byte_val = db[i as usize];
247        let is_zero = byte_val.ct_eq(&0x00);
248        let is_separator = byte_val.ct_eq(&0x01);
249        let is_invalid = !(is_zero | is_separator);
250
251        // Update separator position if we found one and haven't found one before
252        let should_update_pos = is_separator & !found_separator;
253        separator_pos = u32::ct_select(&separator_pos, &i, should_update_pos);
254        found_separator = Choice::ct_select(&found_separator, &Choice::TRUE, should_update_pos);
255
256        // Padding is invalid if we see a non-zero, non-separator byte before finding separator
257        let corrupts_padding = is_invalid & !found_separator;
258        padding_valid &= !corrupts_padding;
259    }
260
261    let salt_len = max_scan_len.wrapping_sub(separator_pos);
262    let final_valid = found_separator & padding_valid;
263
264    // Return 0 length on failure
265    let result_len = u32::ct_select(&0u32, &salt_len, final_valid);
266
267    (result_len as usize, final_valid)
268}
269
270pub(crate) fn emsa_pss_verify<D>(
271    m_hash: &[u8],
272    em: &mut [u8],
273    s_len: Option<usize>,
274    hash: &mut D,
275    key_bits: usize,
276) -> Result<()>
277where
278    D: Digest + FixedOutputReset,
279{
280    let em_bits = key_bits - 1;
281    let em_len = em_bits.div_ceil(8);
282    let key_len = key_bits.div_ceil(8);
283    let h_len = <D as Digest>::output_size();
284
285    let em = &mut em[key_len - em_len..];
286
287    let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?;
288
289    // 7. Let dbMask = MGF(H, em_len - h_len - 1)
290    //
291    // 8. Let DB = maskedDB \xor dbMask
292    mgf1_xor(db, hash, &*h);
293
294    // 9.  Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB
295    //     to zero.
296    db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits);
297
298    let (s_len, salt_valid) = match s_len {
299        Some(s_len) => (s_len, emsa_pss_verify_salt(db, em_len, s_len, h_len)),
300        None => emsa_pss_get_salt_len(db, em_len, h_len),
301    };
302
303    // 11. Let salt be the last s_len octets of DB.
304    let salt = &db[db.len() - s_len..];
305
306    // 12. Let
307    //          M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ;
308    //     M' is an octet string of length 8 + hLen + sLen with eight
309    //     initial zero octets.
310    //
311    // 13. Let H' = Hash(M'), an octet string of length hLen.
312    let prefix = [0u8; 8];
313
314    Digest::update(hash, &prefix[..]);
315    Digest::update(hash, m_hash);
316    Digest::update(hash, salt);
317    let h0 = hash.finalize_reset();
318
319    // 14. If H = H', output "consistent." Otherwise, output "inconsistent."
320    if (salt_valid & h0.as_slice().ct_eq(h)).into() {
321        Ok(())
322    } else {
323        Err(Error::Verification)
324    }
325}
326
327pub(crate) fn emsa_pss_verify_digest<D>(
328    m_hash: &[u8],
329    em: &mut [u8],
330    s_len: Option<usize>,
331    key_bits: usize,
332) -> Result<()>
333where
334    D: Digest + FixedOutputReset,
335{
336    let em_bits = key_bits - 1;
337    let em_len = em_bits.div_ceil(8);
338    let key_len = key_bits.div_ceil(8);
339    let h_len = <D as Digest>::output_size();
340
341    let em = &mut em[key_len - em_len..];
342
343    let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?;
344
345    let mut hash = D::new();
346
347    // 7. Let dbMask = MGF(H, em_len - h_len - 1)
348    //
349    // 8. Let DB = maskedDB \xor dbMask
350    mgf1_xor_digest::<D>(db, &mut hash, &*h);
351
352    // 9.  Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB
353    //     to zero.
354    db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits);
355
356    let (s_len, salt_valid) = match s_len {
357        Some(s_len) => (s_len, emsa_pss_verify_salt(db, em_len, s_len, h_len)),
358        None => emsa_pss_get_salt_len(db, em_len, h_len),
359    };
360
361    // 11. Let salt be the last s_len octets of DB.
362    let salt = &db[db.len() - s_len..];
363
364    // 12. Let
365    //          M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ;
366    //     M' is an octet string of length 8 + hLen + sLen with eight
367    //     initial zero octets.
368    //
369    // 13. Let H' = Hash(M'), an octet string of length hLen.
370    let prefix = [0u8; 8];
371
372    Digest::update(&mut hash, &prefix[..]);
373    Digest::update(&mut hash, m_hash);
374    Digest::update(&mut hash, salt);
375    let h0 = hash.finalize_reset();
376
377    // 14. If H = H', output "consistent." Otherwise, output "inconsistent."
378    if (salt_valid & h0.as_slice().ct_eq(h)).into() {
379        Ok(())
380    } else {
381        Err(Error::Verification)
382    }
383}