ml_dsa/
crypto.rs

1use hybrid_array::Array;
2use sha3::{
3    Shake128, Shake256,
4    digest::{ExtendableOutput, XofReader},
5};
6
7use crate::module_lattice::encode::ArraySize;
8
9pub enum ShakeState<Shake: ExtendableOutput> {
10    Absorbing(Shake),
11    Squeezing(Shake::Reader),
12}
13
14impl<Shake: ExtendableOutput + Default> Default for ShakeState<Shake> {
15    fn default() -> Self {
16        Self::Absorbing(Shake::default())
17    }
18}
19
20impl<Shake: ExtendableOutput + Default + Clone> ShakeState<Shake> {
21    pub fn absorb(mut self, input: &[u8]) -> Self {
22        match &mut self {
23            Self::Absorbing(sponge) => sponge.update(input),
24            Self::Squeezing(_) => unreachable!(),
25        }
26
27        self
28    }
29
30    pub fn squeeze(&mut self, output: &mut [u8]) -> &mut Self {
31        match self {
32            Self::Absorbing(sponge) => {
33                // Clone required to satisfy borrow checker
34                let mut reader = sponge.clone().finalize_xof();
35                reader.read(output);
36                *self = Self::Squeezing(reader);
37            }
38            Self::Squeezing(reader) => {
39                reader.read(output.as_mut());
40            }
41        }
42
43        self
44    }
45
46    pub fn squeeze_new<N: ArraySize>(&mut self) -> Array<u8, N> {
47        let mut v = Array::default();
48        self.squeeze(&mut v);
49        v
50    }
51}
52
53pub type G = ShakeState<Shake128>;
54pub type H = ShakeState<Shake256>;
55
56#[cfg(test)]
57mod test {
58    use super::*;
59    use crate::util::B32;
60    use hex_literal::hex;
61
62    #[test]
63    fn g() {
64        let input = b"hello world";
65        let expected1 = hex!("3a9159f071e4dd1c8c4f968607c30942e120d8156b8b1e72e0d376e8871cb8b8");
66        let expected2 = hex!("99072665674f26cc494a4bcf027c58267e8ee2da60e942759de86d2670bba1aa");
67
68        let mut g = G::default().absorb(input);
69
70        let mut actual = [0u8; 32];
71        g.squeeze(&mut actual);
72        assert_eq!(actual, expected1);
73
74        let actual: B32 = g.squeeze_new();
75        assert_eq!(actual, expected2);
76    }
77
78    #[test]
79    fn h() {
80        let input = b"hello world";
81        let expected1 = hex!("369771bb2cb9d2b04c1d54cca487e372d9f187f73f7ba3f65b95c8ee7798c527");
82        let expected2 = hex!("f4f3c2d55c2d46a29f2e945d469c3df27853a8735271f5cc2d9e889544357116");
83
84        let mut h = H::default().absorb(input);
85
86        let mut actual = [0u8; 32];
87        h.squeeze(&mut actual);
88        assert_eq!(actual, expected1);
89
90        let actual: B32 = h.squeeze_new();
91        assert_eq!(actual, expected2);
92    }
93}