aws_lc_rs/
kem.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0 OR ISC
3
4//! Key-Encapsulation Mechanisms (KEMs), including support for Kyber Round 3 Submission.
5//!
6//! # Example
7//!
8//! Note that this example uses the Kyber-512 Round 3 algorithm, but other algorithms can be used
9//! in the exact same way by substituting
10//! `kem::<desired_algorithm_here>` for `kem::KYBER512_R3`.
11//!
12//! ```rust
13//! use aws_lc_rs::{
14//!     kem::{Ciphertext, DecapsulationKey, EncapsulationKey},
15//!     kem::{ML_KEM_512}
16//! };
17//!
18//! // Alice generates their (private) decapsulation key.
19//! let decapsulation_key = DecapsulationKey::generate(&ML_KEM_512)?;
20//!
21//! // Alices computes the (public) encapsulation key.
22//! let encapsulation_key = decapsulation_key.encapsulation_key()?;
23//!
24//! let encapsulation_key_bytes = encapsulation_key.key_bytes()?;
25//!
26//! // Alice sends the encapsulation key bytes to bob through some
27//! // protocol message.
28//! let encapsulation_key_bytes = encapsulation_key_bytes.as_ref();
29//!
30//! // Bob constructs the (public) encapsulation key from the key bytes provided by Alice.
31//! let retrieved_encapsulation_key = EncapsulationKey::new(&ML_KEM_512, encapsulation_key_bytes)?;
32//!
33//! // Bob executes the encapsulation algorithm to to produce their copy of the secret, and associated ciphertext.
34//! let (ciphertext, bob_secret) = retrieved_encapsulation_key.encapsulate()?;
35//!
36//! // Alice receives ciphertext bytes from bob
37//! let ciphertext_bytes = ciphertext.as_ref();
38//!
39//! // Bob sends Alice the ciphertext computed from the encapsulation algorithm, Alice runs decapsulation to derive their
40//! // copy of the secret.
41//! let alice_secret = decapsulation_key.decapsulate(Ciphertext::from(ciphertext_bytes))?;
42//!
43//! // Alice and Bob have now arrived to the same secret
44//! assert_eq!(alice_secret.as_ref(), bob_secret.as_ref());
45//!
46//! # Ok::<(), aws_lc_rs::error::Unspecified>(())
47//! ```
48use crate::aws_lc::{
49    EVP_PKEY_CTX_kem_set_params, EVP_PKEY_decapsulate, EVP_PKEY_encapsulate,
50    EVP_PKEY_kem_new_raw_public_key, EVP_PKEY_kem_new_raw_secret_key, EVP_PKEY, EVP_PKEY_KEM,
51};
52use crate::buffer::Buffer;
53use crate::encoding::generated_encodings;
54use crate::error::{KeyRejected, Unspecified};
55use crate::ptr::LcPtr;
56use alloc::borrow::Cow;
57use core::cmp::Ordering;
58use zeroize::Zeroize;
59
60const ML_KEM_512_SHARED_SECRET_LENGTH: usize = 32;
61const ML_KEM_512_PUBLIC_KEY_LENGTH: usize = 800;
62const ML_KEM_512_SECRET_KEY_LENGTH: usize = 1632;
63const ML_KEM_512_CIPHERTEXT_LENGTH: usize = 768;
64
65const ML_KEM_768_SHARED_SECRET_LENGTH: usize = 32;
66const ML_KEM_768_PUBLIC_KEY_LENGTH: usize = 1184;
67const ML_KEM_768_SECRET_KEY_LENGTH: usize = 2400;
68const ML_KEM_768_CIPHERTEXT_LENGTH: usize = 1088;
69
70const ML_KEM_1024_SHARED_SECRET_LENGTH: usize = 32;
71const ML_KEM_1024_PUBLIC_KEY_LENGTH: usize = 1568;
72const ML_KEM_1024_SECRET_KEY_LENGTH: usize = 3168;
73const ML_KEM_1024_CIPHERTEXT_LENGTH: usize = 1568;
74
75/// NIST FIPS 203 ML-KEM-512 algorithm.
76pub const ML_KEM_512: Algorithm<AlgorithmId> = Algorithm {
77    id: AlgorithmId::MlKem512,
78    decapsulate_key_size: ML_KEM_512_SECRET_KEY_LENGTH,
79    encapsulate_key_size: ML_KEM_512_PUBLIC_KEY_LENGTH,
80    ciphertext_size: ML_KEM_512_CIPHERTEXT_LENGTH,
81    shared_secret_size: ML_KEM_512_SHARED_SECRET_LENGTH,
82};
83
84/// NIST FIPS 203 ML-KEM-768 algorithm.
85pub const ML_KEM_768: Algorithm<AlgorithmId> = Algorithm {
86    id: AlgorithmId::MlKem768,
87    decapsulate_key_size: ML_KEM_768_SECRET_KEY_LENGTH,
88    encapsulate_key_size: ML_KEM_768_PUBLIC_KEY_LENGTH,
89    ciphertext_size: ML_KEM_768_CIPHERTEXT_LENGTH,
90    shared_secret_size: ML_KEM_768_SHARED_SECRET_LENGTH,
91};
92
93/// NIST FIPS 203 ML-KEM-1024 algorithm.
94pub const ML_KEM_1024: Algorithm<AlgorithmId> = Algorithm {
95    id: AlgorithmId::MlKem1024,
96    decapsulate_key_size: ML_KEM_1024_SECRET_KEY_LENGTH,
97    encapsulate_key_size: ML_KEM_1024_PUBLIC_KEY_LENGTH,
98    ciphertext_size: ML_KEM_1024_CIPHERTEXT_LENGTH,
99    shared_secret_size: ML_KEM_1024_SHARED_SECRET_LENGTH,
100};
101
102use crate::aws_lc::{NID_MLKEM1024, NID_MLKEM512, NID_MLKEM768};
103
104/// An identifier for a KEM algorithm.
105pub trait AlgorithmIdentifier:
106    Copy + Clone + Debug + PartialEq + crate::sealed::Sealed + 'static
107{
108    /// Returns the algorithm's associated AWS-LC nid.
109    fn nid(self) -> i32;
110}
111
112/// A KEM algorithm
113#[derive(PartialEq)]
114pub struct Algorithm<Id = AlgorithmId>
115where
116    Id: AlgorithmIdentifier,
117{
118    pub(crate) id: Id,
119    pub(crate) decapsulate_key_size: usize,
120    pub(crate) encapsulate_key_size: usize,
121    pub(crate) ciphertext_size: usize,
122    pub(crate) shared_secret_size: usize,
123}
124
125impl<Id> Algorithm<Id>
126where
127    Id: AlgorithmIdentifier,
128{
129    /// Returns the identifier for this algorithm.
130    #[must_use]
131    pub fn id(&self) -> Id {
132        self.id
133    }
134
135    #[inline]
136    #[allow(dead_code)]
137    pub(crate) fn decapsulate_key_size(&self) -> usize {
138        self.decapsulate_key_size
139    }
140
141    #[inline]
142    pub(crate) fn encapsulate_key_size(&self) -> usize {
143        self.encapsulate_key_size
144    }
145
146    #[inline]
147    pub(crate) fn ciphertext_size(&self) -> usize {
148        self.ciphertext_size
149    }
150
151    #[inline]
152    pub(crate) fn shared_secret_size(&self) -> usize {
153        self.shared_secret_size
154    }
155}
156
157impl<Id> Debug for Algorithm<Id>
158where
159    Id: AlgorithmIdentifier,
160{
161    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
162        Debug::fmt(&self.id, f)
163    }
164}
165
166/// A serializable decapulsation key usable with KEMs. This can be randomly generated with `DecapsulationKey::generate`.
167pub struct DecapsulationKey<Id = AlgorithmId>
168where
169    Id: AlgorithmIdentifier,
170{
171    algorithm: &'static Algorithm<Id>,
172    evp_pkey: LcPtr<EVP_PKEY>,
173}
174
175/// Identifier for a KEM algorithm.
176#[non_exhaustive]
177#[derive(Clone, Copy, Debug, PartialEq)]
178pub enum AlgorithmId {
179    /// NIST FIPS 203 ML-KEM-512 algorithm.
180    MlKem512,
181
182    /// NIST FIPS 203 ML-KEM-768 algorithm.
183    MlKem768,
184
185    /// NIST FIPS 203 ML-KEM-1024 algorithm.
186    MlKem1024,
187}
188
189impl AlgorithmIdentifier for AlgorithmId {
190    fn nid(self) -> i32 {
191        match self {
192            AlgorithmId::MlKem512 => NID_MLKEM512,
193            AlgorithmId::MlKem768 => NID_MLKEM768,
194            AlgorithmId::MlKem1024 => NID_MLKEM1024,
195        }
196    }
197}
198
199impl crate::sealed::Sealed for AlgorithmId {}
200
201impl<Id> DecapsulationKey<Id>
202where
203    Id: AlgorithmIdentifier,
204{
205    /// Creates a new KEM decapsulation key from raw bytes. This method MUST NOT be used to generate
206    /// a new decapsulation key, rather it MUST be used to construct `DecapsulationKey` previously serialized
207    /// to raw bytes.
208    ///
209    /// `alg` is the [`Algorithm`] to be associated with the generated `DecapsulationKey`.
210    ///
211    /// `bytes` is a slice of raw bytes representing a `DecapsulationKey`.
212    ///
213    /// # Security Considerations
214    ///
215    /// This function performs size validation but does not fully validate key material integrity.
216    /// Invalid key bytes (e.g., corrupted or tampered data) may be accepted by this function but
217    /// will cause [`Self::decapsulate`] to fail. Only use bytes that were previously obtained from
218    /// [`Self::key_bytes`] on a validly generated key.
219    ///
220    /// # Limitations
221    ///
222    /// The `DecapsulationKey` returned by this function will NOT provide the associated
223    /// `EncapsulationKey` via [`Self::encapsulation_key`]. The `EncapsulationKey` must be
224    /// serialized and restored separately using [`EncapsulationKey::key_bytes`] and
225    /// [`EncapsulationKey::new`].
226    ///
227    /// # Errors
228    ///
229    /// Returns `KeyRejected::too_small()` if `bytes.len() < alg.decapsulate_key_size()`.
230    ///
231    /// Returns `KeyRejected::too_large()` if `bytes.len() > alg.decapsulate_key_size()`.
232    ///
233    /// Returns `KeyRejected::unexpected_error()` if the underlying cryptographic operation fails.
234    pub fn new(alg: &'static Algorithm<Id>, bytes: &[u8]) -> Result<Self, KeyRejected> {
235        match bytes.len().cmp(&alg.decapsulate_key_size()) {
236            Ordering::Less => Err(KeyRejected::too_small()),
237            Ordering::Greater => Err(KeyRejected::too_large()),
238            Ordering::Equal => Ok(()),
239        }?;
240        let evp_pkey = LcPtr::new(unsafe {
241            EVP_PKEY_kem_new_raw_secret_key(alg.id.nid(), bytes.as_ptr(), bytes.len())
242        })?;
243        Ok(DecapsulationKey {
244            algorithm: alg,
245            evp_pkey,
246        })
247    }
248
249    /// Generate a new KEM decapsulation key for the given algorithm.
250    ///
251    /// # Errors
252    /// `error::Unspecified` when operation fails due to internal error.
253    pub fn generate(alg: &'static Algorithm<Id>) -> Result<Self, Unspecified> {
254        let kyber_key = kem_key_generate(alg.id.nid())?;
255        Ok(DecapsulationKey {
256            algorithm: alg,
257            evp_pkey: kyber_key,
258        })
259    }
260
261    /// Return the algorithm associated with the given KEM decapsulation key.
262    #[must_use]
263    pub fn algorithm(&self) -> &'static Algorithm<Id> {
264        self.algorithm
265    }
266
267    /// Returns the raw bytes of the `DecapsulationKey`.
268    ///
269    /// The returned bytes can be used with [`Self::new`] to reconstruct the `DecapsulationKey`.
270    ///
271    /// # Errors
272    ///
273    /// Returns [`Unspecified`] if the key bytes cannot be retrieved from the underlying
274    /// cryptographic implementation.
275    pub fn key_bytes(&self) -> Result<DecapsulationKeyBytes<'static>, Unspecified> {
276        let decapsulation_key_bytes = self.evp_pkey.as_const().marshal_raw_private_key()?;
277        debug_assert_eq!(
278            decapsulation_key_bytes.len(),
279            self.algorithm.decapsulate_key_size()
280        );
281        Ok(DecapsulationKeyBytes::new(decapsulation_key_bytes))
282    }
283
284    /// Returns the `EncapsulationKey` associated with this `DecapsulationKey`.
285    ///
286    /// # Errors
287    ///
288    /// Returns [`Unspecified`] in the following cases:
289    /// * The `DecapsulationKey` was constructed from raw bytes using [`Self::new`],
290    ///   as the underlying key representation does not include the public key component.
291    ///   In this case, the `EncapsulationKey` must be serialized and restored separately.
292    /// * An internal error occurs while extracting the public key.
293    #[allow(clippy::missing_panics_doc)]
294    pub fn encapsulation_key(&self) -> Result<EncapsulationKey<Id>, Unspecified> {
295        let evp_pkey = self.evp_pkey.clone();
296
297        let encapsulation_key = EncapsulationKey {
298            algorithm: self.algorithm,
299            evp_pkey,
300        };
301
302        // Verify the encapsulation key is valid by attempting to get its bytes.
303        // Keys constructed from raw secret bytes may not have a valid public key.
304        if encapsulation_key.key_bytes().is_err() {
305            return Err(Unspecified);
306        }
307
308        Ok(encapsulation_key)
309    }
310
311    /// Performs the decapsulate operation using this `DecapsulationKey` on the given ciphertext.
312    ///
313    /// `ciphertext` is the ciphertext generated by the encapsulate operation using the `EncapsulationKey`
314    /// associated with this `DecapsulationKey`.
315    ///
316    /// # Errors
317    ///
318    /// Returns [`Unspecified`] in the following cases:
319    /// * The `ciphertext` is malformed or was not generated for this key's algorithm.
320    /// * The `DecapsulationKey` was constructed from invalid bytes (e.g., corrupted or tampered
321    ///   key material passed to [`Self::new`]). Note that [`Self::new`] only validates the size
322    ///   of the key bytes, not their cryptographic validity.
323    /// * An internal cryptographic error occurs.
324    #[allow(clippy::needless_pass_by_value)]
325    pub fn decapsulate(&self, ciphertext: Ciphertext<'_>) -> Result<SharedSecret, Unspecified> {
326        let mut shared_secret_len = self.algorithm.shared_secret_size();
327        let mut shared_secret: Vec<u8> = vec![0u8; shared_secret_len];
328
329        let mut ctx = self.evp_pkey.create_EVP_PKEY_CTX()?;
330
331        let ciphertext = ciphertext.as_ref();
332
333        if 1 != unsafe {
334            EVP_PKEY_decapsulate(
335                ctx.as_mut_ptr(),
336                shared_secret.as_mut_ptr(),
337                &mut shared_secret_len,
338                // AWS-LC incorrectly has this as an unqualified `uint8_t *`, it should be qualified with const
339                ciphertext.as_ptr().cast_mut(),
340                ciphertext.len(),
341            )
342        } {
343            return Err(Unspecified);
344        }
345
346        // This is currently pedantic but done for safety in-case the shared_secret buffer
347        // size changes in the future. `EVP_PKEY_decapsulate` updates `shared_secret_len` with
348        // the length of the shared secret in the event the buffer provided was larger then the secret.
349        // This truncates the buffer to the proper length to match the shared secret written.
350        debug_assert_eq!(shared_secret_len, shared_secret.len());
351        shared_secret.truncate(shared_secret_len);
352
353        Ok(SharedSecret(shared_secret.into_boxed_slice()))
354    }
355}
356
357unsafe impl<Id> Send for DecapsulationKey<Id> where Id: AlgorithmIdentifier {}
358
359unsafe impl<Id> Sync for DecapsulationKey<Id> where Id: AlgorithmIdentifier {}
360
361impl<Id> Debug for DecapsulationKey<Id>
362where
363    Id: AlgorithmIdentifier,
364{
365    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
366        f.debug_struct("DecapsulationKey")
367            .field("algorithm", &self.algorithm)
368            .finish_non_exhaustive()
369    }
370}
371
372generated_encodings!(
373    (EncapsulationKeyBytes, EncapsulationKeyBytesType),
374    (DecapsulationKeyBytes, DecapsulationKeyBytesType)
375);
376
377/// A serializable encapsulation key usable with KEM algorithms. Constructed
378/// from either a `DecapsulationKey` or raw bytes.
379pub struct EncapsulationKey<Id = AlgorithmId>
380where
381    Id: AlgorithmIdentifier,
382{
383    algorithm: &'static Algorithm<Id>,
384    evp_pkey: LcPtr<EVP_PKEY>,
385}
386
387impl<Id> EncapsulationKey<Id>
388where
389    Id: AlgorithmIdentifier,
390{
391    /// Return the algorithm associated with the given KEM encapsulation key.
392    #[must_use]
393    pub fn algorithm(&self) -> &'static Algorithm<Id> {
394        self.algorithm
395    }
396
397    /// Performs the encapsulate operation using this KEM encapsulation key, generating a ciphertext
398    /// and associated shared secret.
399    ///
400    /// # Errors
401    /// `error::Unspecified` when operation fails due to internal error.
402    pub fn encapsulate(&self) -> Result<(Ciphertext<'static>, SharedSecret), Unspecified> {
403        let mut ciphertext_len = self.algorithm.ciphertext_size();
404        let mut shared_secret_len = self.algorithm.shared_secret_size();
405        let mut ciphertext: Vec<u8> = vec![0u8; ciphertext_len];
406        let mut shared_secret: Vec<u8> = vec![0u8; shared_secret_len];
407
408        let mut ctx = self.evp_pkey.create_EVP_PKEY_CTX()?;
409
410        if 1 != unsafe {
411            EVP_PKEY_encapsulate(
412                ctx.as_mut_ptr(),
413                ciphertext.as_mut_ptr(),
414                &mut ciphertext_len,
415                shared_secret.as_mut_ptr(),
416                &mut shared_secret_len,
417            )
418        } {
419            return Err(Unspecified);
420        }
421
422        // The following two steps are currently pedantic but done for safety in-case the buffer allocation
423        // sizes change in the future. `EVP_PKEY_encapsulate` updates `ciphertext_len` and `shared_secret_len` with
424        // the length of the ciphertext and shared secret respectivly in the event the buffer provided for each was
425        // larger then the actual values. Thus these two steps truncate the buffers to the proper length to match the
426        // value lengths written.
427        debug_assert_eq!(ciphertext_len, ciphertext.len());
428        ciphertext.truncate(ciphertext_len);
429        debug_assert_eq!(shared_secret_len, shared_secret.len());
430        shared_secret.truncate(shared_secret_len);
431
432        Ok((
433            Ciphertext::new(ciphertext),
434            SharedSecret::new(shared_secret.into_boxed_slice()),
435        ))
436    }
437
438    /// Returns the `EnscapsulationKey` bytes.
439    ///
440    /// # Errors
441    /// * `Unspecified`: Any failure to retrieve the `EnscapsulationKey` bytes.
442    pub fn key_bytes(&self) -> Result<EncapsulationKeyBytes<'static>, Unspecified> {
443        let mut encapsulate_bytes = vec![0u8; self.algorithm.encapsulate_key_size()];
444        let encapsulate_key_size = self
445            .evp_pkey
446            .as_const()
447            .marshal_raw_public_to_buffer(&mut encapsulate_bytes)?;
448
449        debug_assert_eq!(encapsulate_key_size, encapsulate_bytes.len());
450        encapsulate_bytes.truncate(encapsulate_key_size);
451
452        Ok(EncapsulationKeyBytes::new(encapsulate_bytes))
453    }
454
455    /// Creates a new KEM encapsulation key from raw bytes. This method MUST NOT be used to generate
456    /// a new encapsulation key, rather it MUST be used to construct `EncapsulationKey` previously serialized
457    /// to raw bytes.
458    ///
459    /// `alg` is the [`Algorithm`] to be associated with the generated `EncapsulationKey`.
460    ///
461    /// `bytes` is a slice of raw bytes representing a `EncapsulationKey`.
462    ///
463    /// # Errors
464    /// `error::KeyRejected` when operation fails during key creation.
465    pub fn new(alg: &'static Algorithm<Id>, bytes: &[u8]) -> Result<Self, KeyRejected> {
466        match bytes.len().cmp(&alg.encapsulate_key_size()) {
467            Ordering::Less => Err(KeyRejected::too_small()),
468            Ordering::Greater => Err(KeyRejected::too_large()),
469            Ordering::Equal => Ok(()),
470        }?;
471        let pubkey = LcPtr::new(unsafe {
472            EVP_PKEY_kem_new_raw_public_key(alg.id.nid(), bytes.as_ptr(), bytes.len())
473        })?;
474        Ok(EncapsulationKey {
475            algorithm: alg,
476            evp_pkey: pubkey,
477        })
478    }
479}
480
481unsafe impl<Id> Send for EncapsulationKey<Id> where Id: AlgorithmIdentifier {}
482
483unsafe impl<Id> Sync for EncapsulationKey<Id> where Id: AlgorithmIdentifier {}
484
485impl<Id> Debug for EncapsulationKey<Id>
486where
487    Id: AlgorithmIdentifier,
488{
489    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
490        f.debug_struct("EncapsulationKey")
491            .field("algorithm", &self.algorithm)
492            .finish_non_exhaustive()
493    }
494}
495
496/// A set of encrypted bytes produced by [`EncapsulationKey::encapsulate`],
497/// and used as an input to [`DecapsulationKey::decapsulate`].
498pub struct Ciphertext<'a>(Cow<'a, [u8]>);
499
500impl<'a> Ciphertext<'a> {
501    fn new(value: Vec<u8>) -> Ciphertext<'a> {
502        Self(Cow::Owned(value))
503    }
504}
505
506impl Drop for Ciphertext<'_> {
507    fn drop(&mut self) {
508        if let Cow::Owned(ref mut v) = self.0 {
509            v.zeroize();
510        }
511    }
512}
513
514impl AsRef<[u8]> for Ciphertext<'_> {
515    fn as_ref(&self) -> &[u8] {
516        match self.0 {
517            Cow::Borrowed(v) => v,
518            Cow::Owned(ref v) => v.as_ref(),
519        }
520    }
521}
522
523impl<'a> From<&'a [u8]> for Ciphertext<'a> {
524    fn from(value: &'a [u8]) -> Self {
525        Self(Cow::Borrowed(value))
526    }
527}
528
529/// The cryptographic shared secret output from the KEM encapsulate / decapsulate process.
530pub struct SharedSecret(Box<[u8]>);
531
532impl SharedSecret {
533    fn new(value: Box<[u8]>) -> Self {
534        Self(value)
535    }
536}
537
538impl Drop for SharedSecret {
539    fn drop(&mut self) {
540        self.0.zeroize();
541    }
542}
543
544impl AsRef<[u8]> for SharedSecret {
545    fn as_ref(&self) -> &[u8] {
546        self.0.as_ref()
547    }
548}
549
550// Returns an LcPtr to an EVP_PKEY
551#[inline]
552fn kem_key_generate(nid: i32) -> Result<LcPtr<EVP_PKEY>, Unspecified> {
553    let params_fn = |ctx| {
554        if 1 == unsafe { EVP_PKEY_CTX_kem_set_params(ctx, nid) } {
555            Ok(())
556        } else {
557            Err(())
558        }
559    };
560
561    LcPtr::<EVP_PKEY>::generate(EVP_PKEY_KEM, Some(params_fn))
562}
563
564#[cfg(test)]
565mod tests {
566    use super::{Ciphertext, DecapsulationKey, EncapsulationKey, SharedSecret};
567    use crate::error::KeyRejected;
568
569    use crate::kem::{ML_KEM_1024, ML_KEM_512, ML_KEM_768};
570
571    #[test]
572    fn ciphertext() {
573        let ciphertext_bytes = vec![42u8; 4];
574        let ciphertext = Ciphertext::from(ciphertext_bytes.as_ref());
575        assert_eq!(ciphertext.as_ref(), &[42, 42, 42, 42]);
576        drop(ciphertext);
577
578        let ciphertext_bytes = vec![42u8; 4];
579        let ciphertext = Ciphertext::<'static>::new(ciphertext_bytes);
580        assert_eq!(ciphertext.as_ref(), &[42, 42, 42, 42]);
581    }
582
583    #[test]
584    fn shared_secret() {
585        let secret_bytes = vec![42u8; 4];
586        let shared_secret = SharedSecret::new(secret_bytes.into_boxed_slice());
587        assert_eq!(shared_secret.as_ref(), &[42, 42, 42, 42]);
588    }
589
590    #[test]
591    fn test_kem_serialize() {
592        for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
593            let priv_key = DecapsulationKey::generate(algorithm).unwrap();
594            assert_eq!(priv_key.algorithm(), algorithm);
595
596            // Test DecapsulationKey serialization
597            let priv_key_raw_bytes = priv_key.key_bytes().unwrap();
598            assert_eq!(
599                priv_key_raw_bytes.as_ref().len(),
600                algorithm.decapsulate_key_size()
601            );
602            let priv_key_from_bytes =
603                DecapsulationKey::new(algorithm, priv_key_raw_bytes.as_ref()).unwrap();
604
605            assert_eq!(
606                priv_key.key_bytes().unwrap().as_ref(),
607                priv_key_from_bytes.key_bytes().unwrap().as_ref()
608            );
609            assert_eq!(priv_key.algorithm(), priv_key_from_bytes.algorithm());
610
611            // Test EncapsulationKey serialization
612            let pub_key = priv_key.encapsulation_key().unwrap();
613            let pubkey_raw_bytes = pub_key.key_bytes().unwrap();
614            let pub_key_from_bytes =
615                EncapsulationKey::new(algorithm, pubkey_raw_bytes.as_ref()).unwrap();
616
617            assert_eq!(
618                pub_key.key_bytes().unwrap().as_ref(),
619                pub_key_from_bytes.key_bytes().unwrap().as_ref()
620            );
621            assert_eq!(pub_key.algorithm(), pub_key_from_bytes.algorithm());
622        }
623    }
624
625    #[test]
626    fn test_kem_wrong_sizes() {
627        for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
628            // Test EncapsulationKey size validation
629            let too_long_bytes = vec![0u8; algorithm.encapsulate_key_size() + 1];
630            let long_pub_key_from_bytes = EncapsulationKey::new(algorithm, &too_long_bytes);
631            assert_eq!(
632                long_pub_key_from_bytes.err(),
633                Some(KeyRejected::too_large())
634            );
635
636            let too_short_bytes = vec![0u8; algorithm.encapsulate_key_size() - 1];
637            let short_pub_key_from_bytes = EncapsulationKey::new(algorithm, &too_short_bytes);
638            assert_eq!(
639                short_pub_key_from_bytes.err(),
640                Some(KeyRejected::too_small())
641            );
642
643            // Test DecapsulationKey size validation
644            let too_long_bytes = vec![0u8; algorithm.decapsulate_key_size() + 1];
645            let long_priv_key_from_bytes = DecapsulationKey::new(algorithm, &too_long_bytes);
646            assert_eq!(
647                long_priv_key_from_bytes.err(),
648                Some(KeyRejected::too_large())
649            );
650
651            let too_short_bytes = vec![0u8; algorithm.decapsulate_key_size() - 1];
652            let short_priv_key_from_bytes = DecapsulationKey::new(algorithm, &too_short_bytes);
653            assert_eq!(
654                short_priv_key_from_bytes.err(),
655                Some(KeyRejected::too_small())
656            );
657        }
658    }
659
660    #[test]
661    fn test_kem_e2e() {
662        for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
663            let priv_key = DecapsulationKey::generate(algorithm).unwrap();
664            assert_eq!(priv_key.algorithm(), algorithm);
665
666            // Serialize and reconstruct the decapsulation key
667            let priv_key_bytes = priv_key.key_bytes().unwrap();
668            let priv_key_from_bytes =
669                DecapsulationKey::new(algorithm, priv_key_bytes.as_ref()).unwrap();
670
671            // Keys reconstructed from bytes cannot provide encapsulation_key()
672            assert!(priv_key_from_bytes.encapsulation_key().is_err());
673
674            let pub_key = priv_key.encapsulation_key().unwrap();
675
676            let (alice_ciphertext, alice_secret) =
677                pub_key.encapsulate().expect("encapsulate successful");
678
679            // Decapsulate using the reconstructed key
680            let bob_secret = priv_key_from_bytes
681                .decapsulate(alice_ciphertext)
682                .expect("decapsulate successful");
683
684            assert_eq!(alice_secret.as_ref(), bob_secret.as_ref());
685        }
686    }
687
688    #[test]
689    fn test_serialized_kem_e2e() {
690        for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
691            let priv_key = DecapsulationKey::generate(algorithm).unwrap();
692            assert_eq!(priv_key.algorithm(), algorithm);
693
694            let pub_key = priv_key.encapsulation_key().unwrap();
695
696            // Generate public key bytes to send to bob
697            let pub_key_bytes = pub_key.key_bytes().unwrap();
698
699            // Generate private key bytes for alice to store securely
700            let priv_key_bytes = priv_key.key_bytes().unwrap();
701
702            // Test that priv_key's EVP_PKEY isn't entirely freed since we remove this pub_key's reference.
703            drop(pub_key);
704            drop(priv_key);
705
706            let retrieved_pub_key =
707                EncapsulationKey::new(algorithm, pub_key_bytes.as_ref()).unwrap();
708            let (ciphertext, bob_secret) = retrieved_pub_key
709                .encapsulate()
710                .expect("encapsulate successful");
711
712            // Alice reconstructs her private key from stored bytes
713            let retrieved_priv_key =
714                DecapsulationKey::new(algorithm, priv_key_bytes.as_ref()).unwrap();
715            let alice_secret = retrieved_priv_key
716                .decapsulate(ciphertext)
717                .expect("decapsulate successful");
718
719            assert_eq!(alice_secret.as_ref(), bob_secret.as_ref());
720        }
721    }
722
723    #[test]
724    fn test_decapsulation_key_serialization_roundtrip() {
725        for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
726            // Generate original key
727            let original_key = DecapsulationKey::generate(algorithm).unwrap();
728
729            // Test key_bytes() returns correct size
730            let key_bytes = original_key.key_bytes().unwrap();
731            assert_eq!(key_bytes.as_ref().len(), algorithm.decapsulate_key_size());
732
733            // Test round-trip serialization/deserialization
734            let reconstructed_key = DecapsulationKey::new(algorithm, key_bytes.as_ref()).unwrap();
735
736            // Verify algorithm consistency
737            assert_eq!(original_key.algorithm(), reconstructed_key.algorithm());
738            assert_eq!(original_key.algorithm(), algorithm);
739
740            // Test serialization produces identical bytes (stability check)
741            let key_bytes_2 = reconstructed_key.key_bytes().unwrap();
742            assert_eq!(key_bytes.as_ref(), key_bytes_2.as_ref());
743
744            // Test functional equivalence: both keys decrypt the same ciphertext identically
745            let pub_key = original_key.encapsulation_key().unwrap();
746            let (ciphertext, expected_secret) =
747                pub_key.encapsulate().expect("encapsulate successful");
748
749            let secret_from_original = original_key
750                .decapsulate(Ciphertext::from(ciphertext.as_ref()))
751                .expect("decapsulate with original key");
752            let secret_from_reconstructed = reconstructed_key
753                .decapsulate(Ciphertext::from(ciphertext.as_ref()))
754                .expect("decapsulate with reconstructed key");
755
756            // Verify both keys produce identical secrets
757            assert_eq!(expected_secret.as_ref(), secret_from_original.as_ref());
758            assert_eq!(expected_secret.as_ref(), secret_from_reconstructed.as_ref());
759
760            // Verify secret length matches algorithm specification
761            assert_eq!(expected_secret.as_ref().len(), algorithm.shared_secret_size);
762        }
763    }
764
765    #[test]
766    fn test_decapsulation_key_zeroed_bytes() {
767        // Test behavior when constructing DecapsulationKey from zeroed bytes of correct size.
768        // ML-KEM accepts any bytes of the correct size as a valid secret key (seed-based).
769        // This test documents the expected behavior.
770        for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
771            let zeroed_bytes = vec![0u8; algorithm.decapsulate_key_size()];
772
773            // Constructing a key from zeroed bytes should succeed (ML-KEM treats any
774            // correctly-sized byte sequence as a valid seed)
775            let key_from_zeroed = DecapsulationKey::new(algorithm, &zeroed_bytes);
776            assert!(
777                key_from_zeroed.is_ok(),
778                "DecapsulationKey::new should accept zeroed bytes of correct size for {:?}",
779                algorithm.id()
780            );
781
782            let key = key_from_zeroed.unwrap();
783
784            // The key should be able to serialize back to bytes
785            let key_bytes = key.key_bytes();
786            assert!(
787                key_bytes.is_ok(),
788                "key_bytes() should succeed for key constructed from zeroed bytes"
789            );
790            assert_eq!(key_bytes.unwrap().as_ref(), zeroed_bytes.as_slice());
791
792            // encapsulation_key() should fail since key was constructed from raw bytes
793            assert!(
794                key.encapsulation_key().is_err(),
795                "encapsulation_key() should fail for key constructed from raw bytes"
796            );
797
798            // Test decapsulation behavior with zeroed-seed key.
799            // Generate a valid ciphertext from a properly generated key pair
800            let valid_key = DecapsulationKey::generate(algorithm).unwrap();
801            let valid_pub_key = valid_key.encapsulation_key().unwrap();
802            let (ciphertext, _) = valid_pub_key.encapsulate().unwrap();
803
804            // Decapsulating with a zeroed-seed key fails because the key material
805            // doesn't represent a valid ML-KEM private key structure.
806            // This documents that ML-KEM validates key integrity during decapsulation.
807            let decapsulate_result = key.decapsulate(Ciphertext::from(ciphertext.as_ref()));
808            assert!(
809                decapsulate_result.is_err(),
810                "decapsulate should fail with invalid (zeroed) key material for {:?}",
811                algorithm.id()
812            );
813        }
814    }
815
816    #[test]
817    fn test_cross_algorithm_key_rejection() {
818        // Test that keys from one algorithm are rejected when used with a different algorithm
819        // due to size mismatches.
820        let algorithms = [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024];
821
822        for source_alg in &algorithms {
823            let key = DecapsulationKey::generate(source_alg).unwrap();
824            let key_bytes = key.key_bytes().unwrap();
825
826            for target_alg in &algorithms {
827                if source_alg.id() == target_alg.id() {
828                    // Same algorithm should succeed
829                    let result = DecapsulationKey::new(target_alg, key_bytes.as_ref());
830                    assert!(
831                        result.is_ok(),
832                        "Same algorithm should accept its own key bytes"
833                    );
834                } else {
835                    // Different algorithm should fail due to size mismatch
836                    let result = DecapsulationKey::new(target_alg, key_bytes.as_ref());
837                    assert!(
838                        result.is_err(),
839                        "Algorithm {:?} should reject key bytes from {:?}",
840                        target_alg.id(),
841                        source_alg.id()
842                    );
843
844                    // Verify the error is size-related
845                    let err = result.err().unwrap();
846                    let source_size = source_alg.decapsulate_key_size();
847                    let target_size = target_alg.decapsulate_key_size();
848                    if source_size < target_size {
849                        assert_eq!(
850                            err,
851                            KeyRejected::too_small(),
852                            "Smaller key should be rejected as too_small"
853                        );
854                    } else {
855                        assert_eq!(
856                            err,
857                            KeyRejected::too_large(),
858                            "Larger key should be rejected as too_large"
859                        );
860                    }
861                }
862            }
863        }
864
865        // Also test EncapsulationKey cross-algorithm rejection for completeness
866        for source_alg in &algorithms {
867            let decap_key = DecapsulationKey::generate(source_alg).unwrap();
868            let encap_key = decap_key.encapsulation_key().unwrap();
869            let key_bytes = encap_key.key_bytes().unwrap();
870
871            for target_alg in &algorithms {
872                if source_alg.id() == target_alg.id() {
873                    let result = EncapsulationKey::new(target_alg, key_bytes.as_ref());
874                    assert!(
875                        result.is_ok(),
876                        "Same algorithm should accept its own encapsulation key bytes"
877                    );
878                } else {
879                    let result = EncapsulationKey::new(target_alg, key_bytes.as_ref());
880                    assert!(
881                        result.is_err(),
882                        "Algorithm {:?} should reject encapsulation key bytes from {:?}",
883                        target_alg.id(),
884                        source_alg.id()
885                    );
886                }
887            }
888        }
889    }
890
891    #[test]
892    fn test_debug_fmt() {
893        let private = DecapsulationKey::generate(&ML_KEM_512).expect("successful generation");
894        assert_eq!(
895            format!("{private:?}"),
896            "DecapsulationKey { algorithm: MlKem512, .. }"
897        );
898        assert_eq!(
899            format!(
900                "{:?}",
901                private.encapsulation_key().expect("public key retrievable")
902            ),
903            "EncapsulationKey { algorithm: MlKem512, .. }"
904        );
905    }
906}