Skip to main content

rsa/algorithms/
mgf.rs

1//! Mask generation function common to both PSS and OAEP padding
2
3use digest::{Digest, FixedOutputReset};
4
5/// Mask generation function.
6///
7/// Panics if out is larger than 2**32. This is in accordance with RFC 8017 - PKCS #1 B.2.1
8pub(crate) fn mgf1_xor<D>(out: &mut [u8], digest: &mut D, seed: &[u8])
9where
10    D: Digest + FixedOutputReset,
11{
12    let mut counter = [0u8; 4];
13    let mut i = 0;
14
15    const MAX_LEN: u64 = u32::MAX as u64 + 1;
16    assert!(out.len() as u64 <= MAX_LEN);
17
18    while i < out.len() {
19        let mut digest_input = vec![0u8; seed.len() + 4];
20        digest_input[0..seed.len()].copy_from_slice(seed);
21        digest_input[seed.len()..].copy_from_slice(&counter);
22
23        Digest::update(digest, digest_input.as_slice());
24        let digest_output = &*digest.finalize_reset();
25        let mut j = 0;
26        loop {
27            if j >= digest_output.len() || i >= out.len() {
28                break;
29            }
30
31            out[i] ^= digest_output[j];
32            j += 1;
33            i += 1;
34        }
35        inc_counter(&mut counter);
36    }
37}
38
39/// Mask generation function.
40///
41/// Panics if out is larger than 2**32. This is in accordance with RFC 8017 - PKCS #1 B.2.1
42pub(crate) fn mgf1_xor_digest<D>(out: &mut [u8], digest: &mut D, seed: &[u8])
43where
44    D: Digest + FixedOutputReset,
45{
46    let mut counter = [0u8; 4];
47    let mut i = 0;
48
49    const MAX_LEN: u64 = u32::MAX as u64 + 1;
50    assert!(out.len() as u64 <= MAX_LEN);
51
52    while i < out.len() {
53        Digest::update(digest, seed);
54        Digest::update(digest, counter);
55
56        let digest_output = digest.finalize_reset();
57        let mut j = 0;
58        loop {
59            if j >= digest_output.len() || i >= out.len() {
60                break;
61            }
62
63            out[i] ^= digest_output[j];
64            j += 1;
65            i += 1;
66        }
67        inc_counter(&mut counter);
68    }
69}
70fn inc_counter(counter: &mut [u8; 4]) {
71    for i in (0..4).rev() {
72        counter[i] = counter[i].wrapping_add(1);
73        if counter[i] != 0 {
74            // No overflow
75            return;
76        }
77    }
78}