1use crate::aws_lc::{AES_set_decrypt_key, AES_set_encrypt_key, AES_KEY};
5use crate::cipher::block::Block;
6use crate::cipher::chacha::ChaCha20Key;
7use crate::cipher::{AES_128_KEY_LEN, AES_192_KEY_LEN, AES_256_KEY_LEN};
8use crate::error::Unspecified;
9use core::mem::{size_of, MaybeUninit};
10use core::ptr::copy_nonoverlapping;
11use std::os::raw::c_uint;
14use zeroize::Zeroize;
15
16pub(crate) enum SymmetricCipherKey {
17 Aes128 { enc_key: AES_KEY, dec_key: AES_KEY },
18 Aes192 { enc_key: AES_KEY, dec_key: AES_KEY },
19 Aes256 { enc_key: AES_KEY, dec_key: AES_KEY },
20 ChaCha20 { raw_key: ChaCha20Key },
21}
22
23unsafe impl Send for SymmetricCipherKey {}
24
25unsafe impl Sync for SymmetricCipherKey {}
27
28impl Drop for SymmetricCipherKey {
29 fn drop(&mut self) {
30 match self {
32 SymmetricCipherKey::Aes128 { enc_key, dec_key }
33 | SymmetricCipherKey::Aes192 { enc_key, dec_key }
34 | SymmetricCipherKey::Aes256 { enc_key, dec_key } => unsafe {
35 let enc_bytes: &mut [u8; size_of::<AES_KEY>()] = (enc_key as *mut AES_KEY)
36 .cast::<[u8; size_of::<AES_KEY>()]>()
37 .as_mut()
38 .unwrap();
39 enc_bytes.zeroize();
40 let dec_bytes: &mut [u8; size_of::<AES_KEY>()] = (dec_key as *mut AES_KEY)
41 .cast::<[u8; size_of::<AES_KEY>()]>()
42 .as_mut()
43 .unwrap();
44 dec_bytes.zeroize();
45 },
46 SymmetricCipherKey::ChaCha20 { .. } => {}
47 }
48 }
49}
50
51impl SymmetricCipherKey {
52 fn aes(key_bytes: &[u8]) -> Result<(AES_KEY, AES_KEY), Unspecified> {
53 let mut enc_key = MaybeUninit::<AES_KEY>::uninit();
54 let mut dec_key = MaybeUninit::<AES_KEY>::uninit();
55 #[allow(clippy::cast_possible_truncation)]
56 if unsafe {
57 0 != AES_set_encrypt_key(
58 key_bytes.as_ptr(),
59 (key_bytes.len() * 8) as c_uint,
60 enc_key.as_mut_ptr(),
61 )
62 } {
63 return Err(Unspecified);
64 }
65
66 #[allow(clippy::cast_possible_truncation)]
67 if unsafe {
68 0 != AES_set_decrypt_key(
69 key_bytes.as_ptr(),
70 (key_bytes.len() * 8) as c_uint,
71 dec_key.as_mut_ptr(),
72 )
73 } {
74 return Err(Unspecified);
75 }
76 unsafe { Ok((enc_key.assume_init(), dec_key.assume_init())) }
77 }
78
79 pub(crate) fn aes128(key_bytes: &[u8]) -> Result<Self, Unspecified> {
80 if key_bytes.len() != AES_128_KEY_LEN {
81 return Err(Unspecified);
82 }
83 let (enc_key, dec_key) = SymmetricCipherKey::aes(key_bytes)?;
84 Ok(SymmetricCipherKey::Aes128 { enc_key, dec_key })
85 }
86
87 pub(crate) fn aes192(key_bytes: &[u8]) -> Result<Self, Unspecified> {
88 if key_bytes.len() != AES_192_KEY_LEN {
89 return Err(Unspecified);
90 }
91 let (enc_key, dec_key) = SymmetricCipherKey::aes(key_bytes)?;
92 Ok(SymmetricCipherKey::Aes192 { enc_key, dec_key })
93 }
94
95 pub(crate) fn aes256(key_bytes: &[u8]) -> Result<Self, Unspecified> {
96 if key_bytes.len() != AES_256_KEY_LEN {
97 return Err(Unspecified);
98 }
99 let (enc_key, dec_key) = SymmetricCipherKey::aes(key_bytes)?;
100 Ok(SymmetricCipherKey::Aes256 { enc_key, dec_key })
101 }
102
103 pub(crate) fn chacha20(key_bytes: &[u8]) -> Result<Self, Unspecified> {
104 if key_bytes.len() != 32 {
105 return Err(Unspecified);
106 }
107 let mut kb = MaybeUninit::<[u8; 32]>::uninit();
108 unsafe {
109 copy_nonoverlapping(key_bytes.as_ptr(), kb.as_mut_ptr().cast(), 32);
110 Ok(SymmetricCipherKey::ChaCha20 {
111 raw_key: ChaCha20Key(kb.assume_init()),
112 })
113 }
114 }
115
116 #[allow(dead_code)]
117 #[inline]
118 pub(crate) fn encrypt_block(&self, block: Block) -> Block {
119 match self {
120 SymmetricCipherKey::Aes128 { enc_key, .. }
121 | SymmetricCipherKey::Aes192 { enc_key, .. }
122 | SymmetricCipherKey::Aes256 { enc_key, .. } => {
123 super::aes::encrypt_block(enc_key, block)
124 }
125 SymmetricCipherKey::ChaCha20 { .. } => panic!("Unsupported algorithm!"),
126 }
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use crate::cipher::block::{Block, BLOCK_LEN};
133 use crate::cipher::key::SymmetricCipherKey;
134 use crate::test::from_hex;
135
136 #[test]
137 fn test_encrypt_block_aes_128() {
138 let key = from_hex("000102030405060708090a0b0c0d0e0f").unwrap();
139 let input = from_hex("00112233445566778899aabbccddeeff").unwrap();
140 let expected_result = from_hex("69c4e0d86a7b0430d8cdb78070b4c55a").unwrap();
141 let input_block: [u8; BLOCK_LEN] = <[u8; BLOCK_LEN]>::try_from(input).unwrap();
142
143 let aes128 = SymmetricCipherKey::aes128(key.as_slice()).unwrap();
144 let result = aes128.encrypt_block(Block::from(input_block));
145
146 assert_eq!(expected_result.as_slice(), result.as_ref());
147 }
148
149 #[test]
150 fn test_encrypt_block_aes_256() {
151 let key =
152 from_hex("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f").unwrap();
153 let input = from_hex("00112233445566778899aabbccddeeff").unwrap();
154 let expected_result = from_hex("8ea2b7ca516745bfeafc49904b496089").unwrap();
155 let input_block: [u8; BLOCK_LEN] = <[u8; BLOCK_LEN]>::try_from(input).unwrap();
156
157 let aes128 = SymmetricCipherKey::aes256(key.as_slice()).unwrap();
158 let result = aes128.encrypt_block(Block::from(input_block));
159
160 assert_eq!(expected_result.as_slice(), result.as_ref());
161 }
162}