1use bytemuck::cast;
9
10use super::{i32x8, u32x8};
11
12cfg_if::cfg_if! {
13 if #[cfg(all(feature = "simd", target_feature = "avx"))] {
14 #[cfg(target_arch = "x86")]
15 use core::arch::x86::*;
16 #[cfg(target_arch = "x86_64")]
17 use core::arch::x86_64::*;
18
19 #[derive(Clone, Copy, Debug)]
20 #[repr(C, align(32))]
21 pub struct f32x8(__m256);
22 } else {
23 use super::f32x4;
24
25 #[derive(Clone, Copy, Debug)]
26 #[repr(C, align(32))]
27 pub struct f32x8(pub f32x4, pub f32x4);
28 }
29}
30
31unsafe impl bytemuck::Zeroable for f32x8 {}
32unsafe impl bytemuck::Pod for f32x8 {}
33
34impl Default for f32x8 {
35 fn default() -> Self {
36 Self::splat(0.0)
37 }
38}
39
40impl f32x8 {
41 pub fn splat(n: f32) -> Self {
42 cast([n, n, n, n, n, n, n, n])
43 }
44
45 pub fn floor(self) -> Self {
46 cfg_if::cfg_if! {
47 if #[cfg(all(feature = "simd", target_feature = "simd128"))] {
48 Self(self.0.floor(), self.1.floor())
49 } else {
50 let roundtrip: f32x8 = cast(self.trunc_int().to_f32x8());
51 roundtrip
52 - roundtrip
53 .cmp_gt(self)
54 .blend(f32x8::splat(1.0), f32x8::default())
55 }
56 }
57 }
58
59 pub fn fract(self) -> Self {
60 self - self.floor()
61 }
62
63 pub fn normalize(self) -> Self {
64 self.max(f32x8::default()).min(f32x8::splat(1.0))
65 }
66
67 pub fn to_i32x8_bitcast(self) -> i32x8 {
68 bytemuck::cast(self)
69 }
70
71 pub fn to_u32x8_bitcast(self) -> u32x8 {
72 bytemuck::cast(self)
73 }
74
75 pub fn cmp_eq(self, rhs: Self) -> Self {
76 cfg_if::cfg_if! {
77 if #[cfg(all(feature = "simd", target_feature = "avx"))] {
78 Self(unsafe { _mm256_cmp_ps(self.0, rhs.0, _CMP_EQ_OQ) })
79 } else {
80 Self(self.0.cmp_eq(rhs.0), self.1.cmp_eq(rhs.1))
81 }
82 }
83 }
84
85 pub fn cmp_ne(self, rhs: Self) -> Self {
86 cfg_if::cfg_if! {
87 if #[cfg(all(feature = "simd", target_feature = "avx"))] {
88 Self(unsafe { _mm256_cmp_ps(self.0, rhs.0, _CMP_NEQ_UQ) })
91 } else {
92 Self(self.0.cmp_ne(rhs.0), self.1.cmp_ne(rhs.1))
93 }
94 }
95 }
96
97 pub fn cmp_ge(self, rhs: Self) -> Self {
98 cfg_if::cfg_if! {
99 if #[cfg(all(feature = "simd", target_feature = "avx"))] {
100 Self(unsafe { _mm256_cmp_ps(self.0, rhs.0, _CMP_GE_OQ) })
101 } else {
102 Self(self.0.cmp_ge(rhs.0), self.1.cmp_ge(rhs.1))
103 }
104 }
105 }
106
107 pub fn cmp_gt(self, rhs: Self) -> Self {
108 cfg_if::cfg_if! {
109 if #[cfg(all(feature = "simd", target_feature = "avx"))] {
110 Self(unsafe { _mm256_cmp_ps(self.0, rhs.0, _CMP_GT_OQ) })
111 } else {
112 Self(self.0.cmp_gt(rhs.0), self.1.cmp_gt(rhs.1))
113 }
114 }
115 }
116
117 pub fn cmp_le(self, rhs: Self) -> Self {
118 cfg_if::cfg_if! {
119 if #[cfg(all(feature = "simd", target_feature = "avx"))] {
120 Self(unsafe { _mm256_cmp_ps(self.0, rhs.0, _CMP_LE_OQ) })
121 } else {
122 Self(self.0.cmp_le(rhs.0), self.1.cmp_le(rhs.1))
123 }
124 }
125 }
126
127 pub fn cmp_lt(self, rhs: Self) -> Self {
128 cfg_if::cfg_if! {
129 if #[cfg(all(feature = "simd", target_feature = "avx"))] {
130 Self(unsafe { _mm256_cmp_ps(self.0, rhs.0, _CMP_LT_OQ) })
131 } else {
132 Self(self.0.cmp_lt(rhs.0), self.1.cmp_lt(rhs.1))
133 }
134 }
135 }
136
137 #[inline]
138 pub fn blend(self, t: Self, f: Self) -> Self {
139 cfg_if::cfg_if! {
140 if #[cfg(all(feature = "simd", target_feature = "avx"))] {
141 Self(unsafe { _mm256_blendv_ps(f.0, t.0, self.0) })
142 } else {
143 Self(self.0.blend(t.0, f.0), self.1.blend(t.1, f.1))
144 }
145 }
146 }
147
148 pub fn abs(self) -> Self {
149 let non_sign_bits = f32x8::splat(f32::from_bits(i32::MAX as u32));
150 self & non_sign_bits
151 }
152
153 pub fn max(self, rhs: Self) -> Self {
154 cfg_if::cfg_if! {
157 if #[cfg(all(feature = "simd", target_feature = "avx"))] {
158 Self(unsafe { _mm256_max_ps(self.0, rhs.0) })
159 } else {
160 Self(self.0.max(rhs.0), self.1.max(rhs.1))
161 }
162 }
163 }
164
165 pub fn min(self, rhs: Self) -> Self {
166 cfg_if::cfg_if! {
169 if #[cfg(all(feature = "simd", target_feature = "avx"))] {
170 Self(unsafe { _mm256_min_ps(self.0, rhs.0) })
171 } else {
172 Self(self.0.min(rhs.0), self.1.min(rhs.1))
173 }
174 }
175 }
176
177 pub fn is_finite(self) -> Self {
178 let shifted_exp_mask = u32x8::splat(0xFF000000);
179 let u: u32x8 = cast(self);
180 let shift_u = u.shl::<1>();
181 let out = !(shift_u & shifted_exp_mask).cmp_eq(shifted_exp_mask);
182 cast(out)
183 }
184
185 pub fn round(self) -> Self {
186 cfg_if::cfg_if! {
187 if #[cfg(all(feature = "simd", target_feature = "avx"))] {
188 Self(unsafe { _mm256_round_ps(self.0, _MM_FROUND_NO_EXC | _MM_FROUND_TO_NEAREST_INT) })
189 } else {
190 Self(self.0.round(), self.1.round())
191 }
192 }
193 }
194
195 pub fn round_int(self) -> i32x8 {
196 cfg_if::cfg_if! {
200 if #[cfg(all(feature = "simd", target_feature = "avx"))] {
201 cast(unsafe { _mm256_cvtps_epi32(self.0) })
202 } else {
203 i32x8(self.0.round_int(), self.1.round_int())
204 }
205 }
206 }
207
208 pub fn trunc_int(self) -> i32x8 {
209 cfg_if::cfg_if! {
213 if #[cfg(all(feature = "simd", target_feature = "avx"))] {
214 cast(unsafe { _mm256_cvttps_epi32(self.0) })
215 } else {
216 i32x8(self.0.trunc_int(), self.1.trunc_int())
217 }
218 }
219 }
220
221 pub fn recip_fast(self) -> Self {
222 cfg_if::cfg_if! {
223 if #[cfg(all(feature = "simd", target_feature = "avx"))] {
224 Self(unsafe { _mm256_rcp_ps(self.0) })
225 } else {
226 Self(self.0.recip_fast(), self.1.recip_fast())
227 }
228 }
229 }
230
231 pub fn recip_sqrt(self) -> Self {
232 cfg_if::cfg_if! {
233 if #[cfg(all(feature = "simd", target_feature = "avx"))] {
234 Self(unsafe { _mm256_rsqrt_ps(self.0) })
235 } else {
236 Self(self.0.recip_sqrt(), self.1.recip_sqrt())
237 }
238 }
239 }
240
241 pub fn sqrt(self) -> Self {
242 cfg_if::cfg_if! {
243 if #[cfg(all(feature = "simd", target_feature = "avx"))] {
244 Self(unsafe { _mm256_sqrt_ps(self.0) })
245 } else {
246 Self(self.0.sqrt(), self.1.sqrt())
247 }
248 }
249 }
250
251 pub fn powf(self, exp: f32) -> Self {
252 let x = self;
253 let e = x.to_i32x8_bitcast().to_f32x8() * f32x8::splat(1.0f32 / ((1 << 23) as f32));
255 let m = (x.to_u32x8_bitcast() & u32x8::splat(0x007fffff) | u32x8::splat(0x3f000000))
256 .to_f32x8_bitcast();
257
258 let log2_x = e
259 - f32x8::splat(124.225514990)
260 - f32x8::splat(1.498030302) * m
261 - f32x8::splat(1.725879990) / (f32x8::splat(0.3520887068) + m);
262
263 let x = log2_x * f32x8::splat(exp);
264
265 let f = x - x.floor();
266
267 let mut a = x + f32x8::splat(121.274057500);
268 a = a - f * f32x8::splat(1.490129070);
269 a += f32x8::splat(27.728023300) / (f32x8::splat(4.84252568) - f);
270 a *= f32x8::splat((1 << 23) as f32);
271
272 let inf_bits = f32x8::splat(f32::INFINITY.to_bits() as f32);
273
274 let x = a
275 .max(f32x8::splat(0.0))
276 .min(inf_bits)
277 .round_int()
278 .to_f32x8_bitcast();
279
280 let skip = self.cmp_eq(f32x8::splat(0.0)) | self.cmp_eq(f32x8::splat(1.0));
281 skip.blend(self, x)
282 }
283}
284
285impl From<[f32; 8]> for f32x8 {
286 fn from(v: [f32; 8]) -> Self {
287 cast(v)
288 }
289}
290
291impl From<f32x8> for [f32; 8] {
292 fn from(v: f32x8) -> Self {
293 cast(v)
294 }
295}
296
297impl core::ops::Add for f32x8 {
298 type Output = Self;
299
300 fn add(self, rhs: Self) -> Self::Output {
301 cfg_if::cfg_if! {
302 if #[cfg(all(feature = "simd", target_feature = "avx"))] {
303 Self(unsafe { _mm256_add_ps(self.0, rhs.0) })
304 } else {
305 Self(self.0 + rhs.0, self.1 + rhs.1)
306 }
307 }
308 }
309}
310
311impl core::ops::AddAssign for f32x8 {
312 fn add_assign(&mut self, rhs: f32x8) {
313 *self = *self + rhs;
314 }
315}
316
317impl core::ops::Sub for f32x8 {
318 type Output = Self;
319
320 fn sub(self, rhs: Self) -> Self::Output {
321 cfg_if::cfg_if! {
322 if #[cfg(all(feature = "simd", target_feature = "avx"))] {
323 Self(unsafe { _mm256_sub_ps(self.0, rhs.0) })
324 } else {
325 Self(self.0 - rhs.0, self.1 - rhs.1)
326 }
327 }
328 }
329}
330
331impl core::ops::Mul for f32x8 {
332 type Output = Self;
333
334 fn mul(self, rhs: Self) -> Self::Output {
335 cfg_if::cfg_if! {
336 if #[cfg(all(feature = "simd", target_feature = "avx"))] {
337 Self(unsafe { _mm256_mul_ps(self.0, rhs.0) })
338 } else {
339 Self(self.0 * rhs.0, self.1 * rhs.1)
340 }
341 }
342 }
343}
344
345impl core::ops::MulAssign for f32x8 {
346 fn mul_assign(&mut self, rhs: f32x8) {
347 *self = *self * rhs;
348 }
349}
350
351impl core::ops::Div for f32x8 {
352 type Output = Self;
353
354 fn div(self, rhs: Self) -> Self::Output {
355 cfg_if::cfg_if! {
356 if #[cfg(all(feature = "simd", target_feature = "avx"))] {
357 Self(unsafe { _mm256_div_ps(self.0, rhs.0) })
358 } else {
359 Self(self.0 / rhs.0, self.1 / rhs.1)
360 }
361 }
362 }
363}
364
365impl core::ops::BitAnd for f32x8 {
366 type Output = Self;
367
368 #[inline(always)]
369 fn bitand(self, rhs: Self) -> Self::Output {
370 cfg_if::cfg_if! {
371 if #[cfg(all(feature = "simd", target_feature = "avx"))] {
372 Self(unsafe { _mm256_and_ps(self.0, rhs.0) })
373 } else {
374 Self(self.0 & rhs.0, self.1 & rhs.1)
375 }
376 }
377 }
378}
379
380impl core::ops::BitOr for f32x8 {
381 type Output = Self;
382
383 #[inline(always)]
384 fn bitor(self, rhs: Self) -> Self::Output {
385 cfg_if::cfg_if! {
386 if #[cfg(all(feature = "simd", target_feature = "avx"))] {
387 Self(unsafe { _mm256_or_ps(self.0, rhs.0) })
388 } else {
389 Self(self.0 | rhs.0, self.1 | rhs.1)
390 }
391 }
392 }
393}
394
395impl core::ops::BitXor for f32x8 {
396 type Output = Self;
397
398 #[inline(always)]
399 fn bitxor(self, rhs: Self) -> Self::Output {
400 cfg_if::cfg_if! {
401 if #[cfg(all(feature = "simd", target_feature = "avx"))] {
402 Self(unsafe { _mm256_xor_ps(self.0, rhs.0) })
403 } else {
404 Self(self.0 ^ rhs.0, self.1 ^ rhs.1)
405 }
406 }
407 }
408}
409
410impl core::ops::Neg for f32x8 {
411 type Output = Self;
412
413 fn neg(self) -> Self {
414 Self::default() - self
415 }
416}
417
418impl core::ops::Not for f32x8 {
419 type Output = Self;
420
421 fn not(self) -> Self {
422 cfg_if::cfg_if! {
423 if #[cfg(all(feature = "simd", target_feature = "avx"))] {
424 let all_bits = unsafe { _mm256_set1_ps(f32::from_bits(u32::MAX)) };
425 Self(unsafe { _mm256_xor_ps(self.0, all_bits) })
426 } else {
427 Self(!self.0, !self.1)
428 }
429 }
430 }
431}
432
433impl core::cmp::PartialEq for f32x8 {
434 fn eq(&self, rhs: &Self) -> bool {
435 cfg_if::cfg_if! {
436 if #[cfg(all(feature = "simd", target_feature = "avx"))] {
437 let mask = unsafe { _mm256_cmp_ps(self.0, rhs.0, _CMP_EQ_OQ) };
438 unsafe { _mm256_movemask_ps(mask) == 0b1111_1111 }
439 } else {
440 self.0 == rhs.0 && self.1 == rhs.1
441 }
442 }
443 }
444}