1use crate::aws_lc::{
49 EVP_PKEY_CTX_kem_set_params, EVP_PKEY_decapsulate, EVP_PKEY_encapsulate,
50 EVP_PKEY_kem_new_raw_public_key, EVP_PKEY_kem_new_raw_secret_key, EVP_PKEY, EVP_PKEY_KEM,
51};
52use crate::buffer::Buffer;
53use crate::encoding::generated_encodings;
54use crate::error::{KeyRejected, Unspecified};
55use crate::ptr::LcPtr;
56use alloc::borrow::Cow;
57use core::cmp::Ordering;
58use zeroize::Zeroize;
59
60const ML_KEM_512_SHARED_SECRET_LENGTH: usize = 32;
61const ML_KEM_512_PUBLIC_KEY_LENGTH: usize = 800;
62const ML_KEM_512_SECRET_KEY_LENGTH: usize = 1632;
63const ML_KEM_512_CIPHERTEXT_LENGTH: usize = 768;
64
65const ML_KEM_768_SHARED_SECRET_LENGTH: usize = 32;
66const ML_KEM_768_PUBLIC_KEY_LENGTH: usize = 1184;
67const ML_KEM_768_SECRET_KEY_LENGTH: usize = 2400;
68const ML_KEM_768_CIPHERTEXT_LENGTH: usize = 1088;
69
70const ML_KEM_1024_SHARED_SECRET_LENGTH: usize = 32;
71const ML_KEM_1024_PUBLIC_KEY_LENGTH: usize = 1568;
72const ML_KEM_1024_SECRET_KEY_LENGTH: usize = 3168;
73const ML_KEM_1024_CIPHERTEXT_LENGTH: usize = 1568;
74
75pub const ML_KEM_512: Algorithm<AlgorithmId> = Algorithm {
77 id: AlgorithmId::MlKem512,
78 decapsulate_key_size: ML_KEM_512_SECRET_KEY_LENGTH,
79 encapsulate_key_size: ML_KEM_512_PUBLIC_KEY_LENGTH,
80 ciphertext_size: ML_KEM_512_CIPHERTEXT_LENGTH,
81 shared_secret_size: ML_KEM_512_SHARED_SECRET_LENGTH,
82};
83
84pub const ML_KEM_768: Algorithm<AlgorithmId> = Algorithm {
86 id: AlgorithmId::MlKem768,
87 decapsulate_key_size: ML_KEM_768_SECRET_KEY_LENGTH,
88 encapsulate_key_size: ML_KEM_768_PUBLIC_KEY_LENGTH,
89 ciphertext_size: ML_KEM_768_CIPHERTEXT_LENGTH,
90 shared_secret_size: ML_KEM_768_SHARED_SECRET_LENGTH,
91};
92
93pub const ML_KEM_1024: Algorithm<AlgorithmId> = Algorithm {
95 id: AlgorithmId::MlKem1024,
96 decapsulate_key_size: ML_KEM_1024_SECRET_KEY_LENGTH,
97 encapsulate_key_size: ML_KEM_1024_PUBLIC_KEY_LENGTH,
98 ciphertext_size: ML_KEM_1024_CIPHERTEXT_LENGTH,
99 shared_secret_size: ML_KEM_1024_SHARED_SECRET_LENGTH,
100};
101
102use crate::aws_lc::{NID_MLKEM1024, NID_MLKEM512, NID_MLKEM768};
103
104pub trait AlgorithmIdentifier:
106 Copy + Clone + Debug + PartialEq + crate::sealed::Sealed + 'static
107{
108 fn nid(self) -> i32;
110}
111
112#[derive(PartialEq)]
114pub struct Algorithm<Id = AlgorithmId>
115where
116 Id: AlgorithmIdentifier,
117{
118 pub(crate) id: Id,
119 pub(crate) decapsulate_key_size: usize,
120 pub(crate) encapsulate_key_size: usize,
121 pub(crate) ciphertext_size: usize,
122 pub(crate) shared_secret_size: usize,
123}
124
125impl<Id> Algorithm<Id>
126where
127 Id: AlgorithmIdentifier,
128{
129 #[must_use]
131 pub fn id(&self) -> Id {
132 self.id
133 }
134
135 #[inline]
136 #[allow(dead_code)]
137 pub(crate) fn decapsulate_key_size(&self) -> usize {
138 self.decapsulate_key_size
139 }
140
141 #[inline]
142 pub(crate) fn encapsulate_key_size(&self) -> usize {
143 self.encapsulate_key_size
144 }
145
146 #[inline]
147 pub(crate) fn ciphertext_size(&self) -> usize {
148 self.ciphertext_size
149 }
150
151 #[inline]
152 pub(crate) fn shared_secret_size(&self) -> usize {
153 self.shared_secret_size
154 }
155}
156
157impl<Id> Debug for Algorithm<Id>
158where
159 Id: AlgorithmIdentifier,
160{
161 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
162 Debug::fmt(&self.id, f)
163 }
164}
165
166pub struct DecapsulationKey<Id = AlgorithmId>
168where
169 Id: AlgorithmIdentifier,
170{
171 algorithm: &'static Algorithm<Id>,
172 evp_pkey: LcPtr<EVP_PKEY>,
173}
174
175#[non_exhaustive]
177#[derive(Clone, Copy, Debug, PartialEq)]
178pub enum AlgorithmId {
179 MlKem512,
181
182 MlKem768,
184
185 MlKem1024,
187}
188
189impl AlgorithmIdentifier for AlgorithmId {
190 fn nid(self) -> i32 {
191 match self {
192 AlgorithmId::MlKem512 => NID_MLKEM512,
193 AlgorithmId::MlKem768 => NID_MLKEM768,
194 AlgorithmId::MlKem1024 => NID_MLKEM1024,
195 }
196 }
197}
198
199impl crate::sealed::Sealed for AlgorithmId {}
200
201impl<Id> DecapsulationKey<Id>
202where
203 Id: AlgorithmIdentifier,
204{
205 pub fn new(alg: &'static Algorithm<Id>, bytes: &[u8]) -> Result<Self, KeyRejected> {
235 match bytes.len().cmp(&alg.decapsulate_key_size()) {
236 Ordering::Less => Err(KeyRejected::too_small()),
237 Ordering::Greater => Err(KeyRejected::too_large()),
238 Ordering::Equal => Ok(()),
239 }?;
240 let evp_pkey = LcPtr::new(unsafe {
241 EVP_PKEY_kem_new_raw_secret_key(alg.id.nid(), bytes.as_ptr(), bytes.len())
242 })?;
243 Ok(DecapsulationKey {
244 algorithm: alg,
245 evp_pkey,
246 })
247 }
248
249 pub fn generate(alg: &'static Algorithm<Id>) -> Result<Self, Unspecified> {
254 let kyber_key = kem_key_generate(alg.id.nid())?;
255 Ok(DecapsulationKey {
256 algorithm: alg,
257 evp_pkey: kyber_key,
258 })
259 }
260
261 #[must_use]
263 pub fn algorithm(&self) -> &'static Algorithm<Id> {
264 self.algorithm
265 }
266
267 pub fn key_bytes(&self) -> Result<DecapsulationKeyBytes<'static>, Unspecified> {
276 let decapsulation_key_bytes = self.evp_pkey.as_const().marshal_raw_private_key()?;
277 debug_assert_eq!(
278 decapsulation_key_bytes.len(),
279 self.algorithm.decapsulate_key_size()
280 );
281 Ok(DecapsulationKeyBytes::new(decapsulation_key_bytes))
282 }
283
284 #[allow(clippy::missing_panics_doc)]
294 pub fn encapsulation_key(&self) -> Result<EncapsulationKey<Id>, Unspecified> {
295 let evp_pkey = self.evp_pkey.clone();
296
297 let encapsulation_key = EncapsulationKey {
298 algorithm: self.algorithm,
299 evp_pkey,
300 };
301
302 if encapsulation_key.key_bytes().is_err() {
305 return Err(Unspecified);
306 }
307
308 Ok(encapsulation_key)
309 }
310
311 #[allow(clippy::needless_pass_by_value)]
325 pub fn decapsulate(&self, ciphertext: Ciphertext<'_>) -> Result<SharedSecret, Unspecified> {
326 let mut shared_secret_len = self.algorithm.shared_secret_size();
327 let mut shared_secret: Vec<u8> = vec![0u8; shared_secret_len];
328
329 let mut ctx = self.evp_pkey.create_EVP_PKEY_CTX()?;
330
331 let ciphertext = ciphertext.as_ref();
332
333 if 1 != unsafe {
334 EVP_PKEY_decapsulate(
335 ctx.as_mut_ptr(),
336 shared_secret.as_mut_ptr(),
337 &mut shared_secret_len,
338 ciphertext.as_ptr().cast_mut(),
340 ciphertext.len(),
341 )
342 } {
343 return Err(Unspecified);
344 }
345
346 debug_assert_eq!(shared_secret_len, shared_secret.len());
351 shared_secret.truncate(shared_secret_len);
352
353 Ok(SharedSecret(shared_secret.into_boxed_slice()))
354 }
355}
356
357unsafe impl<Id> Send for DecapsulationKey<Id> where Id: AlgorithmIdentifier {}
358
359unsafe impl<Id> Sync for DecapsulationKey<Id> where Id: AlgorithmIdentifier {}
360
361impl<Id> Debug for DecapsulationKey<Id>
362where
363 Id: AlgorithmIdentifier,
364{
365 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
366 f.debug_struct("DecapsulationKey")
367 .field("algorithm", &self.algorithm)
368 .finish_non_exhaustive()
369 }
370}
371
372generated_encodings!(
373 (EncapsulationKeyBytes, EncapsulationKeyBytesType),
374 (DecapsulationKeyBytes, DecapsulationKeyBytesType)
375);
376
377pub struct EncapsulationKey<Id = AlgorithmId>
380where
381 Id: AlgorithmIdentifier,
382{
383 algorithm: &'static Algorithm<Id>,
384 evp_pkey: LcPtr<EVP_PKEY>,
385}
386
387impl<Id> EncapsulationKey<Id>
388where
389 Id: AlgorithmIdentifier,
390{
391 #[must_use]
393 pub fn algorithm(&self) -> &'static Algorithm<Id> {
394 self.algorithm
395 }
396
397 pub fn encapsulate(&self) -> Result<(Ciphertext<'static>, SharedSecret), Unspecified> {
403 let mut ciphertext_len = self.algorithm.ciphertext_size();
404 let mut shared_secret_len = self.algorithm.shared_secret_size();
405 let mut ciphertext: Vec<u8> = vec![0u8; ciphertext_len];
406 let mut shared_secret: Vec<u8> = vec![0u8; shared_secret_len];
407
408 let mut ctx = self.evp_pkey.create_EVP_PKEY_CTX()?;
409
410 if 1 != unsafe {
411 EVP_PKEY_encapsulate(
412 ctx.as_mut_ptr(),
413 ciphertext.as_mut_ptr(),
414 &mut ciphertext_len,
415 shared_secret.as_mut_ptr(),
416 &mut shared_secret_len,
417 )
418 } {
419 return Err(Unspecified);
420 }
421
422 debug_assert_eq!(ciphertext_len, ciphertext.len());
428 ciphertext.truncate(ciphertext_len);
429 debug_assert_eq!(shared_secret_len, shared_secret.len());
430 shared_secret.truncate(shared_secret_len);
431
432 Ok((
433 Ciphertext::new(ciphertext),
434 SharedSecret::new(shared_secret.into_boxed_slice()),
435 ))
436 }
437
438 pub fn key_bytes(&self) -> Result<EncapsulationKeyBytes<'static>, Unspecified> {
443 let mut encapsulate_bytes = vec![0u8; self.algorithm.encapsulate_key_size()];
444 let encapsulate_key_size = self
445 .evp_pkey
446 .as_const()
447 .marshal_raw_public_to_buffer(&mut encapsulate_bytes)?;
448
449 debug_assert_eq!(encapsulate_key_size, encapsulate_bytes.len());
450 encapsulate_bytes.truncate(encapsulate_key_size);
451
452 Ok(EncapsulationKeyBytes::new(encapsulate_bytes))
453 }
454
455 pub fn new(alg: &'static Algorithm<Id>, bytes: &[u8]) -> Result<Self, KeyRejected> {
466 match bytes.len().cmp(&alg.encapsulate_key_size()) {
467 Ordering::Less => Err(KeyRejected::too_small()),
468 Ordering::Greater => Err(KeyRejected::too_large()),
469 Ordering::Equal => Ok(()),
470 }?;
471 let pubkey = LcPtr::new(unsafe {
472 EVP_PKEY_kem_new_raw_public_key(alg.id.nid(), bytes.as_ptr(), bytes.len())
473 })?;
474 Ok(EncapsulationKey {
475 algorithm: alg,
476 evp_pkey: pubkey,
477 })
478 }
479}
480
481unsafe impl<Id> Send for EncapsulationKey<Id> where Id: AlgorithmIdentifier {}
482
483unsafe impl<Id> Sync for EncapsulationKey<Id> where Id: AlgorithmIdentifier {}
484
485impl<Id> Debug for EncapsulationKey<Id>
486where
487 Id: AlgorithmIdentifier,
488{
489 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
490 f.debug_struct("EncapsulationKey")
491 .field("algorithm", &self.algorithm)
492 .finish_non_exhaustive()
493 }
494}
495
496pub struct Ciphertext<'a>(Cow<'a, [u8]>);
499
500impl<'a> Ciphertext<'a> {
501 fn new(value: Vec<u8>) -> Ciphertext<'a> {
502 Self(Cow::Owned(value))
503 }
504}
505
506impl Drop for Ciphertext<'_> {
507 fn drop(&mut self) {
508 if let Cow::Owned(ref mut v) = self.0 {
509 v.zeroize();
510 }
511 }
512}
513
514impl AsRef<[u8]> for Ciphertext<'_> {
515 fn as_ref(&self) -> &[u8] {
516 match self.0 {
517 Cow::Borrowed(v) => v,
518 Cow::Owned(ref v) => v.as_ref(),
519 }
520 }
521}
522
523impl<'a> From<&'a [u8]> for Ciphertext<'a> {
524 fn from(value: &'a [u8]) -> Self {
525 Self(Cow::Borrowed(value))
526 }
527}
528
529pub struct SharedSecret(Box<[u8]>);
531
532impl SharedSecret {
533 fn new(value: Box<[u8]>) -> Self {
534 Self(value)
535 }
536}
537
538impl Drop for SharedSecret {
539 fn drop(&mut self) {
540 self.0.zeroize();
541 }
542}
543
544impl AsRef<[u8]> for SharedSecret {
545 fn as_ref(&self) -> &[u8] {
546 self.0.as_ref()
547 }
548}
549
550#[inline]
552fn kem_key_generate(nid: i32) -> Result<LcPtr<EVP_PKEY>, Unspecified> {
553 let params_fn = |ctx| {
554 if 1 == unsafe { EVP_PKEY_CTX_kem_set_params(ctx, nid) } {
555 Ok(())
556 } else {
557 Err(())
558 }
559 };
560
561 LcPtr::<EVP_PKEY>::generate(EVP_PKEY_KEM, Some(params_fn))
562}
563
564#[cfg(test)]
565mod tests {
566 use super::{Ciphertext, DecapsulationKey, EncapsulationKey, SharedSecret};
567 use crate::error::KeyRejected;
568
569 use crate::kem::{ML_KEM_1024, ML_KEM_512, ML_KEM_768};
570
571 #[test]
572 fn ciphertext() {
573 let ciphertext_bytes = vec![42u8; 4];
574 let ciphertext = Ciphertext::from(ciphertext_bytes.as_ref());
575 assert_eq!(ciphertext.as_ref(), &[42, 42, 42, 42]);
576 drop(ciphertext);
577
578 let ciphertext_bytes = vec![42u8; 4];
579 let ciphertext = Ciphertext::<'static>::new(ciphertext_bytes);
580 assert_eq!(ciphertext.as_ref(), &[42, 42, 42, 42]);
581 }
582
583 #[test]
584 fn shared_secret() {
585 let secret_bytes = vec![42u8; 4];
586 let shared_secret = SharedSecret::new(secret_bytes.into_boxed_slice());
587 assert_eq!(shared_secret.as_ref(), &[42, 42, 42, 42]);
588 }
589
590 #[test]
591 fn test_kem_serialize() {
592 for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
593 let priv_key = DecapsulationKey::generate(algorithm).unwrap();
594 assert_eq!(priv_key.algorithm(), algorithm);
595
596 let priv_key_raw_bytes = priv_key.key_bytes().unwrap();
598 assert_eq!(
599 priv_key_raw_bytes.as_ref().len(),
600 algorithm.decapsulate_key_size()
601 );
602 let priv_key_from_bytes =
603 DecapsulationKey::new(algorithm, priv_key_raw_bytes.as_ref()).unwrap();
604
605 assert_eq!(
606 priv_key.key_bytes().unwrap().as_ref(),
607 priv_key_from_bytes.key_bytes().unwrap().as_ref()
608 );
609 assert_eq!(priv_key.algorithm(), priv_key_from_bytes.algorithm());
610
611 let pub_key = priv_key.encapsulation_key().unwrap();
613 let pubkey_raw_bytes = pub_key.key_bytes().unwrap();
614 let pub_key_from_bytes =
615 EncapsulationKey::new(algorithm, pubkey_raw_bytes.as_ref()).unwrap();
616
617 assert_eq!(
618 pub_key.key_bytes().unwrap().as_ref(),
619 pub_key_from_bytes.key_bytes().unwrap().as_ref()
620 );
621 assert_eq!(pub_key.algorithm(), pub_key_from_bytes.algorithm());
622 }
623 }
624
625 #[test]
626 fn test_kem_wrong_sizes() {
627 for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
628 let too_long_bytes = vec![0u8; algorithm.encapsulate_key_size() + 1];
630 let long_pub_key_from_bytes = EncapsulationKey::new(algorithm, &too_long_bytes);
631 assert_eq!(
632 long_pub_key_from_bytes.err(),
633 Some(KeyRejected::too_large())
634 );
635
636 let too_short_bytes = vec![0u8; algorithm.encapsulate_key_size() - 1];
637 let short_pub_key_from_bytes = EncapsulationKey::new(algorithm, &too_short_bytes);
638 assert_eq!(
639 short_pub_key_from_bytes.err(),
640 Some(KeyRejected::too_small())
641 );
642
643 let too_long_bytes = vec![0u8; algorithm.decapsulate_key_size() + 1];
645 let long_priv_key_from_bytes = DecapsulationKey::new(algorithm, &too_long_bytes);
646 assert_eq!(
647 long_priv_key_from_bytes.err(),
648 Some(KeyRejected::too_large())
649 );
650
651 let too_short_bytes = vec![0u8; algorithm.decapsulate_key_size() - 1];
652 let short_priv_key_from_bytes = DecapsulationKey::new(algorithm, &too_short_bytes);
653 assert_eq!(
654 short_priv_key_from_bytes.err(),
655 Some(KeyRejected::too_small())
656 );
657 }
658 }
659
660 #[test]
661 fn test_kem_e2e() {
662 for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
663 let priv_key = DecapsulationKey::generate(algorithm).unwrap();
664 assert_eq!(priv_key.algorithm(), algorithm);
665
666 let priv_key_bytes = priv_key.key_bytes().unwrap();
668 let priv_key_from_bytes =
669 DecapsulationKey::new(algorithm, priv_key_bytes.as_ref()).unwrap();
670
671 assert!(priv_key_from_bytes.encapsulation_key().is_err());
673
674 let pub_key = priv_key.encapsulation_key().unwrap();
675
676 let (alice_ciphertext, alice_secret) =
677 pub_key.encapsulate().expect("encapsulate successful");
678
679 let bob_secret = priv_key_from_bytes
681 .decapsulate(alice_ciphertext)
682 .expect("decapsulate successful");
683
684 assert_eq!(alice_secret.as_ref(), bob_secret.as_ref());
685 }
686 }
687
688 #[test]
689 fn test_serialized_kem_e2e() {
690 for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
691 let priv_key = DecapsulationKey::generate(algorithm).unwrap();
692 assert_eq!(priv_key.algorithm(), algorithm);
693
694 let pub_key = priv_key.encapsulation_key().unwrap();
695
696 let pub_key_bytes = pub_key.key_bytes().unwrap();
698
699 let priv_key_bytes = priv_key.key_bytes().unwrap();
701
702 drop(pub_key);
704 drop(priv_key);
705
706 let retrieved_pub_key =
707 EncapsulationKey::new(algorithm, pub_key_bytes.as_ref()).unwrap();
708 let (ciphertext, bob_secret) = retrieved_pub_key
709 .encapsulate()
710 .expect("encapsulate successful");
711
712 let retrieved_priv_key =
714 DecapsulationKey::new(algorithm, priv_key_bytes.as_ref()).unwrap();
715 let alice_secret = retrieved_priv_key
716 .decapsulate(ciphertext)
717 .expect("decapsulate successful");
718
719 assert_eq!(alice_secret.as_ref(), bob_secret.as_ref());
720 }
721 }
722
723 #[test]
724 fn test_decapsulation_key_serialization_roundtrip() {
725 for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
726 let original_key = DecapsulationKey::generate(algorithm).unwrap();
728
729 let key_bytes = original_key.key_bytes().unwrap();
731 assert_eq!(key_bytes.as_ref().len(), algorithm.decapsulate_key_size());
732
733 let reconstructed_key = DecapsulationKey::new(algorithm, key_bytes.as_ref()).unwrap();
735
736 assert_eq!(original_key.algorithm(), reconstructed_key.algorithm());
738 assert_eq!(original_key.algorithm(), algorithm);
739
740 let key_bytes_2 = reconstructed_key.key_bytes().unwrap();
742 assert_eq!(key_bytes.as_ref(), key_bytes_2.as_ref());
743
744 let pub_key = original_key.encapsulation_key().unwrap();
746 let (ciphertext, expected_secret) =
747 pub_key.encapsulate().expect("encapsulate successful");
748
749 let secret_from_original = original_key
750 .decapsulate(Ciphertext::from(ciphertext.as_ref()))
751 .expect("decapsulate with original key");
752 let secret_from_reconstructed = reconstructed_key
753 .decapsulate(Ciphertext::from(ciphertext.as_ref()))
754 .expect("decapsulate with reconstructed key");
755
756 assert_eq!(expected_secret.as_ref(), secret_from_original.as_ref());
758 assert_eq!(expected_secret.as_ref(), secret_from_reconstructed.as_ref());
759
760 assert_eq!(expected_secret.as_ref().len(), algorithm.shared_secret_size);
762 }
763 }
764
765 #[test]
766 fn test_decapsulation_key_zeroed_bytes() {
767 for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
771 let zeroed_bytes = vec![0u8; algorithm.decapsulate_key_size()];
772
773 let key_from_zeroed = DecapsulationKey::new(algorithm, &zeroed_bytes);
776 assert!(
777 key_from_zeroed.is_ok(),
778 "DecapsulationKey::new should accept zeroed bytes of correct size for {:?}",
779 algorithm.id()
780 );
781
782 let key = key_from_zeroed.unwrap();
783
784 let key_bytes = key.key_bytes();
786 assert!(
787 key_bytes.is_ok(),
788 "key_bytes() should succeed for key constructed from zeroed bytes"
789 );
790 assert_eq!(key_bytes.unwrap().as_ref(), zeroed_bytes.as_slice());
791
792 assert!(
794 key.encapsulation_key().is_err(),
795 "encapsulation_key() should fail for key constructed from raw bytes"
796 );
797
798 let valid_key = DecapsulationKey::generate(algorithm).unwrap();
801 let valid_pub_key = valid_key.encapsulation_key().unwrap();
802 let (ciphertext, _) = valid_pub_key.encapsulate().unwrap();
803
804 let decapsulate_result = key.decapsulate(Ciphertext::from(ciphertext.as_ref()));
808 assert!(
809 decapsulate_result.is_err(),
810 "decapsulate should fail with invalid (zeroed) key material for {:?}",
811 algorithm.id()
812 );
813 }
814 }
815
816 #[test]
817 fn test_cross_algorithm_key_rejection() {
818 let algorithms = [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024];
821
822 for source_alg in &algorithms {
823 let key = DecapsulationKey::generate(source_alg).unwrap();
824 let key_bytes = key.key_bytes().unwrap();
825
826 for target_alg in &algorithms {
827 if source_alg.id() == target_alg.id() {
828 let result = DecapsulationKey::new(target_alg, key_bytes.as_ref());
830 assert!(
831 result.is_ok(),
832 "Same algorithm should accept its own key bytes"
833 );
834 } else {
835 let result = DecapsulationKey::new(target_alg, key_bytes.as_ref());
837 assert!(
838 result.is_err(),
839 "Algorithm {:?} should reject key bytes from {:?}",
840 target_alg.id(),
841 source_alg.id()
842 );
843
844 let err = result.err().unwrap();
846 let source_size = source_alg.decapsulate_key_size();
847 let target_size = target_alg.decapsulate_key_size();
848 if source_size < target_size {
849 assert_eq!(
850 err,
851 KeyRejected::too_small(),
852 "Smaller key should be rejected as too_small"
853 );
854 } else {
855 assert_eq!(
856 err,
857 KeyRejected::too_large(),
858 "Larger key should be rejected as too_large"
859 );
860 }
861 }
862 }
863 }
864
865 for source_alg in &algorithms {
867 let decap_key = DecapsulationKey::generate(source_alg).unwrap();
868 let encap_key = decap_key.encapsulation_key().unwrap();
869 let key_bytes = encap_key.key_bytes().unwrap();
870
871 for target_alg in &algorithms {
872 if source_alg.id() == target_alg.id() {
873 let result = EncapsulationKey::new(target_alg, key_bytes.as_ref());
874 assert!(
875 result.is_ok(),
876 "Same algorithm should accept its own encapsulation key bytes"
877 );
878 } else {
879 let result = EncapsulationKey::new(target_alg, key_bytes.as_ref());
880 assert!(
881 result.is_err(),
882 "Algorithm {:?} should reject encapsulation key bytes from {:?}",
883 target_alg.id(),
884 source_alg.id()
885 );
886 }
887 }
888 }
889 }
890
891 #[test]
892 fn test_debug_fmt() {
893 let private = DecapsulationKey::generate(&ML_KEM_512).expect("successful generation");
894 assert_eq!(
895 format!("{private:?}"),
896 "DecapsulationKey { algorithm: MlKem512, .. }"
897 );
898 assert_eq!(
899 format!(
900 "{:?}",
901 private.encapsulation_key().expect("public key retrievable")
902 ),
903 "EncapsulationKey { algorithm: MlKem512, .. }"
904 );
905 }
906}