1use hybrid_array::Array;
2use sha3::{
3 Shake128, Shake256,
4 digest::{ExtendableOutput, XofReader},
5};
6
7use crate::module_lattice::encode::ArraySize;
8
9pub enum ShakeState<Shake: ExtendableOutput> {
10 Absorbing(Shake),
11 Squeezing(Shake::Reader),
12}
13
14impl<Shake: ExtendableOutput + Default> Default for ShakeState<Shake> {
15 fn default() -> Self {
16 Self::Absorbing(Shake::default())
17 }
18}
19
20impl<Shake: ExtendableOutput + Default + Clone> ShakeState<Shake> {
21 pub fn absorb(mut self, input: &[u8]) -> Self {
22 match &mut self {
23 Self::Absorbing(sponge) => sponge.update(input),
24 Self::Squeezing(_) => unreachable!(),
25 }
26
27 self
28 }
29
30 pub fn squeeze(&mut self, output: &mut [u8]) -> &mut Self {
31 match self {
32 Self::Absorbing(sponge) => {
33 let mut reader = sponge.clone().finalize_xof();
35 reader.read(output);
36 *self = Self::Squeezing(reader);
37 }
38 Self::Squeezing(reader) => {
39 reader.read(output.as_mut());
40 }
41 }
42
43 self
44 }
45
46 pub fn squeeze_new<N: ArraySize>(&mut self) -> Array<u8, N> {
47 let mut v = Array::default();
48 self.squeeze(&mut v);
49 v
50 }
51}
52
53pub type G = ShakeState<Shake128>;
54pub type H = ShakeState<Shake256>;
55
56#[cfg(test)]
57mod test {
58 use super::*;
59 use crate::util::B32;
60 use hex_literal::hex;
61
62 #[test]
63 fn g() {
64 let input = b"hello world";
65 let expected1 = hex!("3a9159f071e4dd1c8c4f968607c30942e120d8156b8b1e72e0d376e8871cb8b8");
66 let expected2 = hex!("99072665674f26cc494a4bcf027c58267e8ee2da60e942759de86d2670bba1aa");
67
68 let mut g = G::default().absorb(input);
69
70 let mut actual = [0u8; 32];
71 g.squeeze(&mut actual);
72 assert_eq!(actual, expected1);
73
74 let actual: B32 = g.squeeze_new();
75 assert_eq!(actual, expected2);
76 }
77
78 #[test]
79 fn h() {
80 let input = b"hello world";
81 let expected1 = hex!("369771bb2cb9d2b04c1d54cca487e372d9f187f73f7ba3f65b95c8ee7798c527");
82 let expected2 = hex!("f4f3c2d55c2d46a29f2e945d469c3df27853a8735271f5cc2d9e889544357116");
83
84 let mut h = H::default().absorb(input);
85
86 let mut actual = [0u8; 32];
87 h.squeeze(&mut actual);
88 assert_eq!(actual, expected1);
89
90 let actual: B32 = h.squeeze_new();
91 assert_eq!(actual, expected2);
92 }
93}