1use hybrid_array::typenum::{Unsigned, U1};
2
3use crate::algebra::{NttMatrix, NttVector, Polynomial, PolynomialVector};
4use crate::compress::Compress;
5use crate::crypto::{G, PRF};
6use crate::encode::Encode;
7use crate::param::{EncodedCiphertext, EncodedDecryptionKey, EncodedEncryptionKey, PkeParams};
8use crate::util::B32;
9
10#[cfg(feature = "zeroize")]
11use zeroize::Zeroize;
12
13#[derive(Clone, Default, Debug, PartialEq)]
16pub struct DecryptionKey<P>
17where
18 P: PkeParams,
19{
20 s_hat: NttVector<P::K>,
21}
22
23#[cfg(feature = "zeroize")]
24impl<P> Zeroize for DecryptionKey<P>
25where
26 P: PkeParams,
27{
28 fn zeroize(&mut self) {
29 self.s_hat.zeroize();
30 }
31}
32
33impl<P> DecryptionKey<P>
34where
35 P: PkeParams,
36{
37 pub fn generate(d: &B32) -> (Self, EncryptionKey<P>) {
40 let k = P::K::U8;
42 let (rho, sigma) = G(&[&d[..], &[k]]);
43
44 let A_hat: NttMatrix<P::K> = NttMatrix::sample_uniform(&rho, false);
46 let s: PolynomialVector<P::K> = PolynomialVector::sample_cbd::<P::Eta1>(&sigma, 0);
47 let e: PolynomialVector<P::K> = PolynomialVector::sample_cbd::<P::Eta1>(&sigma, P::K::U8);
48
49 let s_hat = s.ntt();
51 let e_hat = e.ntt();
52
53 let t_hat = &(&A_hat * &s_hat) + &e_hat;
55
56 let dk = DecryptionKey { s_hat };
58 let ek = EncryptionKey { t_hat, rho };
59 (dk, ek)
60 }
61
62 pub fn decrypt(&self, ciphertext: &EncodedCiphertext<P>) -> B32 {
65 let (c1, c2) = P::split_ct(ciphertext);
66
67 let mut u: PolynomialVector<P::K> = Encode::<P::Du>::decode(c1);
68 u.decompress::<P::Du>();
69
70 let mut v: Polynomial = Encode::<P::Dv>::decode(c2);
71 v.decompress::<P::Dv>();
72
73 let u_hat = u.ntt();
74 let sTu = (&self.s_hat * &u_hat).ntt_inverse();
75 let mut w = &v - &sTu;
76 Encode::<U1>::encode(w.compress::<U1>())
77 }
78
79 pub fn as_bytes(&self) -> EncodedDecryptionKey<P> {
81 P::encode_u12(&self.s_hat)
82 }
83
84 pub fn from_bytes(enc: &EncodedDecryptionKey<P>) -> Self {
86 let s_hat = P::decode_u12(enc);
87 Self { s_hat }
88 }
89}
90
91#[derive(Clone, Default, Debug, PartialEq)]
94pub struct EncryptionKey<P>
95where
96 P: PkeParams,
97{
98 t_hat: NttVector<P::K>,
99 rho: B32,
100}
101
102impl<P> EncryptionKey<P>
103where
104 P: PkeParams,
105{
106 pub fn encrypt(&self, message: &B32, randomness: &B32) -> EncodedCiphertext<P> {
109 let r = PolynomialVector::<P::K>::sample_cbd::<P::Eta1>(randomness, 0);
110 let e1 = PolynomialVector::<P::K>::sample_cbd::<P::Eta2>(randomness, P::K::U8);
111
112 let prf_output = PRF::<P::Eta2>(randomness, 2 * P::K::U8);
113 let e2: Polynomial = Polynomial::sample_cbd::<P::Eta2>(&prf_output);
114
115 let A_hat_t = NttMatrix::<P::K>::sample_uniform(&self.rho, true);
116 let r_hat: NttVector<P::K> = r.ntt();
117 let ATr: PolynomialVector<P::K> = (&A_hat_t * &r_hat).ntt_inverse();
118 let mut u = ATr + e1;
119
120 let mut mu: Polynomial = Encode::<U1>::decode(message);
121 mu.decompress::<U1>();
122
123 let tTr: Polynomial = (&self.t_hat * &r_hat).ntt_inverse();
124 let mut v = &(&tTr + &e2) + μ
125
126 let c1 = Encode::<P::Du>::encode(u.compress::<P::Du>());
127 let c2 = Encode::<P::Dv>::encode(v.compress::<P::Dv>());
128 P::concat_ct(c1, c2)
129 }
130
131 pub fn as_bytes(&self) -> EncodedEncryptionKey<P> {
133 let t_hat = P::encode_u12(&self.t_hat);
134 P::concat_ek(t_hat, self.rho.clone())
135 }
136
137 pub fn from_bytes(enc: &EncodedEncryptionKey<P>) -> Self {
139 let (t_hat, rho) = P::split_ek(enc);
140 let t_hat = P::decode_u12(t_hat);
141 Self {
142 t_hat,
143 rho: rho.clone(),
144 }
145 }
146}
147
148#[cfg(test)]
149mod test {
150 use super::*;
151 use crate::crypto::rand;
152 use crate::{MlKem1024Params, MlKem512Params, MlKem768Params};
153
154 fn round_trip_test<P>()
155 where
156 P: PkeParams,
157 {
158 let mut rng = rand::thread_rng();
159 let d: B32 = rand(&mut rng);
160 let original = B32::default();
161 let randomness = B32::default();
162
163 let (dk, ek) = DecryptionKey::<P>::generate(&d);
164 let encrypted = ek.encrypt(&original, &randomness);
165 let decrypted = dk.decrypt(&encrypted);
166 assert_eq!(original, decrypted);
167 }
168
169 #[test]
170 fn round_trip() {
171 round_trip_test::<MlKem512Params>();
172 round_trip_test::<MlKem768Params>();
173 round_trip_test::<MlKem1024Params>();
174 }
175
176 fn codec_test<P>()
177 where
178 P: PkeParams,
179 {
180 let mut rng = rand::thread_rng();
181 let d: B32 = rand(&mut rng);
182 let (dk_original, ek_original) = DecryptionKey::<P>::generate(&d);
183
184 let dk_encoded = dk_original.as_bytes();
185 let dk_decoded = DecryptionKey::from_bytes(&dk_encoded);
186 assert_eq!(dk_original, dk_decoded);
187
188 let ek_encoded = ek_original.as_bytes();
189 let ek_decoded = EncryptionKey::from_bytes(&ek_encoded);
190 assert_eq!(ek_original, ek_decoded);
191 }
192
193 #[test]
194 fn codec() {
195 codec_test::<MlKem512Params>();
196 codec_test::<MlKem768Params>();
197 codec_test::<MlKem1024Params>();
198 }
199}