1use crate::B32;
2use crate::algebra::{
3 Ntt, NttInverse, NttMatrix, NttVector, Polynomial, Vector, matrix_sample_ntt, sample_poly_cbd,
4 sample_poly_vec_cbd,
5};
6use crate::compress::Compress;
7use crate::crypto::{G, PRF};
8use crate::param::{EncodedDecryptionKey, EncodedEncryptionKey, PkeParams};
9use array::typenum::{U1, Unsigned};
10use kem::{Ciphertext, InvalidKey};
11use module_lattice::{
12 Encode,
13 ctutils::{Choice, CtEq},
14};
15
16#[cfg(feature = "zeroize")]
17use zeroize::Zeroize;
18
19#[derive(Clone, Default, Debug)]
22pub(crate) struct DecryptionKey<P>
23where
24 P: PkeParams,
25{
26 s_hat: NttVector<P::K>,
27}
28
29impl<P> CtEq for DecryptionKey<P>
30where
31 P: PkeParams,
32{
33 fn ct_eq(&self, other: &Self) -> Choice {
34 self.s_hat.ct_eq(&other.s_hat)
35 }
36}
37
38impl<P> Eq for DecryptionKey<P> where P: PkeParams {}
39impl<P> PartialEq for DecryptionKey<P>
40where
41 P: PkeParams,
42{
43 fn eq(&self, other: &Self) -> bool {
44 self.ct_eq(other).into()
46 }
47}
48
49#[cfg(feature = "zeroize")]
50impl<P> Zeroize for DecryptionKey<P>
51where
52 P: PkeParams,
53{
54 fn zeroize(&mut self) {
55 self.s_hat.zeroize();
56 }
57}
58
59impl<P> DecryptionKey<P>
60where
61 P: PkeParams,
62{
63 pub(crate) fn generate(d: &B32) -> (Self, EncryptionKey<P>) {
66 let k = P::K::U8;
68 let (rho, sigma) = G(&[&d[..], &[k]]);
69
70 let A_hat: NttMatrix<P::K> = matrix_sample_ntt(&rho, false);
72 let s: Vector<P::K> = sample_poly_vec_cbd::<P::Eta1, P::K>(&sigma, 0);
73 let e: Vector<P::K> = sample_poly_vec_cbd::<P::Eta1, P::K>(&sigma, P::K::U8);
74
75 let s_hat = s.ntt();
77 let e_hat = e.ntt();
78
79 let t_hat = &(&A_hat * &s_hat) + &e_hat;
81
82 let dk = DecryptionKey { s_hat };
84 let ek = EncryptionKey { t_hat, rho };
85 (dk, ek)
86 }
87
88 pub(crate) fn decrypt(&self, ciphertext: &Ciphertext<P>) -> B32 {
91 let (c1, c2) = P::split_ct(ciphertext);
92
93 let mut u: Vector<P::K> = Encode::<P::Du>::decode(c1);
94 u.decompress::<P::Du>();
95
96 let mut v: Polynomial = Encode::<P::Dv>::decode(c2);
97 v.decompress::<P::Dv>();
98
99 let u_hat = u.ntt();
100 let sTu = (&self.s_hat * &u_hat).ntt_inverse();
101 let mut w = &v - &sTu;
102 Encode::<U1>::encode(w.compress::<U1>())
103 }
104
105 pub(crate) fn to_bytes(&self) -> EncodedDecryptionKey<P> {
107 P::encode_u12(&self.s_hat)
108 }
109
110 pub(crate) fn from_bytes(enc: &EncodedDecryptionKey<P>) -> Self {
112 let s_hat = P::decode_u12(enc);
113 Self { s_hat }
114 }
115}
116
117#[derive(Clone, Default, Debug, Eq, PartialEq)]
120pub(crate) struct EncryptionKey<P>
121where
122 P: PkeParams,
123{
124 t_hat: NttVector<P::K>,
125 rho: B32,
126}
127
128impl<P> EncryptionKey<P>
129where
130 P: PkeParams,
131{
132 pub(crate) fn encrypt(&self, message: &B32, randomness: &B32) -> Ciphertext<P> {
135 let r = sample_poly_vec_cbd::<P::Eta1, P::K>(randomness, 0);
136 let e1 = sample_poly_vec_cbd::<P::Eta2, P::K>(randomness, P::K::U8);
137
138 let prf_output = PRF::<P::Eta2>(randomness, 2 * P::K::U8);
139 let e2: Polynomial = sample_poly_cbd::<P::Eta2>(&prf_output);
140
141 let A_hat_t: NttMatrix<P::K> = matrix_sample_ntt(&self.rho, true);
142 let r_hat: NttVector<P::K> = r.ntt();
143 let ATr: Vector<P::K> = (&A_hat_t * &r_hat).ntt_inverse();
144 let mut u = ATr + e1;
145
146 let mut mu: Polynomial = Encode::<U1>::decode(message);
147 mu.decompress::<U1>();
148
149 let tTr: Polynomial = (&self.t_hat * &r_hat).ntt_inverse();
150 let mut v = &(&tTr + &e2) + μ
151
152 let c1 = Encode::<P::Du>::encode(u.compress::<P::Du>());
153 let c2 = Encode::<P::Dv>::encode(v.compress::<P::Dv>());
154 P::concat_ct(c1, c2)
155 }
156
157 pub(crate) fn to_bytes(&self) -> EncodedEncryptionKey<P> {
159 let t_hat = P::encode_u12(&self.t_hat);
160 P::concat_ek(t_hat, self.rho.clone())
161 }
162
163 pub(crate) fn from_bytes(enc: &EncodedEncryptionKey<P>) -> Result<Self, InvalidKey> {
169 let (t_hat, rho) = P::split_ek(enc);
170 let t_hat = P::decode_u12(t_hat);
171 let ret = Self {
172 t_hat,
173 rho: rho.clone(),
174 };
175
176 if &ret.to_bytes() == enc {
203 Ok(ret)
204 } else {
205 Err(InvalidKey)
206 }
207 }
208}
209
210#[cfg(test)]
211mod test {
212 use super::*;
213 use crate::{MlKem512, MlKem768, MlKem1024};
214 use ::kem::Generate;
215 use getrandom::{SysRng, rand_core::UnwrapErr};
216
217 fn round_trip_test<P>()
218 where
219 P: PkeParams,
220 {
221 let mut rng = UnwrapErr(SysRng);
222 let d = B32::generate_from_rng(&mut rng);
223 let original = B32::default();
224 let randomness = B32::default();
225
226 let (dk, ek) = DecryptionKey::<P>::generate(&d);
227 let encrypted = ek.encrypt(&original, &randomness);
228 let decrypted = dk.decrypt(&encrypted);
229 assert_eq!(original, decrypted);
230 }
231
232 #[test]
233 fn round_trip() {
234 round_trip_test::<MlKem512>();
235 round_trip_test::<MlKem768>();
236 round_trip_test::<MlKem1024>();
237 }
238
239 fn codec_test<P>()
240 where
241 P: PkeParams,
242 {
243 let mut rng = UnwrapErr(SysRng);
244 let d = B32::generate_from_rng(&mut rng);
245 let (dk_original, ek_original) = DecryptionKey::<P>::generate(&d);
246
247 let dk_encoded = dk_original.to_bytes();
248 let dk_decoded = DecryptionKey::from_bytes(&dk_encoded);
249 assert_eq!(dk_original, dk_decoded);
250
251 let ek_encoded = ek_original.to_bytes();
252 let ek_decoded = EncryptionKey::from_bytes(&ek_encoded).unwrap();
253 assert_eq!(ek_original, ek_decoded);
254 }
255
256 #[test]
257 fn codec() {
258 codec_test::<MlKem512>();
259 codec_test::<MlKem768>();
260 codec_test::<MlKem1024>();
261 }
262
263 #[test]
264 fn reject_invalid_encryption_keys() {
265 let invalid_key = [0xFF; 1184];
268 assert!(EncryptionKey::<MlKem768>::from_bytes(&invalid_key.into()).is_err());
269 }
270
271 fn key_inequality_test<P>()
272 where
273 P: PkeParams,
274 {
275 let mut rng = UnwrapErr(SysRng);
276 let d1 = B32::generate_from_rng(&mut rng);
277 let d2 = B32::generate_from_rng(&mut rng);
278
279 let (dk1, _) = DecryptionKey::<P>::generate(&d1);
280 let (dk2, _) = DecryptionKey::<P>::generate(&d2);
281
282 assert_ne!(dk1, dk2);
284 }
285
286 #[test]
287 fn key_inequality() {
288 key_inequality_test::<MlKem512>();
289 key_inequality_test::<MlKem768>();
290 key_inequality_test::<MlKem1024>();
291 }
292}