ml_kem/
pke.rs

1use hybrid_array::typenum::{Unsigned, U1};
2
3use crate::algebra::{NttMatrix, NttVector, Polynomial, PolynomialVector};
4use crate::compress::Compress;
5use crate::crypto::{G, PRF};
6use crate::encode::Encode;
7use crate::param::{EncodedCiphertext, EncodedDecryptionKey, EncodedEncryptionKey, PkeParams};
8use crate::util::B32;
9
10#[cfg(feature = "zeroize")]
11use zeroize::Zeroize;
12
13/// A `DecryptionKey` provides the ability to generate a new key pair, and decrypt an
14/// encrypted value.
15#[derive(Clone, Default, Debug, PartialEq)]
16pub struct DecryptionKey<P>
17where
18    P: PkeParams,
19{
20    s_hat: NttVector<P::K>,
21}
22
23#[cfg(feature = "zeroize")]
24impl<P> Zeroize for DecryptionKey<P>
25where
26    P: PkeParams,
27{
28    fn zeroize(&mut self) {
29        self.s_hat.zeroize();
30    }
31}
32
33impl<P> DecryptionKey<P>
34where
35    P: PkeParams,
36{
37    /// Generate a new random decryption key according to the `K-PKE.KeyGen` procedure.
38    // Algorithm 12. K-PKE.KeyGen()
39    pub fn generate(d: &B32) -> (Self, EncryptionKey<P>) {
40        // Generate random seeds
41        let k = P::K::U8;
42        let (rho, sigma) = G(&[&d[..], &[k]]);
43
44        // Sample pseudo-random matrix and vectors
45        let A_hat: NttMatrix<P::K> = NttMatrix::sample_uniform(&rho, false);
46        let s: PolynomialVector<P::K> = PolynomialVector::sample_cbd::<P::Eta1>(&sigma, 0);
47        let e: PolynomialVector<P::K> = PolynomialVector::sample_cbd::<P::Eta1>(&sigma, P::K::U8);
48
49        // NTT the vectors
50        let s_hat = s.ntt();
51        let e_hat = e.ntt();
52
53        // Compute the public value
54        let t_hat = &(&A_hat * &s_hat) + &e_hat;
55
56        // Assemble the keys
57        let dk = DecryptionKey { s_hat };
58        let ek = EncryptionKey { t_hat, rho };
59        (dk, ek)
60    }
61
62    /// Decrypt ciphertext to obtain the encrypted value, according to the K-PKE.Decrypt procedure.
63    // Algorithm 14. kK-PKE.Decrypt(dk_PKE, c)
64    pub fn decrypt(&self, ciphertext: &EncodedCiphertext<P>) -> B32 {
65        let (c1, c2) = P::split_ct(ciphertext);
66
67        let mut u: PolynomialVector<P::K> = Encode::<P::Du>::decode(c1);
68        u.decompress::<P::Du>();
69
70        let mut v: Polynomial = Encode::<P::Dv>::decode(c2);
71        v.decompress::<P::Dv>();
72
73        let u_hat = u.ntt();
74        let sTu = (&self.s_hat * &u_hat).ntt_inverse();
75        let mut w = &v - &sTu;
76        Encode::<U1>::encode(w.compress::<U1>())
77    }
78
79    /// Represent this decryption key as a byte array `(s_hat)`
80    pub fn as_bytes(&self) -> EncodedDecryptionKey<P> {
81        P::encode_u12(&self.s_hat)
82    }
83
84    /// Parse an decryption key from a byte array `(s_hat)`
85    pub fn from_bytes(enc: &EncodedDecryptionKey<P>) -> Self {
86        let s_hat = P::decode_u12(enc);
87        Self { s_hat }
88    }
89}
90
91/// An `EncryptionKey` provides the ability to encrypt a value so that it can only be
92/// decrypted by the holder of the corresponding decapsulation key.
93#[derive(Clone, Default, Debug, PartialEq)]
94pub struct EncryptionKey<P>
95where
96    P: PkeParams,
97{
98    t_hat: NttVector<P::K>,
99    rho: B32,
100}
101
102impl<P> EncryptionKey<P>
103where
104    P: PkeParams,
105{
106    /// Encrypt the specified message for the holder of the corresponding decryption key, using the
107    /// provided randomness, according the `K-PKE.Encrypt` procedure.
108    pub fn encrypt(&self, message: &B32, randomness: &B32) -> EncodedCiphertext<P> {
109        let r = PolynomialVector::<P::K>::sample_cbd::<P::Eta1>(randomness, 0);
110        let e1 = PolynomialVector::<P::K>::sample_cbd::<P::Eta2>(randomness, P::K::U8);
111
112        let prf_output = PRF::<P::Eta2>(randomness, 2 * P::K::U8);
113        let e2: Polynomial = Polynomial::sample_cbd::<P::Eta2>(&prf_output);
114
115        let A_hat_t = NttMatrix::<P::K>::sample_uniform(&self.rho, true);
116        let r_hat: NttVector<P::K> = r.ntt();
117        let ATr: PolynomialVector<P::K> = (&A_hat_t * &r_hat).ntt_inverse();
118        let mut u = ATr + e1;
119
120        let mut mu: Polynomial = Encode::<U1>::decode(message);
121        mu.decompress::<U1>();
122
123        let tTr: Polynomial = (&self.t_hat * &r_hat).ntt_inverse();
124        let mut v = &(&tTr + &e2) + &mu;
125
126        let c1 = Encode::<P::Du>::encode(u.compress::<P::Du>());
127        let c2 = Encode::<P::Dv>::encode(v.compress::<P::Dv>());
128        P::concat_ct(c1, c2)
129    }
130
131    /// Represent this encryption key as a byte array `(t_hat || rho)`
132    pub fn as_bytes(&self) -> EncodedEncryptionKey<P> {
133        let t_hat = P::encode_u12(&self.t_hat);
134        P::concat_ek(t_hat, self.rho.clone())
135    }
136
137    /// Parse an encryption key from a byte array `(t_hat || rho)`
138    pub fn from_bytes(enc: &EncodedEncryptionKey<P>) -> Self {
139        let (t_hat, rho) = P::split_ek(enc);
140        let t_hat = P::decode_u12(t_hat);
141        Self {
142            t_hat,
143            rho: rho.clone(),
144        }
145    }
146}
147
148#[cfg(test)]
149mod test {
150    use super::*;
151    use crate::crypto::rand;
152    use crate::{MlKem1024Params, MlKem512Params, MlKem768Params};
153
154    fn round_trip_test<P>()
155    where
156        P: PkeParams,
157    {
158        let mut rng = rand::thread_rng();
159        let d: B32 = rand(&mut rng);
160        let original = B32::default();
161        let randomness = B32::default();
162
163        let (dk, ek) = DecryptionKey::<P>::generate(&d);
164        let encrypted = ek.encrypt(&original, &randomness);
165        let decrypted = dk.decrypt(&encrypted);
166        assert_eq!(original, decrypted);
167    }
168
169    #[test]
170    fn round_trip() {
171        round_trip_test::<MlKem512Params>();
172        round_trip_test::<MlKem768Params>();
173        round_trip_test::<MlKem1024Params>();
174    }
175
176    fn codec_test<P>()
177    where
178        P: PkeParams,
179    {
180        let mut rng = rand::thread_rng();
181        let d: B32 = rand(&mut rng);
182        let (dk_original, ek_original) = DecryptionKey::<P>::generate(&d);
183
184        let dk_encoded = dk_original.as_bytes();
185        let dk_decoded = DecryptionKey::from_bytes(&dk_encoded);
186        assert_eq!(dk_original, dk_decoded);
187
188        let ek_encoded = ek_original.as_bytes();
189        let ek_decoded = EncryptionKey::from_bytes(&ek_encoded);
190        assert_eq!(ek_original, ek_decoded);
191    }
192
193    #[test]
194    fn codec() {
195        codec_test::<MlKem512Params>();
196        codec_test::<MlKem768Params>();
197        codec_test::<MlKem1024Params>();
198    }
199}