Skip to main content

ml_kem/
pke.rs

1use crate::B32;
2use crate::algebra::{
3    Ntt, NttInverse, NttMatrix, NttVector, Polynomial, Vector, matrix_sample_ntt, sample_poly_cbd,
4    sample_poly_vec_cbd,
5};
6use crate::compress::Compress;
7use crate::crypto::{G, PRF};
8use crate::param::{EncodedDecryptionKey, EncodedEncryptionKey, PkeParams};
9use array::typenum::{U1, Unsigned};
10use kem::{Ciphertext, InvalidKey};
11use module_lattice::{
12    Encode,
13    ctutils::{Choice, CtEq},
14};
15
16#[cfg(feature = "zeroize")]
17use zeroize::Zeroize;
18
19/// A `DecryptionKey` provides the ability to generate a new key pair, and decrypt an
20/// encrypted value.
21#[derive(Clone, Default, Debug)]
22pub(crate) struct DecryptionKey<P>
23where
24    P: PkeParams,
25{
26    s_hat: NttVector<P::K>,
27}
28
29impl<P> CtEq for DecryptionKey<P>
30where
31    P: PkeParams,
32{
33    fn ct_eq(&self, other: &Self) -> Choice {
34        self.s_hat.ct_eq(&other.s_hat)
35    }
36}
37
38impl<P> Eq for DecryptionKey<P> where P: PkeParams {}
39impl<P> PartialEq for DecryptionKey<P>
40where
41    P: PkeParams,
42{
43    fn eq(&self, other: &Self) -> bool {
44        // Compare decryption keys in constant-time
45        self.ct_eq(other).into()
46    }
47}
48
49#[cfg(feature = "zeroize")]
50impl<P> Zeroize for DecryptionKey<P>
51where
52    P: PkeParams,
53{
54    fn zeroize(&mut self) {
55        self.s_hat.zeroize();
56    }
57}
58
59impl<P> DecryptionKey<P>
60where
61    P: PkeParams,
62{
63    /// Generate a new random decryption key according to the `K-PKE.KeyGen` procedure.
64    // Algorithm 12. K-PKE.KeyGen()
65    pub(crate) fn generate(d: &B32) -> (Self, EncryptionKey<P>) {
66        // Generate random seeds
67        let k = P::K::U8;
68        let (rho, sigma) = G(&[&d[..], &[k]]);
69
70        // Sample pseudo-random matrix and vectors
71        let A_hat: NttMatrix<P::K> = matrix_sample_ntt(&rho, false);
72        let s: Vector<P::K> = sample_poly_vec_cbd::<P::Eta1, P::K>(&sigma, 0);
73        let e: Vector<P::K> = sample_poly_vec_cbd::<P::Eta1, P::K>(&sigma, P::K::U8);
74
75        // NTT the vectors
76        let s_hat = s.ntt();
77        let e_hat = e.ntt();
78
79        // Compute the public value
80        let t_hat = &(&A_hat * &s_hat) + &e_hat;
81
82        // Assemble the keys
83        let dk = DecryptionKey { s_hat };
84        let ek = EncryptionKey { t_hat, rho };
85        (dk, ek)
86    }
87
88    /// Decrypt ciphertext to obtain the encrypted value, according to the K-PKE.Decrypt procedure.
89    // Algorithm 14. kK-PKE.Decrypt(dk_PKE, c)
90    pub(crate) fn decrypt(&self, ciphertext: &Ciphertext<P>) -> B32 {
91        let (c1, c2) = P::split_ct(ciphertext);
92
93        let mut u: Vector<P::K> = Encode::<P::Du>::decode(c1);
94        u.decompress::<P::Du>();
95
96        let mut v: Polynomial = Encode::<P::Dv>::decode(c2);
97        v.decompress::<P::Dv>();
98
99        let u_hat = u.ntt();
100        let sTu = (&self.s_hat * &u_hat).ntt_inverse();
101        let mut w = &v - &sTu;
102        Encode::<U1>::encode(w.compress::<U1>())
103    }
104
105    /// Represent this decryption key as a byte array `(s_hat)`
106    pub(crate) fn to_bytes(&self) -> EncodedDecryptionKey<P> {
107        P::encode_u12(&self.s_hat)
108    }
109
110    /// Parse an decryption key from a byte array `(s_hat)`
111    pub(crate) fn from_bytes(enc: &EncodedDecryptionKey<P>) -> Self {
112        let s_hat = P::decode_u12(enc);
113        Self { s_hat }
114    }
115}
116
117/// An `EncryptionKey` provides the ability to encrypt a value so that it can only be
118/// decrypted by the holder of the corresponding decapsulation key.
119#[derive(Clone, Default, Debug, Eq, PartialEq)]
120pub(crate) struct EncryptionKey<P>
121where
122    P: PkeParams,
123{
124    t_hat: NttVector<P::K>,
125    rho: B32,
126}
127
128impl<P> EncryptionKey<P>
129where
130    P: PkeParams,
131{
132    /// Encrypt the specified message for the holder of the corresponding decryption key, using the
133    /// provided randomness, according the `K-PKE.Encrypt` procedure.
134    pub(crate) fn encrypt(&self, message: &B32, randomness: &B32) -> Ciphertext<P> {
135        let r = sample_poly_vec_cbd::<P::Eta1, P::K>(randomness, 0);
136        let e1 = sample_poly_vec_cbd::<P::Eta2, P::K>(randomness, P::K::U8);
137
138        let prf_output = PRF::<P::Eta2>(randomness, 2 * P::K::U8);
139        let e2: Polynomial = sample_poly_cbd::<P::Eta2>(&prf_output);
140
141        let A_hat_t: NttMatrix<P::K> = matrix_sample_ntt(&self.rho, true);
142        let r_hat: NttVector<P::K> = r.ntt();
143        let ATr: Vector<P::K> = (&A_hat_t * &r_hat).ntt_inverse();
144        let mut u = ATr + e1;
145
146        let mut mu: Polynomial = Encode::<U1>::decode(message);
147        mu.decompress::<U1>();
148
149        let tTr: Polynomial = (&self.t_hat * &r_hat).ntt_inverse();
150        let mut v = &(&tTr + &e2) + &mu;
151
152        let c1 = Encode::<P::Du>::encode(u.compress::<P::Du>());
153        let c2 = Encode::<P::Dv>::encode(v.compress::<P::Dv>());
154        P::concat_ct(c1, c2)
155    }
156
157    /// Represent this encryption key as a byte array `(t_hat || rho)`
158    pub(crate) fn to_bytes(&self) -> EncodedEncryptionKey<P> {
159        let t_hat = P::encode_u12(&self.t_hat);
160        P::concat_ek(t_hat, self.rho.clone())
161    }
162
163    /// Parse an encryption key from a byte array `(t_hat || rho)`.
164    ///
165    /// # Errors
166    /// Returns [`InvalidKey`] in the event that the key fails the encapsulation key checks
167    /// specified in FIPS 203 §7.2.
168    pub(crate) fn from_bytes(enc: &EncodedEncryptionKey<P>) -> Result<Self, InvalidKey> {
169        let (t_hat, rho) = P::split_ek(enc);
170        let t_hat = P::decode_u12(t_hat);
171        let ret = Self {
172            t_hat,
173            rho: rho.clone(),
174        };
175
176        // Check the candidate encapsulation key is valid using the method specified in FIPS 203
177        // §7.2 ML-KEM Encapsulation:
178        //
179        // > Encapsulation key check. To check a candidate encapsulation key `ek`, perform the
180        // > following:
181        // >
182        // > 1. (Type check) If `ek` is not an array of bytes of length 384𝑘+32 for the value of 𝑘
183        // >    specified by the relevant parameter set, then input checking failed.
184        // > 2. (Modulus check) Perform the computation:
185        // >
186        // >    test ← ByteEncode₁₂(ByteDecode₁₂(ek[0:384𝑘]))
187        // >
188        // >    (see Section 4.2.1). If `test ≠ ek[0∶384𝑘]`, then input checking failed. This
189        // >    check ensures that the integers encoded in the public key are in the valid range
190        // >    `[0,q-1]`.
191        // >
192        // > If both checks pass, then `ML-KEM.Encaps` can be run with input `ek`. It is important
193        // > to note that this checking process does not guarantee that ek is a properly produced
194        // > output of `ML-KEM.KeyGen`.
195        // >
196        // > `ML-KEM.Encaps` shall not be run with an encapsulation key that has not been checked as
197        // > above.
198        //
199        // #1 is performed by the `EncodedEncryptionKey` type, and the following check vicariously
200        // performs #2 by encoding the integer-mod-q array using our implementation of ByteEncode₁₂
201        // and comparing the resulting serialization to see if it round-trips.
202        if &ret.to_bytes() == enc {
203            Ok(ret)
204        } else {
205            Err(InvalidKey)
206        }
207    }
208}
209
210#[cfg(test)]
211mod test {
212    use super::*;
213    use crate::{MlKem512, MlKem768, MlKem1024};
214    use ::kem::Generate;
215    use getrandom::{SysRng, rand_core::UnwrapErr};
216
217    fn round_trip_test<P>()
218    where
219        P: PkeParams,
220    {
221        let mut rng = UnwrapErr(SysRng);
222        let d = B32::generate_from_rng(&mut rng);
223        let original = B32::default();
224        let randomness = B32::default();
225
226        let (dk, ek) = DecryptionKey::<P>::generate(&d);
227        let encrypted = ek.encrypt(&original, &randomness);
228        let decrypted = dk.decrypt(&encrypted);
229        assert_eq!(original, decrypted);
230    }
231
232    #[test]
233    fn round_trip() {
234        round_trip_test::<MlKem512>();
235        round_trip_test::<MlKem768>();
236        round_trip_test::<MlKem1024>();
237    }
238
239    fn codec_test<P>()
240    where
241        P: PkeParams,
242    {
243        let mut rng = UnwrapErr(SysRng);
244        let d = B32::generate_from_rng(&mut rng);
245        let (dk_original, ek_original) = DecryptionKey::<P>::generate(&d);
246
247        let dk_encoded = dk_original.to_bytes();
248        let dk_decoded = DecryptionKey::from_bytes(&dk_encoded);
249        assert_eq!(dk_original, dk_decoded);
250
251        let ek_encoded = ek_original.to_bytes();
252        let ek_decoded = EncryptionKey::from_bytes(&ek_encoded).unwrap();
253        assert_eq!(ek_original, ek_decoded);
254    }
255
256    #[test]
257    fn codec() {
258        codec_test::<MlKem512>();
259        codec_test::<MlKem768>();
260        codec_test::<MlKem1024>();
261    }
262
263    #[test]
264    fn reject_invalid_encryption_keys() {
265        // Create an invalid key: all bytes set to 0xFF
266        // When decoded as 12-bit coefficients, this produces values of 0xFFF = 4095 > 3329
267        let invalid_key = [0xFF; 1184];
268        assert!(EncryptionKey::<MlKem768>::from_bytes(&invalid_key.into()).is_err());
269    }
270
271    fn key_inequality_test<P>()
272    where
273        P: PkeParams,
274    {
275        let mut rng = UnwrapErr(SysRng);
276        let d1 = B32::generate_from_rng(&mut rng);
277        let d2 = B32::generate_from_rng(&mut rng);
278
279        let (dk1, _) = DecryptionKey::<P>::generate(&d1);
280        let (dk2, _) = DecryptionKey::<P>::generate(&d2);
281
282        // Verify inequality (catches PartialEq mutation that returns true unconditionally)
283        assert_ne!(dk1, dk2);
284    }
285
286    #[test]
287    fn key_inequality() {
288        key_inequality_test::<MlKem512>();
289        key_inequality_test::<MlKem768>();
290        key_inequality_test::<MlKem1024>();
291    }
292}