Skip to main content

aes/x86/ni/
expand.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2
3use crate::x86::arch::*;
4use core::mem::{transmute, zeroed};
5
6pub(super) type Aes128RoundKeys = [__m128i; 11];
7pub(super) type Aes192RoundKeys = [__m128i; 13];
8pub(super) type Aes256RoundKeys = [__m128i; 15];
9
10pub(crate) mod aes128 {
11    use super::*;
12
13    #[target_feature(enable = "aes")]
14    pub(crate) unsafe fn expand_key(key: &[u8; 16]) -> Aes128RoundKeys {
15        unsafe fn expand_round<const RK: i32>(keys: &mut Aes128RoundKeys, pos: usize) {
16            let mut t1 = keys[pos - 1];
17            let mut t2;
18            let mut t3;
19
20            t2 = _mm_aeskeygenassist_si128(t1, RK);
21            t2 = _mm_shuffle_epi32(t2, 0xff);
22            t3 = _mm_slli_si128(t1, 0x4);
23            t1 = _mm_xor_si128(t1, t3);
24            t3 = _mm_slli_si128(t3, 0x4);
25            t1 = _mm_xor_si128(t1, t3);
26            t3 = _mm_slli_si128(t3, 0x4);
27            t1 = _mm_xor_si128(t1, t3);
28            t1 = _mm_xor_si128(t1, t2);
29
30            keys[pos] = t1;
31        }
32
33        let mut keys: Aes128RoundKeys = zeroed();
34        let k = _mm_loadu_si128(key.as_ptr().cast());
35        keys[0] = k;
36
37        let kr = &mut keys;
38        expand_round::<0x01>(kr, 1);
39        expand_round::<0x02>(kr, 2);
40        expand_round::<0x04>(kr, 3);
41        expand_round::<0x08>(kr, 4);
42        expand_round::<0x10>(kr, 5);
43        expand_round::<0x20>(kr, 6);
44        expand_round::<0x40>(kr, 7);
45        expand_round::<0x80>(kr, 8);
46        expand_round::<0x1B>(kr, 9);
47        expand_round::<0x36>(kr, 10);
48
49        keys
50    }
51}
52
53pub(crate) mod aes192 {
54    use super::*;
55
56    #[target_feature(enable = "aes")]
57    pub(crate) unsafe fn expand_key(key: &[u8; 24]) -> Aes192RoundKeys {
58        unsafe fn shuffle(a: __m128i, b: __m128i, i: usize) -> __m128i {
59            let a: [u64; 2] = transmute(a);
60            let b: [u64; 2] = transmute(b);
61            transmute([a[i], b[0]])
62        }
63
64        #[target_feature(enable = "aes")]
65        unsafe fn expand_round<const RK: i32>(
66            mut t1: __m128i,
67            mut t3: __m128i,
68        ) -> (__m128i, __m128i) {
69            let (mut t2, mut t4);
70
71            t2 = _mm_aeskeygenassist_si128(t3, RK);
72            t2 = _mm_shuffle_epi32(t2, 0x55);
73            t4 = _mm_slli_si128(t1, 0x4);
74            t1 = _mm_xor_si128(t1, t4);
75            t4 = _mm_slli_si128(t4, 0x4);
76            t1 = _mm_xor_si128(t1, t4);
77            t4 = _mm_slli_si128(t4, 0x4);
78            t1 = _mm_xor_si128(t1, t4);
79            t1 = _mm_xor_si128(t1, t2);
80            t2 = _mm_shuffle_epi32(t1, 0xff);
81            t4 = _mm_slli_si128(t3, 0x4);
82            t3 = _mm_xor_si128(t3, t4);
83            t3 = _mm_xor_si128(t3, t2);
84
85            (t1, t3)
86        }
87
88        let mut keys: Aes192RoundKeys = zeroed();
89        // We are being extra pedantic here to remove out-of-bound access.
90        // This should be optimized into movups, movsd sequence.
91        let (k0, k1l) = {
92            let mut t = [0u8; 32];
93            t[..key.len()].copy_from_slice(key);
94            (
95                _mm_loadu_si128(t.as_ptr().cast()),
96                _mm_loadu_si128(t.as_ptr().offset(16).cast()),
97            )
98        };
99
100        keys[0] = k0;
101
102        let (k1_2, k2r) = expand_round::<0x01>(k0, k1l);
103        keys[1] = shuffle(k1l, k1_2, 0);
104        keys[2] = shuffle(k1_2, k2r, 1);
105
106        let (k3, k4l) = expand_round::<0x02>(k1_2, k2r);
107        keys[3] = k3;
108
109        let (k4_5, k5r) = expand_round::<0x04>(k3, k4l);
110        let k4 = shuffle(k4l, k4_5, 0);
111        let k5 = shuffle(k4_5, k5r, 1);
112        keys[4] = k4;
113        keys[5] = k5;
114
115        let (k6, k7l) = expand_round::<0x08>(k4_5, k5r);
116        keys[6] = k6;
117
118        let (k7_8, k8r) = expand_round::<0x10>(k6, k7l);
119        keys[7] = shuffle(k7l, k7_8, 0);
120        keys[8] = shuffle(k7_8, k8r, 1);
121
122        let (k9, k10l) = expand_round::<0x20>(k7_8, k8r);
123        keys[9] = k9;
124
125        let (k10_11, k11r) = expand_round::<0x40>(k9, k10l);
126        keys[10] = shuffle(k10l, k10_11, 0);
127        keys[11] = shuffle(k10_11, k11r, 1);
128
129        let (k12, _) = expand_round::<0x80>(k10_11, k11r);
130        keys[12] = k12;
131
132        keys
133    }
134}
135
136pub(crate) mod aes256 {
137    use super::*;
138
139    #[target_feature(enable = "aes")]
140    pub(crate) unsafe fn expand_key(key: &[u8; 32]) -> Aes256RoundKeys {
141        unsafe fn expand_round<const RK: i32>(keys: &mut Aes256RoundKeys, pos: usize) {
142            let mut t1 = keys[pos - 2];
143            let mut t2;
144            let mut t3 = keys[pos - 1];
145            let mut t4;
146
147            t2 = _mm_aeskeygenassist_si128(t3, RK);
148            t2 = _mm_shuffle_epi32(t2, 0xff);
149            t4 = _mm_slli_si128(t1, 0x4);
150            t1 = _mm_xor_si128(t1, t4);
151            t4 = _mm_slli_si128(t4, 0x4);
152            t1 = _mm_xor_si128(t1, t4);
153            t4 = _mm_slli_si128(t4, 0x4);
154            t1 = _mm_xor_si128(t1, t4);
155            t1 = _mm_xor_si128(t1, t2);
156
157            keys[pos] = t1;
158
159            t4 = _mm_aeskeygenassist_si128(t1, 0x00);
160            t2 = _mm_shuffle_epi32(t4, 0xaa);
161            t4 = _mm_slli_si128(t3, 0x4);
162            t3 = _mm_xor_si128(t3, t4);
163            t4 = _mm_slli_si128(t4, 0x4);
164            t3 = _mm_xor_si128(t3, t4);
165            t4 = _mm_slli_si128(t4, 0x4);
166            t3 = _mm_xor_si128(t3, t4);
167            t3 = _mm_xor_si128(t3, t2);
168
169            keys[pos + 1] = t3;
170        }
171
172        unsafe fn expand_round_last<const RK: i32>(keys: &mut Aes256RoundKeys, pos: usize) {
173            let mut t1 = keys[pos - 2];
174            let mut t2;
175            let t3 = keys[pos - 1];
176            let mut t4;
177
178            t2 = _mm_aeskeygenassist_si128(t3, RK);
179            t2 = _mm_shuffle_epi32(t2, 0xff);
180            t4 = _mm_slli_si128(t1, 0x4);
181            t1 = _mm_xor_si128(t1, t4);
182            t4 = _mm_slli_si128(t4, 0x4);
183            t1 = _mm_xor_si128(t1, t4);
184            t4 = _mm_slli_si128(t4, 0x4);
185            t1 = _mm_xor_si128(t1, t4);
186            t1 = _mm_xor_si128(t1, t2);
187
188            keys[pos] = t1;
189        }
190
191        let mut keys: Aes256RoundKeys = zeroed();
192
193        let kp = key.as_ptr().cast::<__m128i>();
194        keys[0] = _mm_loadu_si128(kp);
195        keys[1] = _mm_loadu_si128(kp.add(1));
196
197        let k = &mut keys;
198        expand_round::<0x01>(k, 2);
199        expand_round::<0x02>(k, 4);
200        expand_round::<0x04>(k, 6);
201        expand_round::<0x08>(k, 8);
202        expand_round::<0x10>(k, 10);
203        expand_round::<0x20>(k, 12);
204        expand_round_last::<0x40>(k, 14);
205
206        keys
207    }
208}
209
210#[target_feature(enable = "aes")]
211pub(crate) unsafe fn inv_keys<const N: usize>(keys: &[__m128i; N]) -> [__m128i; N] {
212    let mut inv_keys: [__m128i; N] = zeroed();
213    inv_keys[0] = keys[N - 1];
214    for i in 1..N - 1 {
215        inv_keys[i] = _mm_aesimc_si128(keys[N - 1 - i]);
216    }
217    inv_keys[N - 1] = keys[0];
218    inv_keys
219}