1use crate::aws_lc::{
49 EVP_PKEY_CTX_kem_set_params, EVP_PKEY_decapsulate, EVP_PKEY_encapsulate,
50 EVP_PKEY_kem_new_raw_public_key, EVP_PKEY, EVP_PKEY_KEM,
51};
52use crate::buffer::Buffer;
53use crate::encoding::generated_encodings;
54use crate::error::{KeyRejected, Unspecified};
55use crate::ptr::LcPtr;
56use alloc::borrow::Cow;
57use core::cmp::Ordering;
58use zeroize::Zeroize;
59
60const ML_KEM_512_SHARED_SECRET_LENGTH: usize = 32;
61const ML_KEM_512_PUBLIC_KEY_LENGTH: usize = 800;
62const ML_KEM_512_SECRET_KEY_LENGTH: usize = 1632;
63const ML_KEM_512_CIPHERTEXT_LENGTH: usize = 768;
64
65const ML_KEM_768_SHARED_SECRET_LENGTH: usize = 32;
66const ML_KEM_768_PUBLIC_KEY_LENGTH: usize = 1184;
67const ML_KEM_768_SECRET_KEY_LENGTH: usize = 2400;
68const ML_KEM_768_CIPHERTEXT_LENGTH: usize = 1088;
69
70const ML_KEM_1024_SHARED_SECRET_LENGTH: usize = 32;
71const ML_KEM_1024_PUBLIC_KEY_LENGTH: usize = 1568;
72const ML_KEM_1024_SECRET_KEY_LENGTH: usize = 3168;
73const ML_KEM_1024_CIPHERTEXT_LENGTH: usize = 1568;
74
75pub const ML_KEM_512: Algorithm<AlgorithmId> = Algorithm {
77 id: AlgorithmId::MlKem512,
78 decapsulate_key_size: ML_KEM_512_SECRET_KEY_LENGTH,
79 encapsulate_key_size: ML_KEM_512_PUBLIC_KEY_LENGTH,
80 ciphertext_size: ML_KEM_512_CIPHERTEXT_LENGTH,
81 shared_secret_size: ML_KEM_512_SHARED_SECRET_LENGTH,
82};
83
84pub const ML_KEM_768: Algorithm<AlgorithmId> = Algorithm {
86 id: AlgorithmId::MlKem768,
87 decapsulate_key_size: ML_KEM_768_SECRET_KEY_LENGTH,
88 encapsulate_key_size: ML_KEM_768_PUBLIC_KEY_LENGTH,
89 ciphertext_size: ML_KEM_768_CIPHERTEXT_LENGTH,
90 shared_secret_size: ML_KEM_768_SHARED_SECRET_LENGTH,
91};
92
93pub const ML_KEM_1024: Algorithm<AlgorithmId> = Algorithm {
95 id: AlgorithmId::MlKem1024,
96 decapsulate_key_size: ML_KEM_1024_SECRET_KEY_LENGTH,
97 encapsulate_key_size: ML_KEM_1024_PUBLIC_KEY_LENGTH,
98 ciphertext_size: ML_KEM_1024_CIPHERTEXT_LENGTH,
99 shared_secret_size: ML_KEM_1024_SHARED_SECRET_LENGTH,
100};
101
102use crate::aws_lc::{NID_MLKEM1024, NID_MLKEM512, NID_MLKEM768};
103
104pub trait AlgorithmIdentifier:
106 Copy + Clone + Debug + PartialEq + crate::sealed::Sealed + 'static
107{
108 fn nid(self) -> i32;
110}
111
112#[derive(PartialEq)]
114pub struct Algorithm<Id = AlgorithmId>
115where
116 Id: AlgorithmIdentifier,
117{
118 pub(crate) id: Id,
119 pub(crate) decapsulate_key_size: usize,
120 pub(crate) encapsulate_key_size: usize,
121 pub(crate) ciphertext_size: usize,
122 pub(crate) shared_secret_size: usize,
123}
124
125impl<Id> Algorithm<Id>
126where
127 Id: AlgorithmIdentifier,
128{
129 #[must_use]
131 pub fn id(&self) -> Id {
132 self.id
133 }
134
135 #[inline]
136 #[allow(dead_code)]
137 pub(crate) fn decapsulate_key_size(&self) -> usize {
138 self.decapsulate_key_size
139 }
140
141 #[inline]
142 pub(crate) fn encapsulate_key_size(&self) -> usize {
143 self.encapsulate_key_size
144 }
145
146 #[inline]
147 pub(crate) fn ciphertext_size(&self) -> usize {
148 self.ciphertext_size
149 }
150
151 #[inline]
152 pub(crate) fn shared_secret_size(&self) -> usize {
153 self.shared_secret_size
154 }
155}
156
157impl<Id> Debug for Algorithm<Id>
158where
159 Id: AlgorithmIdentifier,
160{
161 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
162 Debug::fmt(&self.id, f)
163 }
164}
165
166pub struct DecapsulationKey<Id = AlgorithmId>
168where
169 Id: AlgorithmIdentifier,
170{
171 algorithm: &'static Algorithm<Id>,
172 evp_pkey: LcPtr<EVP_PKEY>,
173}
174
175#[non_exhaustive]
177#[derive(Clone, Copy, Debug, PartialEq)]
178pub enum AlgorithmId {
179 MlKem512,
181
182 MlKem768,
184
185 MlKem1024,
187}
188
189impl AlgorithmIdentifier for AlgorithmId {
190 fn nid(self) -> i32 {
191 match self {
192 AlgorithmId::MlKem512 => NID_MLKEM512,
193 AlgorithmId::MlKem768 => NID_MLKEM768,
194 AlgorithmId::MlKem1024 => NID_MLKEM1024,
195 }
196 }
197}
198
199impl crate::sealed::Sealed for AlgorithmId {}
200
201impl<Id> DecapsulationKey<Id>
202where
203 Id: AlgorithmIdentifier,
204{
205 pub fn generate(alg: &'static Algorithm<Id>) -> Result<Self, Unspecified> {
210 let kyber_key = kem_key_generate(alg.id.nid())?;
211 Ok(DecapsulationKey {
212 algorithm: alg,
213 evp_pkey: kyber_key,
214 })
215 }
216
217 #[must_use]
219 pub fn algorithm(&self) -> &'static Algorithm<Id> {
220 self.algorithm
221 }
222
223 #[allow(clippy::missing_panics_doc)]
228 pub fn encapsulation_key(&self) -> Result<EncapsulationKey<Id>, Unspecified> {
229 let evp_pkey = self.evp_pkey.clone();
230
231 Ok(EncapsulationKey {
232 algorithm: self.algorithm,
233 evp_pkey,
234 })
235 }
236
237 #[allow(clippy::needless_pass_by_value)]
245 pub fn decapsulate(&self, ciphertext: Ciphertext<'_>) -> Result<SharedSecret, Unspecified> {
246 let mut shared_secret_len = self.algorithm.shared_secret_size();
247 let mut shared_secret: Vec<u8> = vec![0u8; shared_secret_len];
248
249 let mut ctx = self.evp_pkey.create_EVP_PKEY_CTX()?;
250
251 let ciphertext = ciphertext.as_ref();
252
253 if 1 != unsafe {
254 EVP_PKEY_decapsulate(
255 *ctx.as_mut(),
256 shared_secret.as_mut_ptr(),
257 &mut shared_secret_len,
258 ciphertext.as_ptr() as *mut u8,
260 ciphertext.len(),
261 )
262 } {
263 return Err(Unspecified);
264 }
265
266 debug_assert_eq!(shared_secret_len, shared_secret.len());
271 shared_secret.truncate(shared_secret_len);
272
273 Ok(SharedSecret(shared_secret.into_boxed_slice()))
274 }
275}
276
277unsafe impl<Id> Send for DecapsulationKey<Id> where Id: AlgorithmIdentifier {}
278
279unsafe impl<Id> Sync for DecapsulationKey<Id> where Id: AlgorithmIdentifier {}
280
281impl<Id> Debug for DecapsulationKey<Id>
282where
283 Id: AlgorithmIdentifier,
284{
285 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
286 f.debug_struct("DecapsulationKey")
287 .field("algorithm", &self.algorithm)
288 .finish_non_exhaustive()
289 }
290}
291
292generated_encodings!((EncapsulationKeyBytes, EncapsulationKeyBytesType));
293
294pub struct EncapsulationKey<Id = AlgorithmId>
297where
298 Id: AlgorithmIdentifier,
299{
300 algorithm: &'static Algorithm<Id>,
301 evp_pkey: LcPtr<EVP_PKEY>,
302}
303
304impl<Id> EncapsulationKey<Id>
305where
306 Id: AlgorithmIdentifier,
307{
308 #[must_use]
310 pub fn algorithm(&self) -> &'static Algorithm<Id> {
311 self.algorithm
312 }
313
314 pub fn encapsulate(&self) -> Result<(Ciphertext<'static>, SharedSecret), Unspecified> {
320 let mut ciphertext_len = self.algorithm.ciphertext_size();
321 let mut shared_secret_len = self.algorithm.shared_secret_size();
322 let mut ciphertext: Vec<u8> = vec![0u8; ciphertext_len];
323 let mut shared_secret: Vec<u8> = vec![0u8; shared_secret_len];
324
325 let mut ctx = self.evp_pkey.create_EVP_PKEY_CTX()?;
326
327 if 1 != unsafe {
328 EVP_PKEY_encapsulate(
329 *ctx.as_mut(),
330 ciphertext.as_mut_ptr(),
331 &mut ciphertext_len,
332 shared_secret.as_mut_ptr(),
333 &mut shared_secret_len,
334 )
335 } {
336 return Err(Unspecified);
337 }
338
339 debug_assert_eq!(ciphertext_len, ciphertext.len());
345 ciphertext.truncate(ciphertext_len);
346 debug_assert_eq!(shared_secret_len, shared_secret.len());
347 shared_secret.truncate(shared_secret_len);
348
349 Ok((
350 Ciphertext::new(ciphertext),
351 SharedSecret::new(shared_secret.into_boxed_slice()),
352 ))
353 }
354
355 pub fn key_bytes(&self) -> Result<EncapsulationKeyBytes<'static>, Unspecified> {
360 let mut encapsulate_bytes = vec![0u8; self.algorithm.encapsulate_key_size()];
361 let encapsulate_key_size = self
362 .evp_pkey
363 .as_const()
364 .marshal_raw_public_to_buffer(&mut encapsulate_bytes)?;
365
366 debug_assert_eq!(encapsulate_key_size, encapsulate_bytes.len());
367 encapsulate_bytes.truncate(encapsulate_key_size);
368
369 Ok(EncapsulationKeyBytes::new(encapsulate_bytes))
370 }
371
372 pub fn new(alg: &'static Algorithm<Id>, bytes: &[u8]) -> Result<Self, KeyRejected> {
383 match bytes.len().cmp(&alg.encapsulate_key_size()) {
384 Ordering::Less => Err(KeyRejected::too_small()),
385 Ordering::Greater => Err(KeyRejected::too_large()),
386 Ordering::Equal => Ok(()),
387 }?;
388 let pubkey = LcPtr::new(unsafe {
389 EVP_PKEY_kem_new_raw_public_key(alg.id.nid(), bytes.as_ptr(), bytes.len())
390 })?;
391 Ok(EncapsulationKey {
392 algorithm: alg,
393 evp_pkey: pubkey,
394 })
395 }
396}
397
398unsafe impl<Id> Send for EncapsulationKey<Id> where Id: AlgorithmIdentifier {}
399
400unsafe impl<Id> Sync for EncapsulationKey<Id> where Id: AlgorithmIdentifier {}
401
402impl<Id> Debug for EncapsulationKey<Id>
403where
404 Id: AlgorithmIdentifier,
405{
406 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
407 f.debug_struct("EncapsulationKey")
408 .field("algorithm", &self.algorithm)
409 .finish_non_exhaustive()
410 }
411}
412
413pub struct Ciphertext<'a>(Cow<'a, [u8]>);
416
417impl<'a> Ciphertext<'a> {
418 fn new(value: Vec<u8>) -> Ciphertext<'a> {
419 Self(Cow::Owned(value))
420 }
421}
422
423impl Drop for Ciphertext<'_> {
424 fn drop(&mut self) {
425 if let Cow::Owned(ref mut v) = self.0 {
426 v.zeroize();
427 }
428 }
429}
430
431impl AsRef<[u8]> for Ciphertext<'_> {
432 fn as_ref(&self) -> &[u8] {
433 match self.0 {
434 Cow::Borrowed(v) => v,
435 Cow::Owned(ref v) => v.as_ref(),
436 }
437 }
438}
439
440impl<'a> From<&'a [u8]> for Ciphertext<'a> {
441 fn from(value: &'a [u8]) -> Self {
442 Self(Cow::Borrowed(value))
443 }
444}
445
446pub struct SharedSecret(Box<[u8]>);
448
449impl SharedSecret {
450 fn new(value: Box<[u8]>) -> Self {
451 Self(value)
452 }
453}
454
455impl Drop for SharedSecret {
456 fn drop(&mut self) {
457 self.0.zeroize();
458 }
459}
460
461impl AsRef<[u8]> for SharedSecret {
462 fn as_ref(&self) -> &[u8] {
463 self.0.as_ref()
464 }
465}
466
467#[inline]
469fn kem_key_generate(nid: i32) -> Result<LcPtr<EVP_PKEY>, Unspecified> {
470 let params_fn = |ctx| {
471 if 1 == unsafe { EVP_PKEY_CTX_kem_set_params(ctx, nid) } {
472 Ok(())
473 } else {
474 Err(())
475 }
476 };
477
478 LcPtr::<EVP_PKEY>::generate(EVP_PKEY_KEM, Some(params_fn))
479}
480
481#[cfg(test)]
482mod tests {
483 use super::{Ciphertext, DecapsulationKey, EncapsulationKey, SharedSecret};
484 use crate::error::KeyRejected;
485
486 use crate::kem::{ML_KEM_1024, ML_KEM_512, ML_KEM_768};
487
488 #[test]
489 fn ciphertext() {
490 let ciphertext_bytes = vec![42u8; 4];
491 let ciphertext = Ciphertext::from(ciphertext_bytes.as_ref());
492 assert_eq!(ciphertext.as_ref(), &[42, 42, 42, 42]);
493 drop(ciphertext);
494
495 let ciphertext_bytes = vec![42u8; 4];
496 let ciphertext = Ciphertext::<'static>::new(ciphertext_bytes);
497 assert_eq!(ciphertext.as_ref(), &[42, 42, 42, 42]);
498 }
499
500 #[test]
501 fn shared_secret() {
502 let secret_bytes = vec![42u8; 4];
503 let shared_secret = SharedSecret::new(secret_bytes.into_boxed_slice());
504 assert_eq!(shared_secret.as_ref(), &[42, 42, 42, 42]);
505 }
506
507 #[test]
508 fn test_kem_serialize() {
509 for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
510 let priv_key = DecapsulationKey::generate(algorithm).unwrap();
511 assert_eq!(priv_key.algorithm(), algorithm);
512
513 let pub_key = priv_key.encapsulation_key().unwrap();
514 let pubkey_raw_bytes = pub_key.key_bytes().unwrap();
515 let pub_key_from_bytes =
516 EncapsulationKey::new(algorithm, pubkey_raw_bytes.as_ref()).unwrap();
517
518 assert_eq!(
519 pub_key.key_bytes().unwrap().as_ref(),
520 pub_key_from_bytes.key_bytes().unwrap().as_ref()
521 );
522 assert_eq!(pub_key.algorithm(), pub_key_from_bytes.algorithm());
523 }
524 }
525
526 #[test]
527 fn test_kem_wrong_sizes() {
528 for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
529 let too_long_bytes = vec![0u8; algorithm.encapsulate_key_size() + 1];
530 let long_pub_key_from_bytes = EncapsulationKey::new(algorithm, &too_long_bytes);
531 assert_eq!(
532 long_pub_key_from_bytes.err(),
533 Some(KeyRejected::too_large())
534 );
535
536 let too_short_bytes = vec![0u8; algorithm.encapsulate_key_size() - 1];
537 let short_pub_key_from_bytes = EncapsulationKey::new(algorithm, &too_short_bytes);
538 assert_eq!(
539 short_pub_key_from_bytes.err(),
540 Some(KeyRejected::too_small())
541 );
542 }
543 }
544
545 #[test]
546 fn test_kem_e2e() {
547 for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
548 let priv_key = DecapsulationKey::generate(algorithm).unwrap();
549 assert_eq!(priv_key.algorithm(), algorithm);
550
551 let pub_key = priv_key.encapsulation_key().unwrap();
552
553 let (alice_ciphertext, alice_secret) =
554 pub_key.encapsulate().expect("encapsulate successful");
555
556 let bob_secret = priv_key
557 .decapsulate(alice_ciphertext)
558 .expect("decapsulate successful");
559
560 assert_eq!(alice_secret.as_ref(), bob_secret.as_ref());
561 }
562 }
563
564 #[test]
565 fn test_serialized_kem_e2e() {
566 for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
567 let priv_key = DecapsulationKey::generate(algorithm).unwrap();
568 assert_eq!(priv_key.algorithm(), algorithm);
569
570 let pub_key = priv_key.encapsulation_key().unwrap();
571
572 let pub_key_bytes = pub_key.key_bytes().unwrap();
574
575 drop(pub_key);
577
578 let retrieved_pub_key =
579 EncapsulationKey::new(algorithm, pub_key_bytes.as_ref()).unwrap();
580 let (ciphertext, bob_secret) = retrieved_pub_key
581 .encapsulate()
582 .expect("encapsulate successful");
583
584 let alice_secret = priv_key
585 .decapsulate(ciphertext)
586 .expect("decapsulate successful");
587
588 assert_eq!(alice_secret.as_ref(), bob_secret.as_ref());
589 }
590 }
591
592 #[test]
593 fn test_debug_fmt() {
594 let private = DecapsulationKey::generate(&ML_KEM_512).expect("successful generation");
595 assert_eq!(
596 format!("{private:?}"),
597 "DecapsulationKey { algorithm: MlKem512, .. }"
598 );
599 assert_eq!(
600 format!(
601 "{:?}",
602 private.encapsulation_key().expect("public key retrievable")
603 ),
604 "EncapsulationKey { algorithm: MlKem512, .. }"
605 );
606 }
607}