ml_kem/
encapsulation_key.rs1use crate::{
2 B32, SharedKey,
3 crypto::{G, H},
4 kem::{InvalidKey, Kem, Key, KeyExport, KeySizeUser, TryKeyInit},
5 param::{EncapsulationKeySize, KemParams},
6 pke::EncryptionKey,
7};
8use array::sizes::U32;
9use kem::{Ciphertext, Encapsulate, Generate};
10use module_lattice::MaybeBox;
11use rand_core::CryptoRng;
12
13#[derive(Clone, Debug)]
16pub struct EncapsulationKey<P>
17where
18 P: KemParams,
19{
20 ek_pke: MaybeBox<EncryptionKey<P>>,
21 h: B32,
22}
23
24impl<P> EncapsulationKey<P>
25where
26 P: Kem<SharedKeySize = U32> + KemParams,
27{
28 pub fn new(encapsulation_key: &Key<Self>) -> Result<Self, InvalidKey> {
33 EncryptionKey::from_bytes(encapsulation_key)
34 .map(Self::from_encryption_key)
35 .map_err(|_| InvalidKey)
36 }
37
38 #[cfg_attr(not(feature = "hazmat"), doc(hidden))]
44 #[must_use]
45 pub fn encapsulate_deterministic(&self, m: &B32) -> (Ciphertext<P>, SharedKey) {
46 let (K, r) = G(&[m, &self.h]);
47 let c = self.ek_pke.encrypt(m, &r);
48 (c, K)
49 }
50
51 #[inline]
53 pub(crate) fn from_encryption_key(ek_pke: EncryptionKey<P>) -> Self {
54 let h = H(ek_pke.to_bytes());
55 Self {
56 ek_pke: MaybeBox::new(ek_pke),
57 h,
58 }
59 }
60
61 pub(crate) fn ek_pke(&self) -> &EncryptionKey<P> {
63 &self.ek_pke
64 }
65
66 pub(crate) fn h(&self) -> B32 {
68 self.h
69 }
70}
71
72impl<P> Encapsulate for EncapsulationKey<P>
73where
74 P: Kem + KemParams,
75{
76 type Kem = P;
77
78 fn encapsulate_with_rng<R>(&self, rng: &mut R) -> (Ciphertext<P>, SharedKey)
79 where
80 R: CryptoRng + ?Sized,
81 {
82 let m = B32::generate_from_rng(rng);
83 self.encapsulate_deterministic(&m)
84 }
85}
86
87impl<P> KeyExport for EncapsulationKey<P>
88where
89 P: KemParams,
90{
91 fn to_bytes(&self) -> Key<Self> {
92 self.ek_pke.to_bytes()
93 }
94}
95
96impl<P> KeySizeUser for EncapsulationKey<P>
97where
98 P: KemParams,
99{
100 type KeySize = EncapsulationKeySize<P>;
101}
102
103impl<P> TryKeyInit for EncapsulationKey<P>
104where
105 P: KemParams,
106{
107 fn new(encapsulation_key: &Key<Self>) -> Result<Self, InvalidKey> {
108 Self::new(encapsulation_key)
109 }
110}
111
112impl<P> Eq for EncapsulationKey<P> where P: KemParams {}
113impl<P> PartialEq for EncapsulationKey<P>
114where
115 P: KemParams,
116{
117 fn eq(&self, other: &Self) -> bool {
118 self.ek_pke == other.ek_pke && self.h == other.h
120 }
121}