1use super::aead_ctx::{self, AeadCtx};
5use super::{Aad, Algorithm, AlgorithmID, Nonce, Tag, UnboundKey};
6use crate::error::Unspecified;
7use core::fmt::Debug;
8use core::ops::RangeFrom;
9
10#[allow(clippy::module_name_repetitions)]
12#[derive(Debug, PartialEq, Eq, Clone, Copy)]
13#[non_exhaustive]
14pub enum TlsProtocolId {
15 TLS12,
17
18 TLS13,
20}
21
22#[allow(clippy::module_name_repetitions)]
33pub struct TlsRecordSealingKey {
34 key: UnboundKey,
37 protocol: TlsProtocolId,
38}
39
40impl TlsRecordSealingKey {
41 pub fn new(
47 algorithm: &'static Algorithm,
48 protocol: TlsProtocolId,
49 key_bytes: &[u8],
50 ) -> Result<Self, Unspecified> {
51 let ctx = match (algorithm.id, protocol) {
52 (AlgorithmID::AES_128_GCM, TlsProtocolId::TLS12) => AeadCtx::aes_128_gcm_tls12(
53 key_bytes,
54 algorithm.tag_len(),
55 aead_ctx::AeadDirection::Seal,
56 ),
57 (AlgorithmID::AES_128_GCM, TlsProtocolId::TLS13) => AeadCtx::aes_128_gcm_tls13(
58 key_bytes,
59 algorithm.tag_len(),
60 aead_ctx::AeadDirection::Seal,
61 ),
62 (AlgorithmID::AES_256_GCM, TlsProtocolId::TLS12) => AeadCtx::aes_256_gcm_tls12(
63 key_bytes,
64 algorithm.tag_len(),
65 aead_ctx::AeadDirection::Seal,
66 ),
67 (AlgorithmID::AES_256_GCM, TlsProtocolId::TLS13) => AeadCtx::aes_256_gcm_tls13(
68 key_bytes,
69 algorithm.tag_len(),
70 aead_ctx::AeadDirection::Seal,
71 ),
72 (
73 AlgorithmID::AES_128_GCM_SIV
74 | AlgorithmID::AES_192_GCM
75 | AlgorithmID::AES_256_GCM_SIV
76 | AlgorithmID::CHACHA20_POLY1305,
77 _,
78 ) => Err(Unspecified),
79 }?;
80 Ok(Self {
81 key: UnboundKey::from(ctx),
82 protocol,
83 })
84 }
85
86 #[inline]
95 #[allow(clippy::needless_pass_by_value)]
96 pub fn seal_in_place_append_tag<A, InOut>(
97 &mut self,
98 nonce: Nonce,
99 aad: Aad<A>,
100 in_out: &mut InOut,
101 ) -> Result<(), Unspecified>
102 where
103 A: AsRef<[u8]>,
104 InOut: AsMut<[u8]> + for<'in_out> Extend<&'in_out u8>,
105 {
106 self.key
107 .seal_in_place_append_tag(Some(nonce), aad.as_ref(), in_out)
108 .map(|_| ())
109 }
110
111 #[inline]
128 #[allow(clippy::needless_pass_by_value)]
129 pub fn seal_in_place_separate_tag<A>(
130 &mut self,
131 nonce: Nonce,
132 aad: Aad<A>,
133 in_out: &mut [u8],
134 ) -> Result<Tag, Unspecified>
135 where
136 A: AsRef<[u8]>,
137 {
138 self.key
139 .seal_in_place_separate_tag(Some(nonce), aad.as_ref(), in_out)
140 .map(|(_, tag)| tag)
141 }
142
143 #[inline]
145 #[must_use]
146 pub fn algorithm(&self) -> &'static Algorithm {
147 self.key.algorithm()
148 }
149
150 #[must_use]
152 pub fn tls_protocol_id(&self) -> TlsProtocolId {
153 self.protocol
154 }
155}
156
157#[allow(clippy::missing_fields_in_debug)]
158impl Debug for TlsRecordSealingKey {
159 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
160 f.debug_struct("TlsRecordSealingKey")
161 .field("key", &self.key)
162 .field("protocol", &self.protocol)
163 .finish()
164 }
165}
166
167#[allow(clippy::module_name_repetitions)]
177pub struct TlsRecordOpeningKey {
178 key: UnboundKey,
181 protocol: TlsProtocolId,
182}
183
184impl TlsRecordOpeningKey {
185 pub fn new(
191 algorithm: &'static Algorithm,
192 protocol: TlsProtocolId,
193 key_bytes: &[u8],
194 ) -> Result<Self, Unspecified> {
195 let ctx = match (algorithm.id, protocol) {
196 (AlgorithmID::AES_128_GCM, TlsProtocolId::TLS12) => AeadCtx::aes_128_gcm_tls12(
197 key_bytes,
198 algorithm.tag_len(),
199 aead_ctx::AeadDirection::Open,
200 ),
201 (AlgorithmID::AES_128_GCM, TlsProtocolId::TLS13) => AeadCtx::aes_128_gcm_tls13(
202 key_bytes,
203 algorithm.tag_len(),
204 aead_ctx::AeadDirection::Open,
205 ),
206 (AlgorithmID::AES_256_GCM, TlsProtocolId::TLS12) => AeadCtx::aes_256_gcm_tls12(
207 key_bytes,
208 algorithm.tag_len(),
209 aead_ctx::AeadDirection::Open,
210 ),
211 (AlgorithmID::AES_256_GCM, TlsProtocolId::TLS13) => AeadCtx::aes_256_gcm_tls13(
212 key_bytes,
213 algorithm.tag_len(),
214 aead_ctx::AeadDirection::Open,
215 ),
216 (
217 AlgorithmID::AES_128_GCM_SIV
218 | AlgorithmID::AES_192_GCM
219 | AlgorithmID::AES_256_GCM_SIV
220 | AlgorithmID::CHACHA20_POLY1305,
221 _,
222 ) => Err(Unspecified),
223 }?;
224 Ok(Self {
225 key: UnboundKey::from(ctx),
226 protocol,
227 })
228 }
229
230 #[inline]
235 #[allow(clippy::needless_pass_by_value)]
236 pub fn open_in_place<'in_out, A>(
237 &self,
238 nonce: Nonce,
239 aad: Aad<A>,
240 in_out: &'in_out mut [u8],
241 ) -> Result<&'in_out mut [u8], Unspecified>
242 where
243 A: AsRef<[u8]>,
244 {
245 self.key.open_within(nonce, aad.as_ref(), in_out, 0..)
246 }
247
248 #[inline]
253 #[allow(clippy::needless_pass_by_value)]
254 pub fn open_within<'in_out, A>(
255 &self,
256 nonce: Nonce,
257 aad: Aad<A>,
258 in_out: &'in_out mut [u8],
259 ciphertext_and_tag: RangeFrom<usize>,
260 ) -> Result<&'in_out mut [u8], Unspecified>
261 where
262 A: AsRef<[u8]>,
263 {
264 self.key
265 .open_within(nonce, aad.as_ref(), in_out, ciphertext_and_tag)
266 }
267
268 #[inline]
270 #[must_use]
271 pub fn algorithm(&self) -> &'static Algorithm {
272 self.key.algorithm()
273 }
274
275 #[must_use]
277 pub fn tls_protocol_id(&self) -> TlsProtocolId {
278 self.protocol
279 }
280}
281
282#[allow(clippy::missing_fields_in_debug)]
283impl Debug for TlsRecordOpeningKey {
284 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
285 f.debug_struct("TlsRecordOpeningKey")
286 .field("key", &self.key)
287 .field("protocol", &self.protocol)
288 .finish()
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::{TlsProtocolId, TlsRecordOpeningKey, TlsRecordSealingKey};
295 use crate::aead::{Aad, Nonce, AES_128_GCM, AES_256_GCM, CHACHA20_POLY1305};
296 use crate::test::from_hex;
297 use paste::paste;
298
299 const TEST_128_BIT_KEY: &[u8] = &[
300 0xb0, 0x37, 0x9f, 0xf8, 0xfb, 0x8e, 0xa6, 0x31, 0xf4, 0x1c, 0xe6, 0x3e, 0xb5, 0xc5, 0x20,
301 0x7c,
302 ];
303
304 const TEST_256_BIT_KEY: &[u8] = &[
305 0x56, 0xd8, 0x96, 0x68, 0xbd, 0x96, 0xeb, 0xff, 0x5e, 0xa2, 0x0b, 0x34, 0xf2, 0x79, 0x84,
306 0x6e, 0x2b, 0x13, 0x01, 0x3d, 0xab, 0x1d, 0xa4, 0x07, 0x5a, 0x16, 0xd5, 0x0b, 0x53, 0xb0,
307 0xcc, 0x88,
308 ];
309
310 struct TlsNonceTestCase {
311 nonce: &'static str,
312 expect_err: bool,
313 }
314
315 const TLS_NONCE_TEST_CASES: &[TlsNonceTestCase] = &[
316 TlsNonceTestCase {
317 nonce: "9fab40177c900aad9fc28cc3",
318 expect_err: false,
319 },
320 TlsNonceTestCase {
321 nonce: "9fab40177c900aad9fc28cc4",
322 expect_err: false,
323 },
324 TlsNonceTestCase {
325 nonce: "9fab40177c900aad9fc28cc2",
326 expect_err: true,
327 },
328 ];
329
330 macro_rules! test_tls_aead {
331 ($name:ident, $alg:expr, $proto:expr, $key:expr) => {
332 paste! {
333 #[test]
334 fn [<test_ $name _tls_aead_unsupported>]() {
335 assert!(TlsRecordSealingKey::new($alg, $proto, $key).is_err());
336 assert!(TlsRecordOpeningKey::new($alg, $proto, $key).is_err());
337 }
338 }
339 };
340 ($name:ident, $alg:expr, $proto:expr, $key:expr, $expect_tag_len:expr, $expect_nonce_len:expr) => {
341 paste! {
342 #[test]
343 fn [<test_ $name>]() {
344 let mut sealing_key =
345 TlsRecordSealingKey::new($alg, $proto, $key).unwrap();
346
347 let opening_key =
348 TlsRecordOpeningKey::new($alg, $proto, $key).unwrap();
349
350 for case in TLS_NONCE_TEST_CASES {
351 let plaintext = from_hex("00112233445566778899aabbccddeeff").unwrap();
352
353 assert_eq!($alg, sealing_key.algorithm());
354 assert_eq!(*$expect_tag_len, $alg.tag_len());
355 assert_eq!(*$expect_nonce_len, $alg.nonce_len());
356
357 let mut in_out = Vec::from(plaintext.as_slice());
358
359 let nonce = from_hex(case.nonce).unwrap();
360
361 let nonce_bytes = nonce.as_slice();
362
363 let result = sealing_key.seal_in_place_append_tag(
364 Nonce::try_assume_unique_for_key(nonce_bytes).unwrap(),
365 Aad::empty(),
366 &mut in_out,
367 );
368
369 match (result, case.expect_err) {
370 (Ok(()), true) => panic!("expected error for seal_in_place_append_tag"),
371 (Ok(()), false) => {}
372 (Err(_), true) => return,
373 (Err(e), false) => panic!("{e}"),
374 }
375
376 assert_ne!(plaintext, in_out[..plaintext.len()]);
377
378 let mut offset_cipher_text = vec![ 1, 2, 3, 4 ];
380 offset_cipher_text.extend_from_slice(&in_out);
381
382 opening_key
383 .open_in_place(
384 Nonce::try_assume_unique_for_key(nonce_bytes).unwrap(),
385 Aad::empty(),
386 &mut in_out,
387 )
388 .unwrap();
389
390 assert_eq!(plaintext, in_out[..plaintext.len()]);
391
392 opening_key
393 .open_within(
394 Nonce::try_assume_unique_for_key(nonce_bytes).unwrap(),
395 Aad::empty(),
396 &mut offset_cipher_text,
397 4..)
398 .unwrap();
399 assert_eq!(plaintext, offset_cipher_text[..plaintext.len()]);
400 }
401 }
402 }
403 };
404 }
405
406 test_tls_aead!(
407 aes_128_gcm_tls12,
408 &AES_128_GCM,
409 TlsProtocolId::TLS12,
410 TEST_128_BIT_KEY,
411 &16,
412 &12
413 );
414 test_tls_aead!(
415 aes_128_gcm_tls13,
416 &AES_128_GCM,
417 TlsProtocolId::TLS13,
418 TEST_128_BIT_KEY,
419 &16,
420 &12
421 );
422 test_tls_aead!(
423 aes_256_gcm_tls12,
424 &AES_256_GCM,
425 TlsProtocolId::TLS12,
426 TEST_256_BIT_KEY,
427 &16,
428 &12
429 );
430 test_tls_aead!(
431 aes_256_gcm_tls13,
432 &AES_256_GCM,
433 TlsProtocolId::TLS13,
434 TEST_256_BIT_KEY,
435 &16,
436 &12
437 );
438 test_tls_aead!(
439 chacha20_poly1305_tls12,
440 &CHACHA20_POLY1305,
441 TlsProtocolId::TLS12,
442 TEST_256_BIT_KEY
443 );
444 test_tls_aead!(
445 chacha20_poly1305_tls13,
446 &CHACHA20_POLY1305,
447 TlsProtocolId::TLS13,
448 TEST_256_BIT_KEY
449 );
450}