Skip to main content

curve25519_dalek/backend/vector/
packed_simd.rs

1// -*- mode: rust; -*-
2//
3// This file is part of curve25519-dalek.
4// See LICENSE for licensing information.
5
6// Nightly and stable currently disagree on the requirement of unsafe blocks when `unsafe_target_feature`
7// gets used.
8// See: https://github.com/rust-lang/rust/issues/132856
9#![allow(unused_unsafe)]
10
11//! This module defines wrappers over platform-specific SIMD types to make them
12//! more convenient to use.
13//!
14//! UNSAFETY: Everything in this module assumes that we're running on hardware
15//!           which supports at least AVX2. This invariant *must* be enforced
16//!           by the callers of this code.
17use core::ops::{Add, AddAssign, BitAnd, BitAndAssign, BitXor, BitXorAssign, Sub};
18
19use curve25519_dalek_derive::unsafe_target_feature;
20
21macro_rules! impl_shared {
22    (
23        $ty:ident,
24        $lane_ty:ident,
25        $add_intrinsic:ident,
26        $sub_intrinsic:ident,
27        $shl_intrinsic:ident,
28        $shr_intrinsic:ident,
29        $extract_intrinsic:ident
30    ) => {
31        #[allow(non_camel_case_types)]
32        #[derive(Copy, Clone, Debug)]
33        #[repr(transparent)]
34        pub struct $ty(core::arch::x86_64::__m256i);
35
36        #[unsafe_target_feature("avx2")]
37        impl From<$ty> for core::arch::x86_64::__m256i {
38            #[inline]
39            fn from(value: $ty) -> core::arch::x86_64::__m256i {
40                value.0
41            }
42        }
43
44        #[unsafe_target_feature("avx2")]
45        impl From<core::arch::x86_64::__m256i> for $ty {
46            #[inline]
47            fn from(value: core::arch::x86_64::__m256i) -> $ty {
48                $ty(value)
49            }
50        }
51
52        #[unsafe_target_feature("avx2")]
53        impl PartialEq for $ty {
54            #[inline]
55            fn eq(&self, rhs: &$ty) -> bool {
56                unsafe {
57                    // This compares each pair of 8-bit packed integers and returns either 0xFF or
58                    // 0x00 depending on whether they're equal.
59                    //
60                    // So the values are equal if (and only if) this returns a value that's filled
61                    // with only 0xFF.
62                    //
63                    // Pseudocode of what this does:
64                    //     self.0
65                    //         .bytes()
66                    //         .zip(rhs.0.bytes())
67                    //         .map(|a, b| if a == b { 0xFF } else { 0x00 })
68                    //         .join();
69                    let m = core::arch::x86_64::_mm256_cmpeq_epi8(self.0, rhs.0);
70
71                    // Now we need to reduce the 256-bit value to something on which we can branch.
72                    //
73                    // This will just take the most significant bit of every 8-bit packed integer
74                    // and build an `i32` out of it. If the values we previously compared were
75                    // equal then all off the most significant bits will be equal to 1, which means
76                    // that this will return 0xFFFFFFFF, which is equal to -1 when represented as
77                    // an `i32`.
78                    core::arch::x86_64::_mm256_movemask_epi8(m) == -1
79                }
80            }
81        }
82
83        impl Eq for $ty {}
84
85        #[unsafe_target_feature("avx2")]
86        impl Add for $ty {
87            type Output = Self;
88
89            #[inline]
90            fn add(self, rhs: $ty) -> Self {
91                unsafe { core::arch::x86_64::$add_intrinsic(self.0, rhs.0).into() }
92            }
93        }
94
95        #[allow(clippy::assign_op_pattern)]
96        #[unsafe_target_feature("avx2")]
97        impl AddAssign for $ty {
98            #[inline]
99            fn add_assign(&mut self, rhs: $ty) {
100                *self = *self + rhs
101            }
102        }
103
104        #[unsafe_target_feature("avx2")]
105        impl Sub for $ty {
106            type Output = Self;
107
108            #[inline]
109            fn sub(self, rhs: $ty) -> Self {
110                unsafe { core::arch::x86_64::$sub_intrinsic(self.0, rhs.0).into() }
111            }
112        }
113
114        #[unsafe_target_feature("avx2")]
115        impl BitAnd for $ty {
116            type Output = Self;
117
118            #[inline]
119            fn bitand(self, rhs: $ty) -> Self {
120                unsafe { core::arch::x86_64::_mm256_and_si256(self.0, rhs.0).into() }
121            }
122        }
123
124        #[unsafe_target_feature("avx2")]
125        impl BitXor for $ty {
126            type Output = Self;
127
128            #[inline]
129            fn bitxor(self, rhs: $ty) -> Self {
130                unsafe { core::arch::x86_64::_mm256_xor_si256(self.0, rhs.0).into() }
131            }
132        }
133
134        #[allow(clippy::assign_op_pattern)]
135        #[unsafe_target_feature("avx2")]
136        impl BitAndAssign for $ty {
137            #[inline]
138            fn bitand_assign(&mut self, rhs: $ty) {
139                *self = *self & rhs;
140            }
141        }
142
143        #[allow(clippy::assign_op_pattern)]
144        #[unsafe_target_feature("avx2")]
145        impl BitXorAssign for $ty {
146            #[inline]
147            fn bitxor_assign(&mut self, rhs: $ty) {
148                *self = *self ^ rhs;
149            }
150        }
151
152        #[unsafe_target_feature("avx2")]
153        #[allow(dead_code)]
154        impl $ty {
155            #[inline]
156            pub fn shl<const N: i32>(self) -> Self {
157                unsafe { core::arch::x86_64::$shl_intrinsic(self.0, N).into() }
158            }
159
160            #[inline]
161            pub fn shr<const N: i32>(self) -> Self {
162                unsafe { core::arch::x86_64::$shr_intrinsic(self.0, N).into() }
163            }
164
165            #[inline]
166            pub fn extract<const N: i32>(self) -> $lane_ty {
167                unsafe { core::arch::x86_64::$extract_intrinsic(self.0, N) as $lane_ty }
168            }
169        }
170    };
171}
172
173macro_rules! impl_conv {
174    ($src:ident => $($dst:ident),+) => {
175        $(
176            #[unsafe_target_feature("avx2")]
177            impl From<$src> for $dst {
178                #[inline]
179                fn from(value: $src) -> $dst {
180                    $dst(value.0)
181                }
182            }
183        )+
184    }
185}
186
187// We define SIMD functionality over packed unsigned integer types. However, all the integer
188// intrinsics deal with signed integers. So we cast unsigned to signed, pack it into SIMD, do
189// add/sub/shl/shr arithmetic, and finally cast back to unsigned at the end. Why is this equivalent
190// to doing the same thing on unsigned integers? Shl/shr is clear, because casting does not change
191// the bits of the integer. But what about add/sub? This is due to the following:
192//
193//     1) Rust uses two's complement to represent signed integers. So we're assured that the values
194//        we cast into SIMD and extract out at the end are two's complement.
195//
196//        https://doc.rust-lang.org/reference/types/numeric.html
197//
198//     2) Wrapping add/sub is compatible between two's complement signed and unsigned integers.
199//        That is, for all x,y: u64 (or any unsigned integer type),
200//
201//            x.wrapping_add(y) == (x as i64).wrapping_add(y as i64) as u64, and
202//            x.wrapping_sub(y) == (x as i64).wrapping_sub(y as i64) as u64
203//
204//        https://julesjacobs.com/2019/03/20/why-twos-complement-works.html
205//
206//     3) The add/sub functions we use for SIMD are indeed wrapping. The docs indicate that
207//        __mm256_add/sub compile to vpaddX/vpsubX instructions where X = w, d, or q depending on
208//        the bitwidth. From x86 docs:
209//
210//            When an individual result is too large to be represented in X bits (overflow), the
211//            result is wrapped around and the low X bits are written to the destination operand
212//            (that is, the carry is ignored).
213//
214//        https://www.felixcloutier.com/x86/paddb:paddw:paddd:paddq
215//        https://www.felixcloutier.com/x86/psubb:psubw:psubd
216//        https://www.felixcloutier.com/x86/psubq
217
218impl_shared!(
219    u64x4,
220    u64,
221    _mm256_add_epi64,
222    _mm256_sub_epi64,
223    _mm256_slli_epi64,
224    _mm256_srli_epi64,
225    _mm256_extract_epi64
226);
227impl_shared!(
228    u32x8,
229    u32,
230    _mm256_add_epi32,
231    _mm256_sub_epi32,
232    _mm256_slli_epi32,
233    _mm256_srli_epi32,
234    _mm256_extract_epi32
235);
236
237impl_conv!(u64x4 => u32x8);
238
239#[allow(dead_code)]
240impl u64x4 {
241    /// A constified variant of `new`.
242    ///
243    /// Should only be called from `const` contexts. At runtime `new` is going to be faster.
244    #[inline]
245    pub const fn new_const(x0: u64, x1: u64, x2: u64, x3: u64) -> Self {
246        // SAFETY: Transmuting between an array and a SIMD type is safe
247        // https://rust-lang.github.io/unsafe-code-guidelines/layout/packed-simd-vectors.html
248        unsafe {
249            Self(core::mem::transmute::<[u64; 4], core::arch::x86_64::__m256i>([x0, x1, x2, x3]))
250        }
251    }
252
253    /// A constified variant of `splat`.
254    ///
255    /// Should only be called from `const` contexts. At runtime `splat` is going to be faster.
256    #[inline]
257    pub const fn splat_const<const N: u64>() -> Self {
258        Self::new_const(N, N, N, N)
259    }
260
261    /// Constructs a new instance.
262    #[unsafe_target_feature("avx2")]
263    #[inline]
264    pub fn new(x0: u64, x1: u64, x2: u64, x3: u64) -> u64x4 {
265        unsafe {
266            // _mm256_set_epi64 sets the underlying vector in reverse order of the args
267            u64x4(core::arch::x86_64::_mm256_set_epi64x(
268                x3 as i64, x2 as i64, x1 as i64, x0 as i64,
269            ))
270        }
271    }
272
273    /// Constructs a new instance with all of the elements initialized to the given value.
274    #[unsafe_target_feature("avx2")]
275    #[inline]
276    pub fn splat(x: u64) -> u64x4 {
277        unsafe { u64x4(core::arch::x86_64::_mm256_set1_epi64x(x as i64)) }
278    }
279}
280
281#[allow(dead_code)]
282impl u32x8 {
283    /// A constified variant of `new`.
284    ///
285    /// Should only be called from `const` contexts. At runtime `new` is going to be faster.
286    #[allow(clippy::too_many_arguments)]
287    #[inline]
288    pub const fn new_const(
289        x0: u32,
290        x1: u32,
291        x2: u32,
292        x3: u32,
293        x4: u32,
294        x5: u32,
295        x6: u32,
296        x7: u32,
297    ) -> Self {
298        // SAFETY: Transmuting between an array and a SIMD type is safe
299        // https://rust-lang.github.io/unsafe-code-guidelines/layout/packed-simd-vectors.html
300        unsafe {
301            Self(
302                core::mem::transmute::<[u32; 8], core::arch::x86_64::__m256i>([
303                    x0, x1, x2, x3, x4, x5, x6, x7,
304                ]),
305            )
306        }
307    }
308
309    /// A constified variant of `splat`.
310    ///
311    /// Should only be called from `const` contexts. At runtime `splat` is going to be faster.
312    #[inline]
313    pub const fn splat_const<const N: u32>() -> Self {
314        Self::new_const(N, N, N, N, N, N, N, N)
315    }
316
317    /// Constructs a new instance.
318    #[allow(clippy::too_many_arguments)]
319    #[unsafe_target_feature("avx2")]
320    #[inline]
321    pub fn new(x0: u32, x1: u32, x2: u32, x3: u32, x4: u32, x5: u32, x6: u32, x7: u32) -> u32x8 {
322        unsafe {
323            // _mm256_set_epi32 sets the underlying vector in reverse order of the args
324            u32x8(core::arch::x86_64::_mm256_set_epi32(
325                x7 as i32, x6 as i32, x5 as i32, x4 as i32, x3 as i32, x2 as i32, x1 as i32,
326                x0 as i32,
327            ))
328        }
329    }
330
331    /// Constructs a new instance with all of the elements initialized to the given value.
332    #[unsafe_target_feature("avx2")]
333    #[inline]
334    pub fn splat(x: u32) -> u32x8 {
335        unsafe { u32x8(core::arch::x86_64::_mm256_set1_epi32(x as i32)) }
336    }
337}
338
339#[unsafe_target_feature("avx2")]
340impl u32x8 {
341    /// Multiplies the low unsigned 32-bits from each packed 64-bit element
342    /// and returns the unsigned 64-bit results.
343    ///
344    /// (This ignores the upper 32-bits from each packed 64-bits!)
345    #[inline]
346    pub fn mul32(self, rhs: u32x8) -> u64x4 {
347        // NOTE: This ignores the upper 32-bits from each packed 64-bits.
348        unsafe { core::arch::x86_64::_mm256_mul_epu32(self.0, rhs.0).into() }
349    }
350}