ml_kem/
crypto.rs

1#![allow(dead_code)]
2
3use hybrid_array::{Array, ArraySize};
4use rand_core::CryptoRngCore;
5use sha3::{
6    digest::{ExtendableOutput, Update, XofReader},
7    Digest, Sha3_256, Sha3_512, Shake128, Shake256,
8};
9
10use crate::param::{CbdSamplingSize, EncodedPolynomial};
11use crate::util::B32;
12
13pub fn rand<L: ArraySize>(rng: &mut impl CryptoRngCore) -> Array<u8, L> {
14    let mut val = Array::default();
15    rng.fill_bytes(&mut val);
16    val
17}
18
19pub fn G(inputs: &[impl AsRef<[u8]>]) -> (B32, B32) {
20    let mut h = Sha3_512::new();
21    for x in inputs {
22        Digest::update(&mut h, x);
23    }
24    let out = h.finalize();
25
26    let mut a = B32::default();
27    let mut b = B32::default();
28
29    a.copy_from_slice(&out[..32]);
30    b.copy_from_slice(&out[32..]);
31    (a, b)
32}
33
34pub fn H(x: impl AsRef<[u8]>) -> B32 {
35    let mut h = Sha3_256::new();
36    Digest::update(&mut h, x);
37
38    // This odd conversion is needed because the `sha3` crate links against an old version of
39    // the `generic-array` crate.  It should be pretty cheap though, since there's only one
40    // allocation / no copies.
41    let mut out = B32::default();
42    h.finalize_into(out.as_mut_slice().into());
43    out
44}
45
46pub fn J(inputs: &[impl AsRef<[u8]>]) -> B32 {
47    let mut h = Shake256::default();
48    for x in inputs {
49        h.update(x.as_ref());
50    }
51    let mut r = h.finalize_xof();
52
53    let mut out = B32::default();
54    r.read(&mut out);
55    out
56}
57
58pub type PrfOutput<Eta> = EncodedPolynomial<<Eta as CbdSamplingSize>::SampleSize>;
59
60pub fn PRF<Eta>(s: &B32, b: u8) -> PrfOutput<Eta>
61where
62    Eta: CbdSamplingSize,
63{
64    let mut h = Shake256::default();
65    h.update(s.as_ref());
66    h.update(&[b]);
67    let mut r = h.finalize_xof();
68
69    let mut out = PrfOutput::<Eta>::default();
70    r.read(&mut out);
71    out
72}
73
74pub fn XOF(rho: &B32, i: u8, j: u8) -> impl XofReader {
75    let mut h = Shake128::default();
76    h.update(rho);
77    h.update(&[i, j]);
78    h.finalize_xof()
79}
80
81// // A Go script to generate the test vector outputs
82//
83// package main
84//
85// import (
86// 	"fmt"
87// 	"golang.org/x/crypto/sha3"
88// )
89//
90// func main() {
91// 	// G: B* -> B32 || B32 = SHA3_512(c)
92//   msgG := []byte("Input to an invocation of G")
93//   hG := sha3.New512()
94//   hG.Write(msgG)
95//   fmt.Printf("G: %x\n", hG.Sum(nil))
96//
97//   // H: B* -> B32 = SHA3_256(s)
98//   msgH := []byte("Input to an invocation of H")
99//   hH := sha3.New256()
100//   hH.Write(msgH)
101//   fmt.Printf("H: %x\n", hH.Sum(nil))
102//
103//   // J: B* -> B32 = SHAKE256(s, 32)
104//   msgJ := []byte("Input to an invocation of J")
105//   outJ := make([]byte, 32)
106//   sha3.ShakeSum256(outJ, msgJ)
107//   fmt.Printf("J: %x\n", outJ)
108//
109//   // PRF<2>: B32 x B -> B64eta = SHAKE256(s || b, 64 * eta)
110//   msgPRF2s := []byte("Input s to an invocation of PRF2")
111//   msgPRF2b := []byte("b")
112//   msgPRF2 := append(msgPRF2s, msgPRF2b...)
113//   outPRF2 := make([]byte, 64 * 2)
114//   sha3.ShakeSum256(outPRF2, msgPRF2)
115//   fmt.Printf("PRF<2>: %x\n", outPRF2)
116//
117//   // PRF<3>: B33 x B -> B64eta = SHAKE256(s || b, 64 * eta)
118//   msgPRF3s := []byte("Input s to an invocation of PRF3")
119//   msgPRF3b := []byte("b")
120//   msgPRF3 := append(msgPRF3s, msgPRF3b...)
121//   outPRF3 := make([]byte, 64 * 3)
122//   sha3.ShakeSum256(outPRF3, msgPRF3)
123//   fmt.Printf("PRF<3>: %x\n", outPRF3)
124//
125//   // XOF: B32 x B x B -> B* = SHAKE128(rho || i || j)
126//   msgXOFrho := []byte("Input rho, to an XOF invocation!")
127//   msgXOFi := []byte("i")
128//   msgXOFj := []byte("j")
129//   msgXOF := append(append(msgXOFrho, msgXOFi...), msgXOFj...)
130//   outXOF := make([]byte, 32)
131//   sha3.ShakeSum128(outXOF, msgXOF)
132//   fmt.Printf("XOF: %x\n", outXOF)
133//
134// }
135
136#[cfg(test)]
137mod test {
138    use super::*;
139    use hex_literal::hex;
140    use hybrid_array::typenum::{U2, U3};
141
142    #[test]
143    fn g() {
144        let msg1 = "Input to ".as_bytes();
145        let msg2 = "an invocation of G".as_bytes();
146        let (actualA, actualB) = G(&[msg1, msg2]);
147        let expectedA = hex!("07dfced2a3a3feb3277cee1709818828ea6d2f42800152e9c312e848122231c2");
148        let expectedB = hex!("272969098a1bbd5a0a9844e2f89f206d8f7f4599e36aecaa4793af400fd880d8");
149        assert_eq!(actualA, expectedA);
150        assert_eq!(actualB, expectedB);
151    }
152
153    #[test]
154    fn h() {
155        let msg = "Input to an invocation of H".as_bytes();
156        let actual = H(msg);
157        let expected = hex!("0ee3ce94213d7dd0069b24b8b15cdd0bcf8eb1c6b3c21c441dc6a19e979cc7eb");
158        assert_eq!(actual, expected);
159    }
160
161    #[test]
162    fn j() {
163        let msg1 = "Input to ".as_bytes();
164        let msg2 = "an invocation of J".as_bytes();
165        let actual = J(&[msg1, msg2]);
166        let expected = hex!("a5292293d70c8eca049cbb475c48fabd625ed2b20785a18248504d3741196b52");
167        assert_eq!(actual, expected);
168    }
169
170    #[test]
171    fn prf() {
172        let s = B32::try_from("Input s to an invocation of PRF2".as_bytes())
173            .expect("Failed to create B32 from slice");
174        let b = b'b';
175        let actual = PRF::<U2>(&s, b);
176        let expected = hex!(
177            "54c002415c2219b564d5c17b0df0c82f83ddf3fdecc7d814ed5d85457c06c2c3\
178             ed0b0584f926dffb1e57c6105f8604e81c4605b93f8284e44585104101042075\
179             568113c861516d91bed227638654fc7f872df205c113b8364091755b62284eec\
180             a6124f2cd4c1cdf598cb8324a4f373470a8f81ee618c75cc33f66facee01c213"
181        );
182        assert_eq!(actual, expected);
183
184        let s = B32::try_from("Input s to an invocation of PRF3".as_bytes())
185            .expect("Failed to create B32 from slice");
186        let b = b'b';
187        let actual = PRF::<U3>(&s, b);
188        let expected = hex!(
189            "5e12028f67479b862a12713cda833e21b8ccd51bff9ddc2bfb9ab2910a9dc2e6\
190             c58264a3f51ccc9ef4ff936a15505e016f60c36ffe300be01b9fb12eacd57867\
191             0873c24709d6146b42c42a07873522eac100d61942ae53e73fbf9095b29b1ab7\
192             169e954213c062703dad88c1c5f57f92af143f0364fe057b134b54ea8a55d94c\
193             67764b3fc6b37376453978b8f0caeb6b18c188c28ee8681e28339477e042d5a1\
194             b4a12deb1de8b9dad026b4e323e03973ffbe25dd511eed5460d22a9851cfc220"
195        );
196        assert_eq!(actual, expected);
197    }
198
199    #[test]
200    fn xof() {
201        let rho = B32::try_from("Input rho, to an XOF invocation!".as_bytes())
202            .expect("Failed to create B32 from slice");
203        let i = b'i';
204        let j = b'j';
205
206        let mut reader = XOF(&rho, i, j);
207        let mut actual = [0u8; 32];
208        reader.read(&mut actual);
209
210        let expected = hex!("0d2c3e65f754d074cb366cf1b099ae105cc40f018342509f15f1ba8a1a4144cb");
211        assert_eq!(actual, expected);
212    }
213}