Skip to main content

ml_dsa/
crypto.rs

1use hybrid_array::Array;
2use module_lattice::ArraySize;
3use shake::{
4    Shake128, Shake256,
5    digest::{ExtendableOutput, XofReader},
6};
7
8pub(crate) enum ShakeState<Shake: ExtendableOutput> {
9    Absorbing(Shake),
10    Squeezing(Shake::Reader),
11}
12
13impl<Shake: ExtendableOutput + Default> Default for ShakeState<Shake> {
14    fn default() -> Self {
15        Self::Absorbing(Shake::default())
16    }
17}
18
19impl<Shake: ExtendableOutput + Default + Clone> ShakeState<Shake> {
20    pub(crate) fn updatable(&mut self) -> &mut Shake {
21        match self {
22            Self::Absorbing(sponge) => sponge,
23            Self::Squeezing(_) => unreachable!(),
24        }
25    }
26
27    pub(crate) fn absorb(mut self, input: &[u8]) -> Self {
28        match &mut self {
29            Self::Absorbing(sponge) => sponge.update(input),
30            Self::Squeezing(_) => unreachable!(),
31        }
32
33        self
34    }
35
36    pub(crate) fn squeeze(&mut self, output: &mut [u8]) -> &mut Self {
37        match self {
38            Self::Absorbing(sponge) => {
39                // Clone required to satisfy borrow checker
40                let mut reader = sponge.clone().finalize_xof();
41                reader.read(output);
42                *self = Self::Squeezing(reader);
43            }
44            Self::Squeezing(reader) => {
45                reader.read(output.as_mut());
46            }
47        }
48
49        self
50    }
51
52    pub(crate) fn squeeze_new<N: ArraySize>(&mut self) -> Array<u8, N> {
53        let mut v = Array::default();
54        self.squeeze(&mut v);
55        v
56    }
57}
58
59pub(crate) type G = ShakeState<Shake128>;
60pub(crate) type H = ShakeState<Shake256>;
61
62#[cfg(test)]
63mod test {
64    use super::*;
65    use crate::B32;
66    use hex_literal::hex;
67
68    #[test]
69    fn g() {
70        let input = b"hello world";
71        let expected1 = hex!("3a9159f071e4dd1c8c4f968607c30942e120d8156b8b1e72e0d376e8871cb8b8");
72        let expected2 = hex!("99072665674f26cc494a4bcf027c58267e8ee2da60e942759de86d2670bba1aa");
73
74        let mut g = G::default().absorb(input);
75
76        let mut actual = [0u8; 32];
77        g.squeeze(&mut actual);
78        assert_eq!(actual, expected1);
79
80        let actual: B32 = g.squeeze_new();
81        assert_eq!(actual, expected2);
82    }
83
84    #[test]
85    fn h() {
86        let input = b"hello world";
87        let expected1 = hex!("369771bb2cb9d2b04c1d54cca487e372d9f187f73f7ba3f65b95c8ee7798c527");
88        let expected2 = hex!("f4f3c2d55c2d46a29f2e945d469c3df27853a8735271f5cc2d9e889544357116");
89
90        let mut h = H::default().absorb(input);
91
92        let mut actual = [0u8; 32];
93        h.squeeze(&mut actual);
94        assert_eq!(actual, expected1);
95
96        let actual: B32 = h.squeeze_new();
97        assert_eq!(actual, expected2);
98    }
99}