fearless_simd/core_arch/x86/
avx.rs

1// Copyright 2024 the Fearless_SIMD Authors
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! Access to AVX intrinsics.
5
6use crate::impl_macros::delegate;
7#[cfg(target_arch = "x86")]
8use core::arch::x86 as arch;
9#[cfg(target_arch = "x86_64")]
10use core::arch::x86_64 as arch;
11
12use arch::*;
13
14/// A token for AVX intrinsics on `x86` and `x86_64`.
15#[derive(Clone, Copy, Debug)]
16pub struct Avx {
17    _private: (),
18}
19
20#[expect(
21    clippy::missing_safety_doc,
22    reason = "TODO: https://github.com/linebender/fearless_simd/issues/40"
23)]
24impl Avx {
25    /// Create a SIMD token.
26    ///
27    /// # Safety
28    ///
29    /// The required CPU features must be available.
30    #[inline]
31    pub unsafe fn new_unchecked() -> Self {
32        Self { _private: () }
33    }
34
35    delegate! { arch:
36        fn _mm256_add_pd(a: __m256d, b: __m256d) -> __m256d;
37        fn _mm256_add_ps(a: __m256, b: __m256) -> __m256;
38        fn _mm256_and_pd(a: __m256d, b: __m256d) -> __m256d;
39        fn _mm256_and_ps(a: __m256, b: __m256) -> __m256;
40        fn _mm256_or_pd(a: __m256d, b: __m256d) -> __m256d;
41        fn _mm256_or_ps(a: __m256, b: __m256) -> __m256;
42        fn _mm256_shuffle_pd<const MASK: i32>(a: __m256d, b: __m256d) -> __m256d;
43        fn _mm256_shuffle_ps<const MASK: i32>(a: __m256, b: __m256) -> __m256;
44        fn _mm256_andnot_pd(a: __m256d, b: __m256d) -> __m256d;
45        fn _mm256_andnot_ps(a: __m256, b: __m256) -> __m256;
46        fn _mm256_max_pd(a: __m256d, b: __m256d) -> __m256d;
47        fn _mm256_max_ps(a: __m256, b: __m256) -> __m256;
48        fn _mm256_min_pd(a: __m256d, b: __m256d) -> __m256d;
49        fn _mm256_min_ps(a: __m256, b: __m256) -> __m256;
50        fn _mm256_mul_pd(a: __m256d, b: __m256d) -> __m256d;
51        fn _mm256_mul_ps(a: __m256, b: __m256) -> __m256;
52        fn _mm256_addsub_pd(a: __m256d, b: __m256d) -> __m256d;
53        fn _mm256_addsub_ps(a: __m256, b: __m256) -> __m256;
54        fn _mm256_sub_pd(a: __m256d, b: __m256d) -> __m256d;
55        fn _mm256_sub_ps(a: __m256, b: __m256) -> __m256;
56        fn _mm256_div_ps(a: __m256, b: __m256) -> __m256;
57        fn _mm256_div_pd(a: __m256d, b: __m256d) -> __m256d;
58        fn _mm256_round_pd<const ROUNDING: i32>(a: __m256d) -> __m256d;
59        fn _mm256_ceil_pd(a: __m256d) -> __m256d;
60        fn _mm256_floor_pd(a: __m256d) -> __m256d;
61        fn _mm256_round_ps<const ROUNDING: i32>(a: __m256) -> __m256;
62        fn _mm256_ceil_ps(a: __m256) -> __m256;
63        fn _mm256_floor_ps(a: __m256) -> __m256;
64        fn _mm256_sqrt_ps(a: __m256) -> __m256;
65        fn _mm256_sqrt_pd(a: __m256d) -> __m256d;
66        fn _mm256_blend_pd<const IMM4: i32>(a: __m256d, b: __m256d) -> __m256d;
67        fn _mm256_blend_ps<const IMM8: i32>(a: __m256, b: __m256) -> __m256;
68        fn _mm256_blendv_pd(a: __m256d, b: __m256d, c: __m256d) -> __m256d;
69        fn _mm256_blendv_ps(a: __m256, b: __m256, c: __m256) -> __m256;
70        fn _mm256_dp_ps<const IMM8: i32>(a: __m256, b: __m256) -> __m256;
71        fn _mm256_hadd_pd(a: __m256d, b: __m256d) -> __m256d;
72        fn _mm256_hadd_ps(a: __m256, b: __m256) -> __m256;
73        fn _mm256_hsub_pd(a: __m256d, b: __m256d) -> __m256d;
74        fn _mm256_hsub_ps(a: __m256, b: __m256) -> __m256;
75        fn _mm256_xor_pd(a: __m256d, b: __m256d) -> __m256d;
76        fn _mm256_xor_ps(a: __m256, b: __m256) -> __m256;
77        fn _mm_cmp_pd<const IMM5: i32>(a: __m128d, b: __m128d) -> __m128d;
78        fn _mm256_cmp_pd<const IMM5: i32>(a: __m256d, b: __m256d) -> __m256d;
79        fn _mm_cmp_ps<const IMM5: i32>(a: __m128, b: __m128) -> __m128;
80        fn _mm256_cmp_ps<const IMM5: i32>(a: __m256, b: __m256) -> __m256;
81        fn _mm_cmp_sd<const IMM5: i32>(a: __m128d, b: __m128d) -> __m128d;
82        fn _mm_cmp_ss<const IMM5: i32>(a: __m128, b: __m128) -> __m128;
83        fn _mm256_cvtepi32_pd(a: __m128i) -> __m256d;
84        fn _mm256_cvtepi32_ps(a: __m256i) -> __m256;
85        fn _mm256_cvtpd_ps(a: __m256d) -> __m128;
86        fn _mm256_cvtps_epi32(a: __m256) -> __m256i;
87        fn _mm256_cvtps_pd(a: __m128) -> __m256d;
88        fn _mm256_cvttpd_epi32(a: __m256d) -> __m128i;
89        fn _mm256_cvtpd_epi32(a: __m256d) -> __m128i;
90        fn _mm256_cvttps_epi32(a: __m256) -> __m256i;
91        fn _mm256_extractf128_ps<const IMM1: i32>(a: __m256) -> __m128;
92        fn _mm256_extractf128_pd<const IMM1: i32>(a: __m256d) -> __m128d;
93        fn _mm256_extractf128_si256<const IMM1: i32>(a: __m256i) -> __m128i;
94        fn _mm256_zeroall();
95        fn _mm256_zeroupper();
96        fn _mm256_permutevar_ps(a: __m256, b: __m256i) -> __m256;
97        fn _mm_permutevar_ps(a: __m128, b: __m128i) -> __m128;
98        fn _mm256_permute_ps<const IMM8: i32>(a: __m256) -> __m256;
99        fn _mm_permute_ps<const IMM8: i32>(a: __m128) -> __m128;
100        fn _mm256_permutevar_pd(a: __m256d, b: __m256i) -> __m256d;
101        fn _mm_permutevar_pd(a: __m128d, b: __m128i) -> __m128d;
102        fn _mm256_permute_pd<const IMM4: i32>(a: __m256d) -> __m256d;
103        fn _mm_permute_pd<const IMM2: i32>(a: __m128d) -> __m128d;
104        fn _mm256_permute2f128_ps<const IMM8: i32>(a: __m256, b: __m256) -> __m256;
105        fn _mm256_permute2f128_pd<const IMM8: i32>(a: __m256d, b: __m256d) -> __m256d;
106        fn _mm256_permute2f128_si256<const IMM8: i32>(a: __m256i, b: __m256i) -> __m256i;
107        fn _mm256_broadcast_ss(f: &f32) -> __m256;
108        fn _mm_broadcast_ss(f: &f32) -> __m128;
109        fn _mm256_broadcast_sd(f: &f64) -> __m256d;
110        fn _mm256_broadcast_ps(a: &__m128) -> __m256;
111        fn _mm256_broadcast_pd(a: &__m128d) -> __m256d;
112        fn _mm256_insertf128_ps<const IMM1: i32>(a: __m256, b: __m128) -> __m256;
113        fn _mm256_insertf128_pd<const IMM1: i32>(a: __m256d, b: __m128d) -> __m256d;
114        fn _mm256_insertf128_si256<const IMM1: i32>(a: __m256i, b: __m128i) -> __m256i;
115        fn _mm256_insert_epi8<const INDEX: i32>(a: __m256i, i: i8) -> __m256i;
116        fn _mm256_insert_epi16<const INDEX: i32>(a: __m256i, i: i16) -> __m256i;
117        fn _mm256_insert_epi32<const INDEX: i32>(a: __m256i, i: i32) -> __m256i;
118        unsafe fn _mm256_load_pd(mem_addr: *const f64) -> __m256d;
119        unsafe fn _mm256_store_pd(mem_addr: *mut f64, a: __m256d);
120        unsafe fn _mm256_load_ps(mem_addr: *const f32) -> __m256;
121        unsafe fn _mm256_store_ps(mem_addr: *mut f32, a: __m256);
122        unsafe fn _mm256_loadu_pd(mem_addr: *const f64) -> __m256d;
123        unsafe fn _mm256_storeu_pd(mem_addr: *mut f64, a: __m256d);
124        unsafe fn _mm256_loadu_ps(mem_addr: *const f32) -> __m256;
125        unsafe fn _mm256_storeu_ps(mem_addr: *mut f32, a: __m256);
126        unsafe fn _mm256_load_si256(mem_addr: *const __m256i) -> __m256i;
127        unsafe fn _mm256_store_si256(mem_addr: *mut __m256i, a: __m256i);
128        unsafe fn _mm256_loadu_si256(mem_addr: *const __m256i) -> __m256i;
129        unsafe fn _mm256_storeu_si256(mem_addr: *mut __m256i, a: __m256i);
130        unsafe fn _mm256_maskload_pd(mem_addr: *const f64, mask: __m256i) -> __m256d;
131        unsafe fn _mm256_maskstore_pd(mem_addr: *mut f64, mask: __m256i, a: __m256d);
132        unsafe fn _mm_maskload_pd(mem_addr: *const f64, mask: __m128i) -> __m128d;
133        unsafe fn _mm_maskstore_pd(mem_addr: *mut f64, mask: __m128i, a: __m128d);
134        unsafe fn _mm256_maskload_ps(mem_addr: *const f32, mask: __m256i) -> __m256;
135        unsafe fn _mm256_maskstore_ps(mem_addr: *mut f32, mask: __m256i, a: __m256);
136        unsafe fn _mm_maskload_ps(mem_addr: *const f32, mask: __m128i) -> __m128;
137        unsafe fn _mm_maskstore_ps(mem_addr: *mut f32, mask: __m128i, a: __m128);
138        fn _mm256_movehdup_ps(a: __m256) -> __m256;
139        fn _mm256_moveldup_ps(a: __m256) -> __m256;
140        fn _mm256_movedup_pd(a: __m256d) -> __m256d;
141        unsafe fn _mm256_lddqu_si256(mem_addr: *const __m256i) -> __m256i;
142        unsafe fn _mm256_stream_si256(mem_addr: *mut __m256i, a: __m256i);
143        unsafe fn _mm256_stream_pd(mem_addr: *mut f64, a: __m256d);
144        unsafe fn _mm256_stream_ps(mem_addr: *mut f32, a: __m256);
145        fn _mm256_rcp_ps(a: __m256) -> __m256;
146        fn _mm256_rsqrt_ps(a: __m256) -> __m256;
147        fn _mm256_unpackhi_pd(a: __m256d, b: __m256d) -> __m256d;
148        fn _mm256_unpackhi_ps(a: __m256, b: __m256) -> __m256;
149        fn _mm256_unpacklo_pd(a: __m256d, b: __m256d) -> __m256d;
150        fn _mm256_unpacklo_ps(a: __m256, b: __m256) -> __m256;
151        fn _mm256_testz_si256(a: __m256i, b: __m256i) -> i32;
152        fn _mm256_testc_si256(a: __m256i, b: __m256i) -> i32;
153        fn _mm256_testnzc_si256(a: __m256i, b: __m256i) -> i32;
154        fn _mm256_testz_pd(a: __m256d, b: __m256d) -> i32;
155        fn _mm256_testc_pd(a: __m256d, b: __m256d) -> i32;
156        fn _mm256_testnzc_pd(a: __m256d, b: __m256d) -> i32;
157        fn _mm_testz_pd(a: __m128d, b: __m128d) -> i32;
158        fn _mm_testc_pd(a: __m128d, b: __m128d) -> i32;
159        fn _mm_testnzc_pd(a: __m128d, b: __m128d) -> i32;
160        fn _mm256_testz_ps(a: __m256, b: __m256) -> i32;
161        fn _mm256_testc_ps(a: __m256, b: __m256) -> i32;
162        fn _mm256_testnzc_ps(a: __m256, b: __m256) -> i32;
163        fn _mm_testz_ps(a: __m128, b: __m128) -> i32;
164        fn _mm_testc_ps(a: __m128, b: __m128) -> i32;
165        fn _mm_testnzc_ps(a: __m128, b: __m128) -> i32;
166        fn _mm256_movemask_pd(a: __m256d) -> i32;
167        fn _mm256_movemask_ps(a: __m256) -> i32;
168        fn _mm256_setzero_pd() -> __m256d;
169        fn _mm256_setzero_ps() -> __m256;
170        fn _mm256_setzero_si256() -> __m256i;
171        fn _mm256_set_pd(a: f64, b: f64, c: f64, d: f64) -> __m256d;
172        fn _mm256_set_ps(a: f32, b: f32, c: f32, d: f32, e: f32, f: f32, g: f32, h: f32) -> __m256;
173        fn _mm256_set_epi8(
174            e00: i8,
175            e01: i8,
176            e02: i8,
177            e03: i8,
178            e04: i8,
179            e05: i8,
180            e06: i8,
181            e07: i8,
182            e08: i8,
183            e09: i8,
184            e10: i8,
185            e11: i8,
186            e12: i8,
187            e13: i8,
188            e14: i8,
189            e15: i8,
190            e16: i8,
191            e17: i8,
192            e18: i8,
193            e19: i8,
194            e20: i8,
195            e21: i8,
196            e22: i8,
197            e23: i8,
198            e24: i8,
199            e25: i8,
200            e26: i8,
201            e27: i8,
202            e28: i8,
203            e29: i8,
204            e30: i8,
205            e31: i8,
206        ) -> __m256i;
207        fn _mm256_set_epi16(
208            e00: i16,
209            e01: i16,
210            e02: i16,
211            e03: i16,
212            e04: i16,
213            e05: i16,
214            e06: i16,
215            e07: i16,
216            e08: i16,
217            e09: i16,
218            e10: i16,
219            e11: i16,
220            e12: i16,
221            e13: i16,
222            e14: i16,
223            e15: i16,
224        ) -> __m256i;
225        fn _mm256_set_epi32(
226            e0: i32,
227            e1: i32,
228            e2: i32,
229            e3: i32,
230            e4: i32,
231            e5: i32,
232            e6: i32,
233            e7: i32,
234        ) -> __m256i;
235        fn _mm256_set_epi64x(a: i64, b: i64, c: i64, d: i64) -> __m256i;
236        fn _mm256_setr_pd(a: f64, b: f64, c: f64, d: f64) -> __m256d;
237        fn _mm256_setr_ps(a: f32, b: f32, c: f32, d: f32, e: f32, f: f32, g: f32, h: f32)
238        -> __m256;
239        fn _mm256_setr_epi8(
240            e00: i8,
241            e01: i8,
242            e02: i8,
243            e03: i8,
244            e04: i8,
245            e05: i8,
246            e06: i8,
247            e07: i8,
248            e08: i8,
249            e09: i8,
250            e10: i8,
251            e11: i8,
252            e12: i8,
253            e13: i8,
254            e14: i8,
255            e15: i8,
256            e16: i8,
257            e17: i8,
258            e18: i8,
259            e19: i8,
260            e20: i8,
261            e21: i8,
262            e22: i8,
263            e23: i8,
264            e24: i8,
265            e25: i8,
266            e26: i8,
267            e27: i8,
268            e28: i8,
269            e29: i8,
270            e30: i8,
271            e31: i8,
272        ) -> __m256i;
273        fn _mm256_setr_epi16(
274            e00: i16,
275            e01: i16,
276            e02: i16,
277            e03: i16,
278            e04: i16,
279            e05: i16,
280            e06: i16,
281            e07: i16,
282            e08: i16,
283            e09: i16,
284            e10: i16,
285            e11: i16,
286            e12: i16,
287            e13: i16,
288            e14: i16,
289            e15: i16,
290        ) -> __m256i;
291        fn _mm256_setr_epi32(
292            e0: i32,
293            e1: i32,
294            e2: i32,
295            e3: i32,
296            e4: i32,
297            e5: i32,
298            e6: i32,
299            e7: i32,
300        ) -> __m256i;
301        fn _mm256_setr_epi64x(a: i64, b: i64, c: i64, d: i64) -> __m256i;
302        fn _mm256_set1_pd(a: f64) -> __m256d;
303        fn _mm256_set1_ps(a: f32) -> __m256;
304        fn _mm256_set1_epi8(a: i8) -> __m256i;
305        fn _mm256_set1_epi16(a: i16) -> __m256i;
306        fn _mm256_set1_epi32(a: i32) -> __m256i;
307        fn _mm256_set1_epi64x(a: i64) -> __m256i;
308        fn _mm256_castpd_ps(a: __m256d) -> __m256;
309        fn _mm256_castps_pd(a: __m256) -> __m256d;
310        fn _mm256_castps_si256(a: __m256) -> __m256i;
311        fn _mm256_castsi256_ps(a: __m256i) -> __m256;
312        fn _mm256_castpd_si256(a: __m256d) -> __m256i;
313        fn _mm256_castsi256_pd(a: __m256i) -> __m256d;
314        fn _mm256_castps256_ps128(a: __m256) -> __m128;
315        fn _mm256_castpd256_pd128(a: __m256d) -> __m128d;
316        fn _mm256_castsi256_si128(a: __m256i) -> __m128i;
317        fn _mm256_castps128_ps256(a: __m128) -> __m256;
318        fn _mm256_castpd128_pd256(a: __m128d) -> __m256d;
319        fn _mm256_castsi128_si256(a: __m128i) -> __m256i;
320        fn _mm256_zextps128_ps256(a: __m128) -> __m256;
321        fn _mm256_zextsi128_si256(a: __m128i) -> __m256i;
322        fn _mm256_zextpd128_pd256(a: __m128d) -> __m256d;
323        fn _mm256_undefined_ps() -> __m256;
324        fn _mm256_undefined_pd() -> __m256d;
325        fn _mm256_undefined_si256() -> __m256i;
326        fn _mm256_set_m128(hi: __m128, lo: __m128) -> __m256;
327        fn _mm256_set_m128d(hi: __m128d, lo: __m128d) -> __m256d;
328        fn _mm256_set_m128i(hi: __m128i, lo: __m128i) -> __m256i;
329        fn _mm256_setr_m128(lo: __m128, hi: __m128) -> __m256;
330        fn _mm256_setr_m128d(lo: __m128d, hi: __m128d) -> __m256d;
331        fn _mm256_setr_m128i(lo: __m128i, hi: __m128i) -> __m256i;
332        unsafe fn _mm256_loadu2_m128(hiaddr: *const f32, loaddr: *const f32) -> __m256;
333        unsafe fn _mm256_loadu2_m128d(hiaddr: *const f64, loaddr: *const f64) -> __m256d;
334        unsafe fn _mm256_loadu2_m128i(hiaddr: *const __m128i, loaddr: *const __m128i) -> __m256i;
335        unsafe fn _mm256_storeu2_m128(hiaddr: *mut f32, loaddr: *mut f32, a: __m256);
336        unsafe fn _mm256_storeu2_m128d(hiaddr: *mut f64, loaddr: *mut f64, a: __m256d);
337        unsafe fn _mm256_storeu2_m128i(hiaddr: *mut __m128i, loaddr: *mut __m128i, a: __m256i);
338        fn _mm256_cvtss_f32(a: __m256) -> f32;
339    }
340}