Skip to main content

aes/x86/ni/
encdec.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2
3use crate::Block;
4use crate::x86::arch::*;
5use cipher::{
6    array::{Array, ArraySize},
7    inout::InOut,
8};
9
10#[target_feature(enable = "aes")]
11pub(crate) unsafe fn encrypt<const KEYS: usize>(
12    keys: &[__m128i; KEYS],
13    block: InOut<'_, '_, Block>,
14) {
15    assert!(KEYS == 11 || KEYS == 13 || KEYS == 15);
16
17    let (block_in, block_out) = block.into_raw();
18    let mut b = _mm_loadu_si128(block_in.cast());
19    b = _mm_xor_si128(b, keys[0]);
20    for &key in &keys[1..KEYS - 1] {
21        b = _mm_aesenc_si128(b, key);
22    }
23    b = _mm_aesenclast_si128(b, keys[KEYS - 1]);
24    _mm_storeu_si128(block_out.cast(), b);
25}
26
27#[target_feature(enable = "aes")]
28pub(crate) unsafe fn decrypt<const KEYS: usize>(
29    keys: &[__m128i; KEYS],
30    block: InOut<'_, '_, Block>,
31) {
32    assert!(KEYS == 11 || KEYS == 13 || KEYS == 15);
33
34    let (block_in, block_out) = block.into_raw();
35    let mut b = _mm_loadu_si128(block_in.cast());
36    b = _mm_xor_si128(b, keys[0]);
37    for &key in &keys[1..KEYS - 1] {
38        b = _mm_aesdec_si128(b, key);
39    }
40    b = _mm_aesdeclast_si128(b, keys[KEYS - 1]);
41    _mm_storeu_si128(block_out.cast(), b);
42}
43
44#[target_feature(enable = "aes")]
45pub(crate) unsafe fn encrypt_par<const KEYS: usize, ParBlocks: ArraySize>(
46    keys: &[__m128i; KEYS],
47    blocks: InOut<'_, '_, Array<Block, ParBlocks>>,
48) {
49    assert!(KEYS == 11 || KEYS == 13 || KEYS == 15);
50
51    let (blocks_in, blocks_out) = blocks.into_raw();
52    let mut b = load(blocks_in);
53
54    // Loop over keys is intentionally not used here to force inlining
55    xor(&mut b, keys[0]);
56    aesenc(&mut b, keys[1]);
57    aesenc(&mut b, keys[2]);
58    aesenc(&mut b, keys[3]);
59    aesenc(&mut b, keys[4]);
60    aesenc(&mut b, keys[5]);
61    aesenc(&mut b, keys[6]);
62    aesenc(&mut b, keys[7]);
63    aesenc(&mut b, keys[8]);
64    aesenc(&mut b, keys[9]);
65    if KEYS >= 13 {
66        aesenc(&mut b, keys[10]);
67        aesenc(&mut b, keys[11]);
68    }
69    if KEYS == 15 {
70        aesenc(&mut b, keys[12]);
71        aesenc(&mut b, keys[13]);
72    }
73    aesenclast(&mut b, keys[KEYS - 1]);
74    store(blocks_out, b);
75}
76
77#[target_feature(enable = "aes")]
78pub(crate) unsafe fn decrypt_par<const KEYS: usize, ParBlocks: ArraySize>(
79    keys: &[__m128i; KEYS],
80    blocks: InOut<'_, '_, Array<Block, ParBlocks>>,
81) {
82    assert!(KEYS == 11 || KEYS == 13 || KEYS == 15);
83
84    let (blocks_in, blocks_out) = blocks.into_raw();
85    let mut b = load(blocks_in);
86
87    // Loop over keys is intentionally not used here to force inlining
88    xor(&mut b, keys[0]);
89    aesdec(&mut b, keys[1]);
90    aesdec(&mut b, keys[2]);
91    aesdec(&mut b, keys[3]);
92    aesdec(&mut b, keys[4]);
93    aesdec(&mut b, keys[5]);
94    aesdec(&mut b, keys[6]);
95    aesdec(&mut b, keys[7]);
96    aesdec(&mut b, keys[8]);
97    aesdec(&mut b, keys[9]);
98    if KEYS >= 13 {
99        aesdec(&mut b, keys[10]);
100        aesdec(&mut b, keys[11]);
101    }
102    if KEYS == 15 {
103        aesdec(&mut b, keys[12]);
104        aesdec(&mut b, keys[13]);
105    }
106    aesdeclast(&mut b, keys[KEYS - 1]);
107    store(blocks_out, b);
108}
109
110#[target_feature(enable = "sse2")]
111pub(crate) unsafe fn load<N: ArraySize>(blocks: *const Array<Block, N>) -> Array<__m128i, N> {
112    let p = blocks.cast::<__m128i>();
113    let mut res: Array<__m128i, N> = core::mem::zeroed();
114    for i in 0..N::USIZE {
115        res[i] = _mm_loadu_si128(p.add(i));
116    }
117    res
118}
119
120#[target_feature(enable = "sse2")]
121pub(crate) unsafe fn store<N: ArraySize>(blocks: *mut Array<Block, N>, b: Array<__m128i, N>) {
122    let p = blocks.cast::<__m128i>();
123    for i in 0..N::USIZE {
124        _mm_storeu_si128(p.add(i), b[i]);
125    }
126}
127
128#[target_feature(enable = "sse2")]
129pub(crate) unsafe fn xor<N: ArraySize>(blocks: &mut Array<__m128i, N>, key: __m128i) {
130    for block in blocks {
131        *block = _mm_xor_si128(*block, key);
132    }
133}
134
135#[target_feature(enable = "aes")]
136pub(crate) unsafe fn aesenc<N: ArraySize>(blocks: &mut Array<__m128i, N>, key: __m128i) {
137    for block in blocks {
138        *block = _mm_aesenc_si128(*block, key);
139    }
140}
141
142#[target_feature(enable = "aes")]
143pub(crate) unsafe fn aesenclast<N: ArraySize>(blocks: &mut Array<__m128i, N>, key: __m128i) {
144    for block in blocks {
145        *block = _mm_aesenclast_si128(*block, key);
146    }
147}
148
149#[target_feature(enable = "aes")]
150pub(crate) unsafe fn aesdec<N: ArraySize>(blocks: &mut Array<__m128i, N>, key: __m128i) {
151    for block in blocks {
152        *block = _mm_aesdec_si128(*block, key);
153    }
154}
155
156#[target_feature(enable = "aes")]
157pub(crate) unsafe fn aesdeclast<N: ArraySize>(blocks: &mut Array<__m128i, N>, key: __m128i) {
158    for block in blocks {
159        *block = _mm_aesdeclast_si128(*block, key);
160    }
161}