1use crate::aws_lc::{HKDF_expand, HKDF};
41use crate::error::Unspecified;
42use crate::fips::indicator_check;
43use crate::{digest, hmac};
44use alloc::sync::Arc;
45use core::fmt;
46use zeroize::Zeroize;
47
48#[derive(Clone, Copy, Debug, Eq, PartialEq)]
50pub struct Algorithm(hmac::Algorithm);
51
52impl Algorithm {
53 #[inline]
55 #[must_use]
56 pub fn hmac_algorithm(&self) -> hmac::Algorithm {
57 self.0
58 }
59}
60
61pub const HKDF_SHA1_FOR_LEGACY_USE_ONLY: Algorithm = Algorithm(hmac::HMAC_SHA1_FOR_LEGACY_USE_ONLY);
63
64pub const HKDF_SHA256: Algorithm = Algorithm(hmac::HMAC_SHA256);
66
67pub const HKDF_SHA384: Algorithm = Algorithm(hmac::HMAC_SHA384);
69
70pub const HKDF_SHA512: Algorithm = Algorithm(hmac::HMAC_SHA512);
72
73const HKDF_INFO_DEFAULT_CAPACITY_LEN: usize = 300;
77
78const MAX_HKDF_PRK_LEN: usize = digest::MAX_OUTPUT_LEN;
81
82impl KeyType for Algorithm {
83 fn len(&self) -> usize {
84 self.0.digest_algorithm().output_len
85 }
86}
87
88pub struct Salt {
90 algorithm: Algorithm,
91 bytes: Box<[u8]>,
92}
93
94#[allow(clippy::missing_fields_in_debug)]
95impl fmt::Debug for Salt {
96 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
97 f.debug_struct("hkdf::Salt")
98 .field("algorithm", &self.algorithm.0)
99 .finish()
100 }
101}
102
103impl Drop for Salt {
104 fn drop(&mut self) {
105 self.bytes.zeroize();
107 }
108}
109
110impl Salt {
111 #[must_use]
129 pub fn new(algorithm: Algorithm, value: &[u8]) -> Self {
130 let bytes = value.to_vec().into_boxed_slice();
131 Self { algorithm, bytes }
132 }
133
134 #[inline]
141 #[must_use]
142 pub fn extract(&self, secret: &[u8]) -> Prk {
143 Prk {
144 algorithm: self.algorithm,
145 mode: PrkMode::ExtractExpand {
146 secret: Arc::from(ZeroizeBoxSlice::from(secret)),
147 salt: self.bytes.clone(),
148 },
149 }
150 }
151
152 #[inline]
154 #[must_use]
155 pub fn algorithm(&self) -> Algorithm {
156 Algorithm(self.algorithm.hmac_algorithm())
157 }
158}
159
160impl From<Okm<'_, Algorithm>> for Salt {
161 fn from(okm: Okm<'_, Algorithm>) -> Self {
162 let algorithm = okm.prk.algorithm;
163 let salt_len = okm.len().len();
164 let mut salt_bytes = vec![0u8; salt_len];
165 okm.fill(&mut salt_bytes).unwrap();
166 Self {
167 algorithm,
168 bytes: salt_bytes.into_boxed_slice(),
169 }
170 }
171}
172
173#[allow(clippy::len_without_is_empty)]
175pub trait KeyType {
176 fn len(&self) -> usize;
178}
179
180#[derive(Clone)]
181enum PrkMode {
182 Expand {
183 key_bytes: [u8; MAX_HKDF_PRK_LEN],
184 key_len: usize,
185 },
186 ExtractExpand {
187 secret: Arc<ZeroizeBoxSlice<u8>>,
188 salt: Box<[u8]>,
189 },
190}
191
192impl PrkMode {
193 fn fill(&self, algorithm: Algorithm, out: &mut [u8], info: &[u8]) -> Result<(), Unspecified> {
194 let digest = digest::match_digest_type(&algorithm.0.digest_algorithm().id).as_const_ptr();
195
196 match &self {
197 PrkMode::Expand { key_bytes, key_len } => unsafe {
198 if 1 != indicator_check!(HKDF_expand(
199 out.as_mut_ptr(),
200 out.len(),
201 digest,
202 key_bytes.as_ptr(),
203 *key_len,
204 info.as_ptr(),
205 info.len(),
206 )) {
207 return Err(Unspecified);
208 }
209 },
210 PrkMode::ExtractExpand { secret, salt } => {
211 if 1 != indicator_check!(unsafe {
212 HKDF(
213 out.as_mut_ptr(),
214 out.len(),
215 digest,
216 secret.as_ptr(),
217 secret.len(),
218 salt.as_ptr(),
219 salt.len(),
220 info.as_ptr(),
221 info.len(),
222 )
223 }) {
224 return Err(Unspecified);
225 }
226 }
227 }
228
229 Ok(())
230 }
231}
232
233impl fmt::Debug for PrkMode {
234 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
235 match self {
236 Self::Expand { .. } => f.debug_struct("Expand").finish_non_exhaustive(),
237 Self::ExtractExpand { .. } => f.debug_struct("ExtractExpand").finish_non_exhaustive(),
238 }
239 }
240}
241
242struct ZeroizeBoxSlice<T: Zeroize>(Box<[T]>);
243
244impl<T: Zeroize> core::ops::Deref for ZeroizeBoxSlice<T> {
245 type Target = [T];
246
247 fn deref(&self) -> &Self::Target {
248 &self.0
249 }
250}
251
252impl<T: Clone + Zeroize> From<&[T]> for ZeroizeBoxSlice<T> {
253 fn from(value: &[T]) -> Self {
254 Self(Vec::from(value).into_boxed_slice())
255 }
256}
257
258impl<T: Zeroize> Drop for ZeroizeBoxSlice<T> {
259 fn drop(&mut self) {
260 self.0.zeroize();
261 }
262}
263
264#[derive(Clone)]
266pub struct Prk {
267 algorithm: Algorithm,
268 mode: PrkMode,
269}
270
271#[allow(clippy::missing_fields_in_debug)]
272impl fmt::Debug for Prk {
273 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
274 f.debug_struct("hkdf::Prk")
275 .field("algorithm", &self.algorithm.0)
276 .field("mode", &self.mode)
277 .finish()
278 }
279}
280
281impl Prk {
282 #[must_use]
296 pub fn new_less_safe(algorithm: Algorithm, value: &[u8]) -> Self {
297 Prk::try_new_less_safe(algorithm, value).expect("Prk length limit exceeded.")
298 }
299
300 fn try_new_less_safe(algorithm: Algorithm, value: &[u8]) -> Result<Prk, Unspecified> {
301 let key_len = value.len();
302 if key_len > MAX_HKDF_PRK_LEN {
303 return Err(Unspecified);
304 }
305 let mut key_bytes = [0u8; MAX_HKDF_PRK_LEN];
306 key_bytes[0..key_len].copy_from_slice(value);
307 Ok(Self {
308 algorithm,
309 mode: PrkMode::Expand { key_bytes, key_len },
310 })
311 }
312
313 #[inline]
326 pub fn expand<'a, L: KeyType>(
327 &'a self,
328 info: &'a [&'a [u8]],
329 len: L,
330 ) -> Result<Okm<'a, L>, Unspecified> {
331 let len_cached = len.len();
332 if len_cached > 255 * self.algorithm.0.digest_algorithm().output_len {
333 return Err(Unspecified);
334 }
335 let mut info_bytes: Vec<u8> = Vec::with_capacity(HKDF_INFO_DEFAULT_CAPACITY_LEN);
336 let mut info_len = 0;
337 for &byte_ary in info {
338 info_bytes.extend_from_slice(byte_ary);
339 info_len += byte_ary.len();
340 }
341 let info_bytes = info_bytes.into_boxed_slice();
342 Ok(Okm {
343 prk: self,
344 info_bytes,
345 info_len,
346 len,
347 })
348 }
349}
350
351impl From<Okm<'_, Algorithm>> for Prk {
352 fn from(okm: Okm<Algorithm>) -> Self {
353 let algorithm = okm.len;
354 let key_len = okm.len.len();
355 let mut key_bytes = [0u8; MAX_HKDF_PRK_LEN];
356 okm.fill(&mut key_bytes[0..key_len]).unwrap();
357
358 Self {
359 algorithm,
360 mode: PrkMode::Expand { key_bytes, key_len },
361 }
362 }
363}
364
365pub struct Okm<'a, L: KeyType> {
370 prk: &'a Prk,
371 info_bytes: Box<[u8]>,
372 info_len: usize,
373 len: L,
374}
375
376impl<L: KeyType> fmt::Debug for Okm<'_, L> {
377 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
378 f.debug_struct("hkdf::Okm").field("prk", &self.prk).finish()
379 }
380}
381
382impl<L: KeyType> Drop for Okm<'_, L> {
383 fn drop(&mut self) {
384 self.info_bytes.zeroize();
385 }
386}
387
388impl<L: KeyType> Okm<'_, L> {
389 #[inline]
391 pub fn len(&self) -> &L {
392 &self.len
393 }
394
395 #[inline]
413 pub fn fill(self, out: &mut [u8]) -> Result<(), Unspecified> {
414 if out.len() != self.len.len() {
415 return Err(Unspecified);
416 }
417
418 self.prk
419 .mode
420 .fill(self.prk.algorithm, out, &self.info_bytes[..self.info_len])?;
421
422 Ok(())
423 }
424}
425
426#[cfg(test)]
427mod tests {
428 use crate::hkdf::{Salt, HKDF_SHA256, HKDF_SHA384};
429
430 #[cfg(feature = "fips")]
431 mod fips;
432
433 #[test]
434 fn hkdf_coverage() {
435 assert_ne!(HKDF_SHA256, HKDF_SHA384);
438 assert_eq!("Algorithm(Algorithm(SHA256))", format!("{HKDF_SHA256:?}"));
439 }
440
441 #[test]
442 fn test_debug() {
443 const SALT: &[u8; 32] = &[
444 29, 113, 120, 243, 11, 202, 39, 222, 206, 81, 163, 184, 122, 153, 52, 192, 98, 195,
445 240, 32, 34, 19, 160, 128, 178, 111, 97, 232, 113, 101, 221, 143,
446 ];
447 const SECRET1: &[u8; 32] = &[
448 157, 191, 36, 107, 110, 131, 193, 6, 175, 226, 193, 3, 168, 133, 165, 181, 65, 120,
449 194, 152, 31, 92, 37, 191, 73, 222, 41, 112, 207, 236, 196, 174,
450 ];
451
452 const INFO1: &[&[u8]] = &[
453 &[
454 2, 130, 61, 83, 192, 248, 63, 60, 211, 73, 169, 66, 101, 160, 196, 212, 250, 113,
455 ],
456 &[
457 80, 46, 248, 123, 78, 204, 171, 178, 67, 204, 96, 27, 131, 24,
458 ],
459 ];
460
461 let alg = HKDF_SHA256;
462 let salt = Salt::new(alg, SALT);
463 let prk = salt.extract(SECRET1);
464 let okm = prk.expand(INFO1, alg).unwrap();
465
466 assert_eq!(
467 "hkdf::Salt { algorithm: Algorithm(SHA256) }",
468 format!("{salt:?}")
469 );
470 assert_eq!(
471 "hkdf::Prk { algorithm: Algorithm(SHA256), mode: ExtractExpand { .. } }",
472 format!("{prk:?}")
473 );
474 assert_eq!(
475 "hkdf::Okm { prk: hkdf::Prk { algorithm: Algorithm(SHA256), mode: ExtractExpand { .. } } }",
476 format!("{okm:?}")
477 );
478 }
479
480 #[test]
481 fn test_long_salt() {
482 let long_salt = vec![0x42u8; 100];
484
485 let salt = Salt::new(HKDF_SHA256, &long_salt);
487
488 let secret = b"test secret key material";
490 let prk = salt.extract(secret);
491
492 let info_data = b"test context info";
494 let info = [info_data.as_slice()];
495 let okm = prk.expand(&info, HKDF_SHA256).unwrap();
496
497 let mut output = [0u8; 32];
499 okm.fill(&mut output).unwrap();
500
501 let very_long_salt = vec![0x55u8; 500];
503 let very_long_salt_obj = Salt::new(HKDF_SHA256, &very_long_salt);
504 let prk2 = very_long_salt_obj.extract(secret);
505 let okm2 = prk2.expand(&info, HKDF_SHA256).unwrap();
506 let mut output2 = [0u8; 32];
507 okm2.fill(&mut output2).unwrap();
508
509 assert_ne!(output, output2);
511 }
512}