curve25519_dalek/backend/vector/packed_simd.rs
1// -*- mode: rust; -*-
2//
3// This file is part of curve25519-dalek.
4// See LICENSE for licensing information.
5
6// Nightly and stable currently disagree on the requirement of unsafe blocks when `unsafe_target_feature`
7// gets used.
8// See: https://github.com/rust-lang/rust/issues/132856
9#![allow(unused_unsafe)]
10
11//! This module defines wrappers over platform-specific SIMD types to make them
12//! more convenient to use.
13//!
14//! UNSAFETY: Everything in this module assumes that we're running on hardware
15//! which supports at least AVX2. This invariant *must* be enforced
16//! by the callers of this code.
17use core::ops::{Add, AddAssign, BitAnd, BitAndAssign, BitXor, BitXorAssign, Sub};
18
19use curve25519_dalek_derive::unsafe_target_feature;
20
21macro_rules! impl_shared {
22 (
23 $ty:ident,
24 $lane_ty:ident,
25 $add_intrinsic:ident,
26 $sub_intrinsic:ident,
27 $shl_intrinsic:ident,
28 $shr_intrinsic:ident,
29 $extract_intrinsic:ident
30 ) => {
31 #[allow(non_camel_case_types)]
32 #[derive(Copy, Clone, Debug)]
33 #[repr(transparent)]
34 pub struct $ty(core::arch::x86_64::__m256i);
35
36 #[unsafe_target_feature("avx2")]
37 impl From<$ty> for core::arch::x86_64::__m256i {
38 #[inline]
39 fn from(value: $ty) -> core::arch::x86_64::__m256i {
40 value.0
41 }
42 }
43
44 #[unsafe_target_feature("avx2")]
45 impl From<core::arch::x86_64::__m256i> for $ty {
46 #[inline]
47 fn from(value: core::arch::x86_64::__m256i) -> $ty {
48 $ty(value)
49 }
50 }
51
52 #[unsafe_target_feature("avx2")]
53 impl PartialEq for $ty {
54 #[inline]
55 fn eq(&self, rhs: &$ty) -> bool {
56 unsafe {
57 // This compares each pair of 8-bit packed integers and returns either 0xFF or
58 // 0x00 depending on whether they're equal.
59 //
60 // So the values are equal if (and only if) this returns a value that's filled
61 // with only 0xFF.
62 //
63 // Pseudocode of what this does:
64 // self.0
65 // .bytes()
66 // .zip(rhs.0.bytes())
67 // .map(|a, b| if a == b { 0xFF } else { 0x00 })
68 // .join();
69 let m = core::arch::x86_64::_mm256_cmpeq_epi8(self.0, rhs.0);
70
71 // Now we need to reduce the 256-bit value to something on which we can branch.
72 //
73 // This will just take the most significant bit of every 8-bit packed integer
74 // and build an `i32` out of it. If the values we previously compared were
75 // equal then all off the most significant bits will be equal to 1, which means
76 // that this will return 0xFFFFFFFF, which is equal to -1 when represented as
77 // an `i32`.
78 core::arch::x86_64::_mm256_movemask_epi8(m) == -1
79 }
80 }
81 }
82
83 impl Eq for $ty {}
84
85 #[unsafe_target_feature("avx2")]
86 impl Add for $ty {
87 type Output = Self;
88
89 #[inline]
90 fn add(self, rhs: $ty) -> Self {
91 unsafe { core::arch::x86_64::$add_intrinsic(self.0, rhs.0).into() }
92 }
93 }
94
95 #[allow(clippy::assign_op_pattern)]
96 #[unsafe_target_feature("avx2")]
97 impl AddAssign for $ty {
98 #[inline]
99 fn add_assign(&mut self, rhs: $ty) {
100 *self = *self + rhs
101 }
102 }
103
104 #[unsafe_target_feature("avx2")]
105 impl Sub for $ty {
106 type Output = Self;
107
108 #[inline]
109 fn sub(self, rhs: $ty) -> Self {
110 unsafe { core::arch::x86_64::$sub_intrinsic(self.0, rhs.0).into() }
111 }
112 }
113
114 #[unsafe_target_feature("avx2")]
115 impl BitAnd for $ty {
116 type Output = Self;
117
118 #[inline]
119 fn bitand(self, rhs: $ty) -> Self {
120 unsafe { core::arch::x86_64::_mm256_and_si256(self.0, rhs.0).into() }
121 }
122 }
123
124 #[unsafe_target_feature("avx2")]
125 impl BitXor for $ty {
126 type Output = Self;
127
128 #[inline]
129 fn bitxor(self, rhs: $ty) -> Self {
130 unsafe { core::arch::x86_64::_mm256_xor_si256(self.0, rhs.0).into() }
131 }
132 }
133
134 #[allow(clippy::assign_op_pattern)]
135 #[unsafe_target_feature("avx2")]
136 impl BitAndAssign for $ty {
137 #[inline]
138 fn bitand_assign(&mut self, rhs: $ty) {
139 *self = *self & rhs;
140 }
141 }
142
143 #[allow(clippy::assign_op_pattern)]
144 #[unsafe_target_feature("avx2")]
145 impl BitXorAssign for $ty {
146 #[inline]
147 fn bitxor_assign(&mut self, rhs: $ty) {
148 *self = *self ^ rhs;
149 }
150 }
151
152 #[unsafe_target_feature("avx2")]
153 #[allow(dead_code)]
154 impl $ty {
155 #[inline]
156 pub fn shl<const N: i32>(self) -> Self {
157 unsafe { core::arch::x86_64::$shl_intrinsic(self.0, N).into() }
158 }
159
160 #[inline]
161 pub fn shr<const N: i32>(self) -> Self {
162 unsafe { core::arch::x86_64::$shr_intrinsic(self.0, N).into() }
163 }
164
165 #[inline]
166 pub fn extract<const N: i32>(self) -> $lane_ty {
167 unsafe { core::arch::x86_64::$extract_intrinsic(self.0, N) as $lane_ty }
168 }
169 }
170 };
171}
172
173macro_rules! impl_conv {
174 ($src:ident => $($dst:ident),+) => {
175 $(
176 #[unsafe_target_feature("avx2")]
177 impl From<$src> for $dst {
178 #[inline]
179 fn from(value: $src) -> $dst {
180 $dst(value.0)
181 }
182 }
183 )+
184 }
185}
186
187// We define SIMD functionality over packed unsigned integer types. However, all the integer
188// intrinsics deal with signed integers. So we cast unsigned to signed, pack it into SIMD, do
189// add/sub/shl/shr arithmetic, and finally cast back to unsigned at the end. Why is this equivalent
190// to doing the same thing on unsigned integers? Shl/shr is clear, because casting does not change
191// the bits of the integer. But what about add/sub? This is due to the following:
192//
193// 1) Rust uses two's complement to represent signed integers. So we're assured that the values
194// we cast into SIMD and extract out at the end are two's complement.
195//
196// https://doc.rust-lang.org/reference/types/numeric.html
197//
198// 2) Wrapping add/sub is compatible between two's complement signed and unsigned integers.
199// That is, for all x,y: u64 (or any unsigned integer type),
200//
201// x.wrapping_add(y) == (x as i64).wrapping_add(y as i64) as u64, and
202// x.wrapping_sub(y) == (x as i64).wrapping_sub(y as i64) as u64
203//
204// https://julesjacobs.com/2019/03/20/why-twos-complement-works.html
205//
206// 3) The add/sub functions we use for SIMD are indeed wrapping. The docs indicate that
207// __mm256_add/sub compile to vpaddX/vpsubX instructions where X = w, d, or q depending on
208// the bitwidth. From x86 docs:
209//
210// When an individual result is too large to be represented in X bits (overflow), the
211// result is wrapped around and the low X bits are written to the destination operand
212// (that is, the carry is ignored).
213//
214// https://www.felixcloutier.com/x86/paddb:paddw:paddd:paddq
215// https://www.felixcloutier.com/x86/psubb:psubw:psubd
216// https://www.felixcloutier.com/x86/psubq
217
218impl_shared!(
219 u64x4,
220 u64,
221 _mm256_add_epi64,
222 _mm256_sub_epi64,
223 _mm256_slli_epi64,
224 _mm256_srli_epi64,
225 _mm256_extract_epi64
226);
227impl_shared!(
228 u32x8,
229 u32,
230 _mm256_add_epi32,
231 _mm256_sub_epi32,
232 _mm256_slli_epi32,
233 _mm256_srli_epi32,
234 _mm256_extract_epi32
235);
236
237impl_conv!(u64x4 => u32x8);
238
239#[allow(dead_code)]
240impl u64x4 {
241 /// A constified variant of `new`.
242 ///
243 /// Should only be called from `const` contexts. At runtime `new` is going to be faster.
244 #[inline]
245 pub const fn new_const(x0: u64, x1: u64, x2: u64, x3: u64) -> Self {
246 // SAFETY: Transmuting between an array and a SIMD type is safe
247 // https://rust-lang.github.io/unsafe-code-guidelines/layout/packed-simd-vectors.html
248 unsafe {
249 Self(core::mem::transmute::<[u64; 4], core::arch::x86_64::__m256i>([x0, x1, x2, x3]))
250 }
251 }
252
253 /// A constified variant of `splat`.
254 ///
255 /// Should only be called from `const` contexts. At runtime `splat` is going to be faster.
256 #[inline]
257 pub const fn splat_const<const N: u64>() -> Self {
258 Self::new_const(N, N, N, N)
259 }
260
261 /// Constructs a new instance.
262 #[unsafe_target_feature("avx2")]
263 #[inline]
264 pub fn new(x0: u64, x1: u64, x2: u64, x3: u64) -> u64x4 {
265 unsafe {
266 // _mm256_set_epi64 sets the underlying vector in reverse order of the args
267 u64x4(core::arch::x86_64::_mm256_set_epi64x(
268 x3 as i64, x2 as i64, x1 as i64, x0 as i64,
269 ))
270 }
271 }
272
273 /// Constructs a new instance with all of the elements initialized to the given value.
274 #[unsafe_target_feature("avx2")]
275 #[inline]
276 pub fn splat(x: u64) -> u64x4 {
277 unsafe { u64x4(core::arch::x86_64::_mm256_set1_epi64x(x as i64)) }
278 }
279}
280
281#[allow(dead_code)]
282impl u32x8 {
283 /// A constified variant of `new`.
284 ///
285 /// Should only be called from `const` contexts. At runtime `new` is going to be faster.
286 #[allow(clippy::too_many_arguments)]
287 #[inline]
288 pub const fn new_const(
289 x0: u32,
290 x1: u32,
291 x2: u32,
292 x3: u32,
293 x4: u32,
294 x5: u32,
295 x6: u32,
296 x7: u32,
297 ) -> Self {
298 // SAFETY: Transmuting between an array and a SIMD type is safe
299 // https://rust-lang.github.io/unsafe-code-guidelines/layout/packed-simd-vectors.html
300 unsafe {
301 Self(
302 core::mem::transmute::<[u32; 8], core::arch::x86_64::__m256i>([
303 x0, x1, x2, x3, x4, x5, x6, x7,
304 ]),
305 )
306 }
307 }
308
309 /// A constified variant of `splat`.
310 ///
311 /// Should only be called from `const` contexts. At runtime `splat` is going to be faster.
312 #[inline]
313 pub const fn splat_const<const N: u32>() -> Self {
314 Self::new_const(N, N, N, N, N, N, N, N)
315 }
316
317 /// Constructs a new instance.
318 #[allow(clippy::too_many_arguments)]
319 #[unsafe_target_feature("avx2")]
320 #[inline]
321 pub fn new(x0: u32, x1: u32, x2: u32, x3: u32, x4: u32, x5: u32, x6: u32, x7: u32) -> u32x8 {
322 unsafe {
323 // _mm256_set_epi32 sets the underlying vector in reverse order of the args
324 u32x8(core::arch::x86_64::_mm256_set_epi32(
325 x7 as i32, x6 as i32, x5 as i32, x4 as i32, x3 as i32, x2 as i32, x1 as i32,
326 x0 as i32,
327 ))
328 }
329 }
330
331 /// Constructs a new instance with all of the elements initialized to the given value.
332 #[unsafe_target_feature("avx2")]
333 #[inline]
334 pub fn splat(x: u32) -> u32x8 {
335 unsafe { u32x8(core::arch::x86_64::_mm256_set1_epi32(x as i32)) }
336 }
337}
338
339#[unsafe_target_feature("avx2")]
340impl u32x8 {
341 /// Multiplies the low unsigned 32-bits from each packed 64-bit element
342 /// and returns the unsigned 64-bit results.
343 ///
344 /// (This ignores the upper 32-bits from each packed 64-bits!)
345 #[inline]
346 pub fn mul32(self, rhs: u32x8) -> u64x4 {
347 // NOTE: This ignores the upper 32-bits from each packed 64-bits.
348 unsafe { core::arch::x86_64::_mm256_mul_epu32(self.0, rhs.0).into() }
349 }
350}