Skip to main content

ml_kem/
crypto.rs

1#![allow(dead_code)]
2
3use crate::{B32, param::CbdSamplingSize};
4use module_lattice::EncodedPolynomial;
5use sha3::{
6    Digest, Sha3_256, Sha3_512, Shake128, Shake256,
7    digest::{ExtendableOutput, Update, XofReader},
8};
9
10pub(crate) fn G(inputs: &[impl AsRef<[u8]>]) -> (B32, B32) {
11    let mut h = Sha3_512::new();
12    for x in inputs {
13        Digest::update(&mut h, x);
14    }
15    let out = h.finalize();
16
17    let mut a = B32::default();
18    let mut b = B32::default();
19
20    a.copy_from_slice(&out[..32]);
21    b.copy_from_slice(&out[32..]);
22    (a, b)
23}
24
25pub(crate) fn H(x: impl AsRef<[u8]>) -> B32 {
26    let mut h = Sha3_256::new();
27    Digest::update(&mut h, x);
28
29    // This odd conversion is needed because the `sha3` crate links against an old version of
30    // the `generic-array` crate.  It should be pretty cheap though, since there's only one
31    // allocation / no copies.
32    let mut out = B32::default();
33    h.finalize_into(&mut out);
34    out
35}
36
37pub(crate) fn J(inputs: &[impl AsRef<[u8]>]) -> B32 {
38    let mut h = Shake256::default();
39    for x in inputs {
40        h.update(x.as_ref());
41    }
42    let mut r = h.finalize_xof();
43
44    let mut out = B32::default();
45    r.read(&mut out);
46    out
47}
48
49pub(crate) type PrfOutput<Eta> = EncodedPolynomial<<Eta as CbdSamplingSize>::SampleSize>;
50
51pub(crate) fn PRF<Eta>(s: &B32, b: u8) -> PrfOutput<Eta>
52where
53    Eta: CbdSamplingSize,
54{
55    let mut h = Shake256::default();
56    h.update(s.as_ref());
57    h.update(&[b]);
58    let mut r = h.finalize_xof();
59
60    let mut out = PrfOutput::<Eta>::default();
61    r.read(&mut out);
62    out
63}
64
65pub(crate) fn XOF(rho: &B32, i: u8, j: u8) -> impl XofReader {
66    let mut h = Shake128::default();
67    h.update(rho);
68    h.update(&[i, j]);
69    h.finalize_xof()
70}
71
72// // A Go script to generate the test vector outputs
73//
74// package main
75//
76// import (
77// 	"fmt"
78// 	"golang.org/x/crypto/sha3"
79// )
80//
81// func main() {
82// 	// G: B* -> B32 || B32 = SHA3_512(c)
83//   msgG := []byte("Input to an invocation of G")
84//   hG := sha3.New512()
85//   hG.Write(msgG)
86//   fmt.Printf("G: %x\n", hG.Sum(nil))
87//
88//   // H: B* -> B32 = SHA3_256(s)
89//   msgH := []byte("Input to an invocation of H")
90//   hH := sha3.New256()
91//   hH.Write(msgH)
92//   fmt.Printf("H: %x\n", hH.Sum(nil))
93//
94//   // J: B* -> B32 = SHAKE256(s, 32)
95//   msgJ := []byte("Input to an invocation of J")
96//   outJ := make([]byte, 32)
97//   sha3.ShakeSum256(outJ, msgJ)
98//   fmt.Printf("J: %x\n", outJ)
99//
100//   // PRF<2>: B32 x B -> B64eta = SHAKE256(s || b, 64 * eta)
101//   msgPRF2s := []byte("Input s to an invocation of PRF2")
102//   msgPRF2b := []byte("b")
103//   msgPRF2 := append(msgPRF2s, msgPRF2b...)
104//   outPRF2 := make([]byte, 64 * 2)
105//   sha3.ShakeSum256(outPRF2, msgPRF2)
106//   fmt.Printf("PRF<2>: %x\n", outPRF2)
107//
108//   // PRF<3>: B33 x B -> B64eta = SHAKE256(s || b, 64 * eta)
109//   msgPRF3s := []byte("Input s to an invocation of PRF3")
110//   msgPRF3b := []byte("b")
111//   msgPRF3 := append(msgPRF3s, msgPRF3b...)
112//   outPRF3 := make([]byte, 64 * 3)
113//   sha3.ShakeSum256(outPRF3, msgPRF3)
114//   fmt.Printf("PRF<3>: %x\n", outPRF3)
115//
116//   // XOF: B32 x B x B -> B* = SHAKE128(rho || i || j)
117//   msgXOFrho := []byte("Input rho, to an XOF invocation!")
118//   msgXOFi := []byte("i")
119//   msgXOFj := []byte("j")
120//   msgXOF := append(append(msgXOFrho, msgXOFi...), msgXOFj...)
121//   outXOF := make([]byte, 32)
122//   sha3.ShakeSum128(outXOF, msgXOF)
123//   fmt.Printf("XOF: %x\n", outXOF)
124//
125// }
126
127#[cfg(test)]
128mod test {
129    use super::*;
130    use array::typenum::{U2, U3};
131    use hex_literal::hex;
132
133    #[test]
134    fn g() {
135        let msg1 = "Input to ".as_bytes();
136        let msg2 = "an invocation of G".as_bytes();
137        let (actualA, actualB) = G(&[msg1, msg2]);
138        let expectedA = hex!("07dfced2a3a3feb3277cee1709818828ea6d2f42800152e9c312e848122231c2");
139        let expectedB = hex!("272969098a1bbd5a0a9844e2f89f206d8f7f4599e36aecaa4793af400fd880d8");
140        assert_eq!(actualA, expectedA);
141        assert_eq!(actualB, expectedB);
142    }
143
144    #[test]
145    fn h() {
146        let msg = "Input to an invocation of H".as_bytes();
147        let actual = H(msg);
148        let expected = hex!("0ee3ce94213d7dd0069b24b8b15cdd0bcf8eb1c6b3c21c441dc6a19e979cc7eb");
149        assert_eq!(actual, expected);
150    }
151
152    #[test]
153    fn j() {
154        let msg1 = "Input to ".as_bytes();
155        let msg2 = "an invocation of J".as_bytes();
156        let actual = J(&[msg1, msg2]);
157        let expected = hex!("a5292293d70c8eca049cbb475c48fabd625ed2b20785a18248504d3741196b52");
158        assert_eq!(actual, expected);
159    }
160
161    #[test]
162    fn prf() {
163        let s = B32::try_from("Input s to an invocation of PRF2".as_bytes())
164            .expect("Failed to create B32 from slice");
165        let b = b'b';
166        let actual = PRF::<U2>(&s, b);
167        let expected = hex!(
168            "54c002415c2219b564d5c17b0df0c82f83ddf3fdecc7d814ed5d85457c06c2c3\
169             ed0b0584f926dffb1e57c6105f8604e81c4605b93f8284e44585104101042075\
170             568113c861516d91bed227638654fc7f872df205c113b8364091755b62284eec\
171             a6124f2cd4c1cdf598cb8324a4f373470a8f81ee618c75cc33f66facee01c213"
172        );
173        assert_eq!(actual, expected);
174
175        let s = B32::try_from("Input s to an invocation of PRF3".as_bytes())
176            .expect("Failed to create B32 from slice");
177        let b = b'b';
178        let actual = PRF::<U3>(&s, b);
179        let expected = hex!(
180            "5e12028f67479b862a12713cda833e21b8ccd51bff9ddc2bfb9ab2910a9dc2e6\
181             c58264a3f51ccc9ef4ff936a15505e016f60c36ffe300be01b9fb12eacd57867\
182             0873c24709d6146b42c42a07873522eac100d61942ae53e73fbf9095b29b1ab7\
183             169e954213c062703dad88c1c5f57f92af143f0364fe057b134b54ea8a55d94c\
184             67764b3fc6b37376453978b8f0caeb6b18c188c28ee8681e28339477e042d5a1\
185             b4a12deb1de8b9dad026b4e323e03973ffbe25dd511eed5460d22a9851cfc220"
186        );
187        assert_eq!(actual, expected);
188    }
189
190    #[test]
191    fn xof() {
192        let rho = B32::try_from("Input rho, to an XOF invocation!".as_bytes())
193            .expect("Failed to create B32 from slice");
194        let i = b'i';
195        let j = b'j';
196
197        let mut reader = XOF(&rho, i, j);
198        let mut actual = [0u8; 32];
199        reader.read(&mut actual);
200
201        let expected = hex!("0d2c3e65f754d074cb366cf1b099ae105cc40f018342509f15f1ba8a1a4144cb");
202        assert_eq!(actual, expected);
203    }
204}