Skip to main content

ml_kem/
decapsulation_key.rs

1use crate::{
2    B32, EncapsulationKey, Seed, SharedKey,
3    crypto::{G, J},
4    param::{DecapsulationKeySize, ExpandedDecapsulationKey, KemParams},
5    pke::{DecryptionKey, EncryptionKey},
6};
7use array::{
8    Array, ArraySize,
9    sizes::{U32, U64},
10};
11use kem::{
12    Ciphertext, Decapsulate, Decapsulator, Generate, InvalidKey, Kem, KeyExport, KeyInit,
13    KeySizeUser,
14};
15use module_lattice::{
16    MaybeBox,
17    ctutils::{CtEq, CtSelect},
18};
19use rand_core::{TryCryptoRng, TryRng};
20
21#[cfg(feature = "zeroize")]
22use zeroize::{Zeroize, ZeroizeOnDrop};
23
24/// A `DecapsulationKey` provides the ability to generate a new key pair, and decapsulate an
25/// encapsulated shared key.
26#[derive(Clone, Debug)]
27pub struct DecapsulationKey<P>
28where
29    P: KemParams,
30{
31    /// Decryption key.
32    dk_pke: MaybeBox<DecryptionKey<P>>,
33
34    /// Associated encapsulation key.
35    ek: EncapsulationKey<P>,
36
37    /// Seed this key was initialized from.
38    d: Option<MaybeBox<B32>>,
39
40    /// Random string used during the implicit rejection process.
41    z: MaybeBox<B32>,
42}
43
44impl<P> DecapsulationKey<P>
45where
46    P: KemParams,
47{
48    /// Create a [`DecapsulationKey`] instance from a 64-byte random seed value.
49    #[inline]
50    #[must_use]
51    pub fn from_seed(seed: Seed) -> Self {
52        let (d, z) = seed.split();
53        Self::generate_deterministic(d, z)
54    }
55
56    /// Initialize a [`DecapsulationKey`] from the serialized expanded key form.
57    ///
58    /// Note that this form is deprecated in practice; use [`DecapsulationKey::from_seed`].
59    /// See [`ExpandedKeyEncoding`] for more information.
60    ///
61    /// # Errors
62    /// - Returns [`InvalidKey`] in the event the expanded key failed validation
63    #[deprecated(since = "0.3.0", note = "use `DecapsulationKey::from_seed` instead")]
64    pub fn from_expanded(enc: &ExpandedDecapsulationKey<P>) -> Result<Self, InvalidKey> {
65        let (dk_pke, ek_pke, h, z) = P::split_dk(enc);
66        let dk_pke = MaybeBox::new(DecryptionKey::from_bytes(dk_pke));
67        let ek_pke = EncryptionKey::from_bytes(ek_pke)?;
68
69        let ek = EncapsulationKey::from_encryption_key(ek_pke);
70        if ek.h() != *h {
71            return Err(InvalidKey);
72        }
73
74        Ok(Self {
75            dk_pke,
76            ek,
77            d: None,
78            z: MaybeBox::new(z.clone()),
79        })
80    }
81
82    /// Serialize the [`Seed`] value: 64-bytes which can be used to reconstruct the
83    /// [`DecapsulationKey`].
84    ///
85    /// <div class="warning">
86    /// <b>Warning!</B>
87    ///
88    /// This value is key material. Please treat it with care.
89    /// </div>
90    ///
91    /// # Returns
92    /// - `Some` if the [`DecapsulationKey`] was initialized using `from_seed` or `generate`.
93    /// - `None` if the [`DecapsulationKey`] was initialized from the expanded form.
94    #[inline]
95    #[must_use]
96    pub fn to_seed(&self) -> Option<Seed> {
97        self.d.as_ref().map(|d| d.concat(*self.z))
98    }
99
100    /// Get the [`EncapsulationKey`] which corresponds to this [`DecapsulationKey`].
101    #[must_use]
102    pub fn encapsulation_key(&self) -> &EncapsulationKey<P> {
103        &self.ek
104    }
105
106    #[inline]
107    pub(crate) fn try_generate_from_rng<R>(rng: &mut R) -> Result<Self, <R as TryRng>::Error>
108    where
109        R: TryCryptoRng + ?Sized,
110    {
111        let d = B32::try_generate_from_rng(rng)?;
112        let z = B32::try_generate_from_rng(rng)?;
113        Ok(Self::generate_deterministic(d, z))
114    }
115
116    #[inline]
117    #[must_use]
118    #[allow(clippy::similar_names)] // allow dk_pke, ek_pke, following the spec
119    pub(crate) fn generate_deterministic(d: B32, z: B32) -> Self {
120        let (dk_pke, ek_pke) = DecryptionKey::generate(&d);
121        let ek = EncapsulationKey::from_encryption_key(ek_pke);
122
123        let dk_pke = MaybeBox::new(dk_pke);
124        let d = Some(MaybeBox::new(d));
125        let z = MaybeBox::new(z);
126
127        Self { dk_pke, ek, d, z }
128    }
129}
130
131// Handwritten to omit `d` in the comparisons, so keys initialized from seeds compare equally to
132// keys initialized from the expanded form
133impl<P> PartialEq for DecapsulationKey<P>
134where
135    P: KemParams,
136{
137    fn eq(&self, other: &Self) -> bool {
138        self.dk_pke.ct_eq(&other.dk_pke).into() && self.ek.eq(&other.ek) && self.z.eq(&other.z)
139    }
140}
141
142#[cfg(feature = "zeroize")]
143impl<P> Drop for DecapsulationKey<P>
144where
145    P: KemParams,
146{
147    fn drop(&mut self) {
148        self.dk_pke.zeroize();
149        if let Some(d) = self.d.as_mut() {
150            d.zeroize();
151        }
152        self.z.zeroize();
153    }
154}
155
156#[cfg(feature = "zeroize")]
157impl<P> ZeroizeOnDrop for DecapsulationKey<P> where P: KemParams {}
158
159impl<P> From<Seed> for DecapsulationKey<P>
160where
161    P: KemParams,
162{
163    fn from(seed: Seed) -> Self {
164        Self::from_seed(seed)
165    }
166}
167
168impl<P> Decapsulate for DecapsulationKey<P>
169where
170    P: Kem<EncapsulationKey = EncapsulationKey<P>, SharedKeySize = U32> + KemParams,
171{
172    fn decapsulate(&self, encapsulated_key: &Ciphertext<P>) -> SharedKey {
173        let mp = self.dk_pke.decrypt(encapsulated_key);
174        let (Kp, rp) = G(&[&mp, &self.ek.h()]);
175        let Kbar = J(&[self.z.as_slice(), encapsulated_key.as_ref()]);
176        let cp = self.ek.ek_pke().encrypt(&mp, &rp);
177        Kbar.ct_select(&Kp, cp.ct_eq(encapsulated_key))
178    }
179}
180
181impl<P> Decapsulator for DecapsulationKey<P>
182where
183    P: Kem<EncapsulationKey = EncapsulationKey<P>, SharedKeySize = U32> + KemParams,
184{
185    type Kem = P;
186
187    fn encapsulation_key(&self) -> &EncapsulationKey<P> {
188        &self.ek
189    }
190}
191
192impl<P> Generate for DecapsulationKey<P>
193where
194    P: KemParams,
195{
196    fn try_generate_from_rng<R>(rng: &mut R) -> Result<Self, <R as TryRng>::Error>
197    where
198        R: TryCryptoRng + ?Sized,
199    {
200        Self::try_generate_from_rng(rng)
201    }
202}
203
204impl<P> KeySizeUser for DecapsulationKey<P>
205where
206    P: KemParams,
207{
208    type KeySize = U64;
209}
210
211/// Initialize [`DecapsulationKey`] from a 64-byte uniformly random [`Seed`] value.
212impl<P> KeyInit for DecapsulationKey<P>
213where
214    P: KemParams,
215{
216    #[inline]
217    fn new(seed: &Seed) -> Self {
218        Self::from_seed(*seed)
219    }
220}
221
222/// Serialize the 64-byte [`Seed`] value used to initialize this [`DecapsulationKey`].
223///
224/// # Panics
225/// If this [`DecapsulationKey`] was initialized using legacy expanded key support
226/// (see [`ExpandedKeyEncoding`]).
227impl<P> KeyExport for DecapsulationKey<P>
228where
229    P: KemParams,
230{
231    fn to_bytes(&self) -> Seed {
232        self.to_seed().expect("should be initialized from a seed")
233    }
234}
235
236/// DEPRECATED: support for encoding and decoding [`DecapsulationKey`]s in the legacy expanded form,
237/// as opposed to the more widely adopted [`Seed`] form.
238///
239/// The expanded encoding format is problematic for several reasons, notably they need to validated
240/// whereas generation from seeds is always correct, meaning there is no performance advantage to
241/// using them, only additional complexity.
242///
243/// They are significantly larger than seeds (which are 64-bytes) and their sizes vary depending on
244/// security level whereas the size of a seed is constant:
245/// - ML-KEM-512: 1632 bytes
246/// - ML-KEM-768: 2400 bytes
247/// - ML-KEM-1024: 3168 bytes
248///
249/// Many ML-KEM libraries have dropped support for this format entirely.
250#[deprecated(since = "0.3.0", note = "use `DecapsulationKey::from_seed` instead")]
251pub trait ExpandedKeyEncoding: Sized {
252    /// The size of an expanded decapsulation key.
253    type EncodedSize: ArraySize;
254
255    /// Parse a [`DecapsulationKey`] from its legacy expanded form.
256    ///
257    /// # Errors
258    /// - If the key fails to validate successfully.
259    fn from_expanded_bytes(enc: &Array<u8, Self::EncodedSize>) -> Result<Self, InvalidKey>;
260
261    /// Serialize a [`DecapsulationKey`] to its legacy expanded form.
262    fn to_expanded_bytes(&self) -> Array<u8, Self::EncodedSize>;
263}
264
265#[allow(deprecated)]
266impl<P> ExpandedKeyEncoding for DecapsulationKey<P>
267where
268    P: KemParams,
269{
270    type EncodedSize = DecapsulationKeySize<P>;
271
272    fn from_expanded_bytes(expanded: &ExpandedDecapsulationKey<P>) -> Result<Self, InvalidKey> {
273        Self::from_expanded(expanded)
274    }
275
276    fn to_expanded_bytes(&self) -> ExpandedDecapsulationKey<P> {
277        let dk_pke = self.dk_pke.to_bytes();
278        let ek = self.ek.to_bytes();
279        P::concat_dk(dk_pke, ek, self.ek.h(), *self.z)
280    }
281}