1#![cfg(feature = "sse_luts")]
30use crate::conversions::interpolator::BarycentricWeight;
31use crate::math::FusedMultiplyAdd;
32#[cfg(target_arch = "x86")]
33use std::arch::x86::*;
34#[cfg(target_arch = "x86_64")]
35use std::arch::x86_64::*;
36use std::ops::{Add, Mul, Sub};
37
38#[repr(align(16), C)]
39pub(crate) struct SseAlignedF32(pub(crate) [f32; 4]);
40
41#[cfg(feature = "options")]
42pub(crate) struct TetrahedralSse<const GRID_SIZE: usize> {}
43
44#[cfg(feature = "options")]
45pub(crate) struct PyramidalSse<const GRID_SIZE: usize> {}
46
47#[cfg(feature = "options")]
48pub(crate) struct PrismaticSse<const GRID_SIZE: usize> {}
49
50pub(crate) struct TrilinearSse<const GRID_SIZE: usize> {}
51
52trait Fetcher<T> {
53 fn fetch(&self, x: i32, y: i32, z: i32) -> T;
54}
55
56#[derive(Copy, Clone)]
57#[repr(transparent)]
58pub(crate) struct SseVector {
59 pub(crate) v: __m128,
60}
61
62impl From<f32> for SseVector {
63 #[inline(always)]
64 fn from(v: f32) -> Self {
65 SseVector {
66 v: unsafe { _mm_set1_ps(v) },
67 }
68 }
69}
70
71impl Sub<SseVector> for SseVector {
72 type Output = Self;
73 #[inline(always)]
74 fn sub(self, rhs: SseVector) -> Self::Output {
75 SseVector {
76 v: unsafe { _mm_sub_ps(self.v, rhs.v) },
77 }
78 }
79}
80
81impl Add<SseVector> for SseVector {
82 type Output = Self;
83 #[inline(always)]
84 fn add(self, rhs: SseVector) -> Self::Output {
85 SseVector {
86 v: unsafe { _mm_add_ps(self.v, rhs.v) },
87 }
88 }
89}
90
91impl Mul<SseVector> for SseVector {
92 type Output = Self;
93 #[inline(always)]
94 fn mul(self, rhs: SseVector) -> Self::Output {
95 SseVector {
96 v: unsafe { _mm_mul_ps(self.v, rhs.v) },
97 }
98 }
99}
100
101impl FusedMultiplyAdd<SseVector> for SseVector {
102 #[inline(always)]
103 fn mla(&self, b: SseVector, c: SseVector) -> SseVector {
104 SseVector {
105 v: unsafe { _mm_add_ps(self.v, _mm_mul_ps(b.v, c.v)) },
106 }
107 }
108}
109
110struct TetrahedralSseFetchVector<'a, const GRID_SIZE: usize> {
111 cube: &'a [SseAlignedF32],
112}
113
114impl<const GRID_SIZE: usize> Fetcher<SseVector> for TetrahedralSseFetchVector<'_, GRID_SIZE> {
117 #[inline(always)]
118 fn fetch(&self, x: i32, y: i32, z: i32) -> SseVector {
119 let offset = (x as u32 * (GRID_SIZE as u32 * GRID_SIZE as u32)
120 + y as u32 * GRID_SIZE as u32
121 + z as u32) as usize;
122 let jx = unsafe { self.cube.get_unchecked(offset..) };
123 SseVector {
124 v: unsafe { _mm_load_ps(jx.as_ptr() as *const _) },
125 }
126 }
127}
128
129pub(crate) trait SseMdInterpolation {
130 fn inter3_sse(
131 &self,
132 table: &[SseAlignedF32],
133 in_r: usize,
134 in_g: usize,
135 in_b: usize,
136 lut: &[BarycentricWeight<f32>],
137 ) -> SseVector;
138}
139
140#[cfg(feature = "options")]
141impl<const GRID_SIZE: usize> TetrahedralSse<GRID_SIZE> {
142 #[target_feature(enable = "sse4.1")]
143 unsafe fn interpolate(
144 &self,
145 in_r: usize,
146 in_g: usize,
147 in_b: usize,
148 lut: &[BarycentricWeight<f32>],
149 r: impl Fetcher<SseVector>,
150 ) -> SseVector {
151 let lut_r = unsafe { *lut.get_unchecked(in_r) };
152 let lut_g = unsafe { *lut.get_unchecked(in_g) };
153 let lut_b = unsafe { *lut.get_unchecked(in_b) };
154
155 let x: i32 = lut_r.x;
156 let y: i32 = lut_g.x;
157 let z: i32 = lut_b.x;
158
159 let x_n: i32 = lut_r.x_n;
160 let y_n: i32 = lut_g.x_n;
161 let z_n: i32 = lut_b.x_n;
162
163 let rx = lut_r.w;
164 let ry = lut_g.w;
165 let rz = lut_b.w;
166
167 let c0 = r.fetch(x, y, z);
168
169 let c2;
170 let c1;
171 let c3;
172 if rx >= ry {
173 if ry >= rz {
174 c1 = r.fetch(x_n, y, z) - c0;
176 c2 = r.fetch(x_n, y_n, z) - r.fetch(x_n, y, z);
177 c3 = r.fetch(x_n, y_n, z_n) - r.fetch(x_n, y_n, z);
178 } else if rx >= rz {
179 c1 = r.fetch(x_n, y, z) - c0;
181 c2 = r.fetch(x_n, y_n, z_n) - r.fetch(x_n, y, z_n);
182 c3 = r.fetch(x_n, y, z_n) - r.fetch(x_n, y, z);
183 } else {
184 c1 = r.fetch(x_n, y, z_n) - r.fetch(x, y, z_n);
186 c2 = r.fetch(x_n, y_n, z_n) - r.fetch(x_n, y, z_n);
187 c3 = r.fetch(x, y, z_n) - c0;
188 }
189 } else if rx >= rz {
190 c1 = r.fetch(x_n, y_n, z) - r.fetch(x, y_n, z);
192 c2 = r.fetch(x, y_n, z) - c0;
193 c3 = r.fetch(x_n, y_n, z_n) - r.fetch(x_n, y_n, z);
194 } else if ry >= rz {
195 c1 = r.fetch(x_n, y_n, z_n) - r.fetch(x, y_n, z_n);
197 c2 = r.fetch(x, y_n, z) - c0;
198 c3 = r.fetch(x, y_n, z_n) - r.fetch(x, y_n, z);
199 } else {
200 c1 = r.fetch(x_n, y_n, z_n) - r.fetch(x, y_n, z_n);
202 c2 = r.fetch(x, y_n, z_n) - r.fetch(x, y, z_n);
203 c3 = r.fetch(x, y, z_n) - c0;
204 }
205 let s0 = c0.mla(c1, SseVector::from(rx));
206 let s1 = s0.mla(c2, SseVector::from(ry));
207 s1.mla(c3, SseVector::from(rz))
208 }
209}
210
211macro_rules! define_inter_sse {
212 ($interpolator: ident) => {
213 impl<const GRID_SIZE: usize> SseMdInterpolation for $interpolator<GRID_SIZE> {
214 fn inter3_sse(
215 &self,
216 table: &[SseAlignedF32],
217 in_r: usize,
218 in_g: usize,
219 in_b: usize,
220 lut: &[BarycentricWeight<f32>],
221 ) -> SseVector {
222 unsafe {
223 self.interpolate(
224 in_r,
225 in_g,
226 in_b,
227 lut,
228 TetrahedralSseFetchVector::<GRID_SIZE> { cube: table },
229 )
230 }
231 }
232 }
233 };
234}
235
236#[cfg(feature = "options")]
237define_inter_sse!(TetrahedralSse);
238#[cfg(feature = "options")]
239define_inter_sse!(PyramidalSse);
240#[cfg(feature = "options")]
241define_inter_sse!(PrismaticSse);
242define_inter_sse!(TrilinearSse);
243
244#[cfg(feature = "options")]
245impl<const GRID_SIZE: usize> PyramidalSse<GRID_SIZE> {
246 #[target_feature(enable = "sse4.1")]
247 unsafe fn interpolate(
248 &self,
249 in_r: usize,
250 in_g: usize,
251 in_b: usize,
252 lut: &[BarycentricWeight<f32>],
253 r: impl Fetcher<SseVector>,
254 ) -> SseVector {
255 let lut_r = unsafe { *lut.get_unchecked(in_r) };
256 let lut_g = unsafe { *lut.get_unchecked(in_g) };
257 let lut_b = unsafe { *lut.get_unchecked(in_b) };
258
259 let x: i32 = lut_r.x;
260 let y: i32 = lut_g.x;
261 let z: i32 = lut_b.x;
262
263 let x_n: i32 = lut_r.x_n;
264 let y_n: i32 = lut_g.x_n;
265 let z_n: i32 = lut_b.x_n;
266
267 let dr = lut_r.w;
268 let dg = lut_g.w;
269 let db = lut_b.w;
270
271 let c0 = r.fetch(x, y, z);
272
273 if dr > db && dg > db {
274 let x0 = r.fetch(x_n, y_n, z_n);
275 let x1 = r.fetch(x_n, y_n, z);
276 let x2 = r.fetch(x_n, y, z);
277 let x3 = r.fetch(x, y_n, z);
278
279 let c1 = x0 - x1;
280 let c2 = x2 - c0;
281 let c3 = x3 - c0;
282 let c4 = c0 - x3 - x2 + x1;
283
284 let s0 = c0.mla(c1, SseVector::from(db));
285 let s1 = s0.mla(c2, SseVector::from(dr));
286 let s2 = s1.mla(c3, SseVector::from(dg));
287 s2.mla(c4, SseVector::from(dr * dg))
288 } else if db > dr && dg > dr {
289 let x0 = r.fetch(x, y, z_n);
290 let x1 = r.fetch(x_n, y_n, z_n);
291 let x2 = r.fetch(x, y_n, z_n);
292 let x3 = r.fetch(x, y_n, z);
293
294 let c1 = x0 - c0;
295 let c2 = x1 - x2;
296 let c3 = x3 - c0;
297 let c4 = c0 - x3 - x0 + x2;
298
299 let s0 = c0.mla(c1, SseVector::from(db));
300 let s1 = s0.mla(c2, SseVector::from(dr));
301 let s2 = s1.mla(c3, SseVector::from(dg));
302 s2.mla(c4, SseVector::from(dg * db))
303 } else {
304 let x0 = r.fetch(x, y, z_n);
305 let x1 = r.fetch(x_n, y, z);
306 let x2 = r.fetch(x_n, y, z_n);
307 let x3 = r.fetch(x_n, y_n, z_n);
308
309 let c1 = x0 - c0;
310 let c2 = x1 - c0;
311 let c3 = x3 - x2;
312 let c4 = c0 - x1 - x0 + x2;
313
314 let s0 = c0.mla(c1, SseVector::from(db));
315 let s1 = s0.mla(c2, SseVector::from(dr));
316 let s2 = s1.mla(c3, SseVector::from(dg));
317 s2.mla(c4, SseVector::from(db * dr))
318 }
319 }
320}
321
322#[cfg(feature = "options")]
323impl<const GRID_SIZE: usize> PrismaticSse<GRID_SIZE> {
324 #[target_feature(enable = "sse4.1")]
325 unsafe fn interpolate(
326 &self,
327 in_r: usize,
328 in_g: usize,
329 in_b: usize,
330 lut: &[BarycentricWeight<f32>],
331 r: impl Fetcher<SseVector>,
332 ) -> SseVector {
333 let lut_r = unsafe { *lut.get_unchecked(in_r) };
334 let lut_g = unsafe { *lut.get_unchecked(in_g) };
335 let lut_b = unsafe { *lut.get_unchecked(in_b) };
336
337 let x: i32 = lut_r.x;
338 let y: i32 = lut_g.x;
339 let z: i32 = lut_b.x;
340
341 let x_n: i32 = lut_r.x_n;
342 let y_n: i32 = lut_g.x_n;
343 let z_n: i32 = lut_b.x_n;
344
345 let dr = lut_r.w;
346 let dg = lut_g.w;
347 let db = lut_b.w;
348
349 let c0 = r.fetch(x, y, z);
350
351 if db > dr {
352 let x0 = r.fetch(x, y, z_n);
353 let x1 = r.fetch(x_n, y, z_n);
354 let x2 = r.fetch(x, y_n, z);
355 let x3 = r.fetch(x, y_n, z_n);
356 let x4 = r.fetch(x_n, y_n, z_n);
357
358 let c1 = x0 - c0;
359 let c2 = x1 - x0;
360 let c3 = x2 - c0;
361 let c4 = c0 - x2 - x0 + x3;
362 let c5 = x0 - x3 - x1 + x4;
363
364 let s0 = c0.mla(c1, SseVector::from(db));
365 let s1 = s0.mla(c2, SseVector::from(dr));
366 let s2 = s1.mla(c3, SseVector::from(dg));
367 let s3 = s2.mla(c4, SseVector::from(dg * db));
368 s3.mla(c5, SseVector::from(dr * dg))
369 } else {
370 let x0 = r.fetch(x_n, y, z);
371 let x1 = r.fetch(x_n, y, z_n);
372 let x2 = r.fetch(x, y_n, z);
373 let x3 = r.fetch(x_n, y_n, z);
374 let x4 = r.fetch(x_n, y_n, z_n);
375
376 let c1 = x1 - x0;
377 let c2 = x0 - c0;
378 let c3 = x2 - c0;
379 let c4 = x0 - x3 - x1 + x4;
380 let c5 = c0 - x2 - x0 + x3;
381
382 let s0 = c0.mla(c1, SseVector::from(db));
383 let s1 = s0.mla(c2, SseVector::from(dr));
384 let s2 = s1.mla(c3, SseVector::from(dg));
385 let s3 = s2.mla(c4, SseVector::from(dg * db));
386 s3.mla(c5, SseVector::from(dr * dg))
387 }
388 }
389}
390
391impl<const GRID_SIZE: usize> TrilinearSse<GRID_SIZE> {
392 #[target_feature(enable = "sse4.1")]
393 unsafe fn interpolate(
394 &self,
395 in_r: usize,
396 in_g: usize,
397 in_b: usize,
398 lut: &[BarycentricWeight<f32>],
399 r: impl Fetcher<SseVector>,
400 ) -> SseVector {
401 let lut_r = unsafe { *lut.get_unchecked(in_r) };
402 let lut_g = unsafe { *lut.get_unchecked(in_g) };
403 let lut_b = unsafe { *lut.get_unchecked(in_b) };
404
405 let x: i32 = lut_r.x;
406 let y: i32 = lut_g.x;
407 let z: i32 = lut_b.x;
408
409 let x_n: i32 = lut_r.x_n;
410 let y_n: i32 = lut_g.x_n;
411 let z_n: i32 = lut_b.x_n;
412
413 let dr = lut_r.w;
414 let dg = lut_g.w;
415 let db = lut_b.w;
416
417 let w0 = SseVector::from(dr);
418 let w1 = SseVector::from(dg);
419 let w2 = SseVector::from(db);
420
421 let c000 = r.fetch(x, y, z);
422 let c100 = r.fetch(x_n, y, z);
423 let c010 = r.fetch(x, y_n, z);
424 let c110 = r.fetch(x_n, y_n, z);
425 let c001 = r.fetch(x, y, z_n);
426 let c101 = r.fetch(x_n, y, z_n);
427 let c011 = r.fetch(x, y_n, z_n);
428 let c111 = r.fetch(x_n, y_n, z_n);
429
430 let dx = SseVector::from(1.0 - dr);
431
432 let c00 = (c000 * dx).mla(c100, w0);
433 let c10 = (c010 * dx).mla(c110, w0);
434 let c01 = (c001 * dx).mla(c101, w0);
435 let c11 = (c011 * dx).mla(c111, w0);
436
437 let dy = SseVector::from(1.0 - dg);
438
439 let c0 = (c00 * dy).mla(c10, w1);
440 let c1 = (c01 * dy).mla(c11, w1);
441
442 let dz = SseVector::from(1.0 - db);
443
444 (c0 * dz).mla(c1, w2)
445 }
446}