Skip to main content

moxcms/conversions/avx/
interpolator.rs

1/*
2 * // Copyright (c) Radzivon Bartoshyk 3/2025. All rights reserved.
3 * //
4 * // Redistribution and use in source and binary forms, with or without modification,
5 * // are permitted provided that the following conditions are met:
6 * //
7 * // 1.  Redistributions of source code must retain the above copyright notice, this
8 * // list of conditions and the following disclaimer.
9 * //
10 * // 2.  Redistributions in binary form must reproduce the above copyright notice,
11 * // this list of conditions and the following disclaimer in the documentation
12 * // and/or other materials provided with the distribution.
13 * //
14 * // 3.  Neither the name of the copyright holder nor the names of its
15 * // contributors may be used to endorse or promote products derived from
16 * // this software without specific prior written permission.
17 * //
18 * // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 * // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 * // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 * // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 * // FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 * // DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 * // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 * // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 * // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 * // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 */
29#![cfg(feature = "avx_luts")]
30use crate::conversions::interpolator::BarycentricWeight;
31use crate::math::{FusedMultiplyAdd, FusedMultiplyNegAdd};
32use std::arch::x86_64::*;
33use std::ops::{Add, Mul, Sub};
34
35#[repr(align(16), C)]
36pub(crate) struct SseAlignedF32(pub(crate) [f32; 4]);
37
38#[cfg(feature = "options")]
39pub(crate) struct TetrahedralAvxFma<const GRID_SIZE: usize> {}
40
41#[cfg(feature = "options")]
42pub(crate) struct PyramidalAvxFma<const GRID_SIZE: usize> {}
43
44#[cfg(feature = "options")]
45pub(crate) struct PrismaticAvxFma<const GRID_SIZE: usize> {}
46
47pub(crate) struct TrilinearAvxFma<const GRID_SIZE: usize> {}
48
49#[cfg(feature = "options")]
50pub(crate) struct PrismaticAvxFmaDouble<const GRID_SIZE: usize> {}
51
52pub(crate) struct TrilinearAvxFmaDouble<const GRID_SIZE: usize> {}
53
54#[cfg(feature = "options")]
55pub(crate) struct PyramidAvxFmaDouble<const GRID_SIZE: usize> {}
56
57#[cfg(feature = "options")]
58pub(crate) struct TetrahedralAvxFmaDouble<const GRID_SIZE: usize> {}
59
60pub(crate) trait AvxMdInterpolationDouble {
61    fn inter3_sse(
62        &self,
63        table0: &[SseAlignedF32],
64        table1: &[SseAlignedF32],
65        in_r: usize,
66        in_g: usize,
67        in_b: usize,
68        lut: &[BarycentricWeight<f32>],
69    ) -> (AvxVectorSse, AvxVectorSse);
70}
71
72pub(crate) trait AvxMdInterpolation {
73    fn inter3_sse(
74        &self,
75        table: &[SseAlignedF32],
76        in_r: usize,
77        in_g: usize,
78        in_b: usize,
79        lut: &[BarycentricWeight<f32>],
80    ) -> AvxVectorSse;
81}
82
83trait Fetcher<T> {
84    fn fetch(&self, x: i32, y: i32, z: i32) -> T;
85}
86
87#[derive(Copy, Clone)]
88#[repr(transparent)]
89pub(crate) struct AvxVectorSse {
90    pub(crate) v: __m128,
91}
92
93#[derive(Copy, Clone)]
94#[repr(transparent)]
95pub(crate) struct AvxVector {
96    pub(crate) v: __m256,
97}
98
99impl AvxVector {
100    #[inline(always)]
101    pub(crate) fn from_sse(lo: AvxVectorSse, hi: AvxVectorSse) -> AvxVector {
102        unsafe {
103            AvxVector {
104                v: _mm256_insertf128_ps::<1>(_mm256_castps128_ps256(lo.v), hi.v),
105            }
106        }
107    }
108
109    #[inline(always)]
110    pub(crate) fn split(self) -> (AvxVectorSse, AvxVectorSse) {
111        unsafe {
112            (
113                AvxVectorSse {
114                    v: _mm256_castps256_ps128(self.v),
115                },
116                AvxVectorSse {
117                    v: _mm256_extractf128_ps::<1>(self.v),
118                },
119            )
120        }
121    }
122}
123
124impl From<f32> for AvxVectorSse {
125    #[inline(always)]
126    fn from(v: f32) -> Self {
127        AvxVectorSse {
128            v: unsafe { _mm_set1_ps(v) },
129        }
130    }
131}
132
133impl From<f32> for AvxVector {
134    #[inline(always)]
135    fn from(v: f32) -> Self {
136        AvxVector {
137            v: unsafe { _mm256_set1_ps(v) },
138        }
139    }
140}
141
142impl Sub<AvxVectorSse> for AvxVectorSse {
143    type Output = Self;
144    #[inline(always)]
145    fn sub(self, rhs: AvxVectorSse) -> Self::Output {
146        AvxVectorSse {
147            v: unsafe { _mm_sub_ps(self.v, rhs.v) },
148        }
149    }
150}
151
152impl Sub<AvxVector> for AvxVector {
153    type Output = Self;
154    #[inline(always)]
155    fn sub(self, rhs: AvxVector) -> Self::Output {
156        AvxVector {
157            v: unsafe { _mm256_sub_ps(self.v, rhs.v) },
158        }
159    }
160}
161
162impl Add<AvxVectorSse> for AvxVectorSse {
163    type Output = Self;
164    #[inline(always)]
165    fn add(self, rhs: AvxVectorSse) -> Self::Output {
166        AvxVectorSse {
167            v: unsafe { _mm_add_ps(self.v, rhs.v) },
168        }
169    }
170}
171
172impl Mul<AvxVectorSse> for AvxVectorSse {
173    type Output = Self;
174    #[inline(always)]
175    fn mul(self, rhs: AvxVectorSse) -> Self::Output {
176        AvxVectorSse {
177            v: unsafe { _mm_mul_ps(self.v, rhs.v) },
178        }
179    }
180}
181
182impl AvxVector {
183    #[inline(always)]
184    pub(crate) fn neg_mla(self, b: AvxVector, c: AvxVector) -> Self {
185        Self {
186            v: unsafe { _mm256_fnmadd_ps(b.v, c.v, self.v) },
187        }
188    }
189}
190
191impl FusedMultiplyNegAdd<AvxVectorSse> for AvxVectorSse {
192    #[inline(always)]
193    fn neg_mla(&self, b: AvxVectorSse, c: AvxVectorSse) -> Self {
194        Self {
195            v: unsafe { _mm_fnmadd_ps(b.v, c.v, self.v) },
196        }
197    }
198}
199
200impl Add<AvxVector> for AvxVector {
201    type Output = Self;
202    #[inline(always)]
203    fn add(self, rhs: AvxVector) -> Self::Output {
204        AvxVector {
205            v: unsafe { _mm256_add_ps(self.v, rhs.v) },
206        }
207    }
208}
209
210impl Mul<AvxVector> for AvxVector {
211    type Output = Self;
212    #[inline(always)]
213    fn mul(self, rhs: AvxVector) -> Self::Output {
214        AvxVector {
215            v: unsafe { _mm256_mul_ps(self.v, rhs.v) },
216        }
217    }
218}
219
220impl FusedMultiplyAdd<AvxVectorSse> for AvxVectorSse {
221    #[inline(always)]
222    fn mla(&self, b: AvxVectorSse, c: AvxVectorSse) -> AvxVectorSse {
223        AvxVectorSse {
224            v: unsafe { _mm_fmadd_ps(b.v, c.v, self.v) },
225        }
226    }
227}
228
229impl FusedMultiplyAdd<AvxVector> for AvxVector {
230    #[inline(always)]
231    fn mla(&self, b: AvxVector, c: AvxVector) -> AvxVector {
232        AvxVector {
233            v: unsafe { _mm256_fmadd_ps(b.v, c.v, self.v) },
234        }
235    }
236}
237
238struct TetrahedralAvxSseFetchVector<'a, const GRID_SIZE: usize> {
239    cube: &'a [SseAlignedF32],
240}
241
242struct TetrahedralAvxFetchVector<'a, const GRID_SIZE: usize> {
243    cube0: &'a [SseAlignedF32],
244    cube1: &'a [SseAlignedF32],
245}
246
247/// LUT size here is always fixed size (GRID_SIZE^3) and its use
248/// is hardened at [crate::conversions::avx::assert_barycentric_lut_size_precondition].
249impl<const GRID_SIZE: usize> Fetcher<AvxVector> for TetrahedralAvxFetchVector<'_, GRID_SIZE> {
250    #[inline(always)]
251    fn fetch(&self, x: i32, y: i32, z: i32) -> AvxVector {
252        let offset = (x as u32 * (GRID_SIZE as u32 * GRID_SIZE as u32)
253            + y as u32 * GRID_SIZE as u32
254            + z as u32) as usize;
255        let jx0 = unsafe { self.cube0.get_unchecked(offset..) };
256        let jx1 = unsafe { self.cube1.get_unchecked(offset..) };
257        AvxVector {
258            v: unsafe {
259                _mm256_insertf128_ps::<1>(
260                    _mm256_castps128_ps256(_mm_load_ps(jx0.as_ptr() as *const f32)),
261                    _mm_load_ps(jx1.as_ptr() as *const f32),
262                )
263            },
264        }
265    }
266}
267
268impl<const GRID_SIZE: usize> Fetcher<AvxVectorSse> for TetrahedralAvxSseFetchVector<'_, GRID_SIZE> {
269    #[inline(always)]
270    fn fetch(&self, x: i32, y: i32, z: i32) -> AvxVectorSse {
271        let offset = (x as u32 * (GRID_SIZE as u32 * GRID_SIZE as u32)
272            + y as u32 * GRID_SIZE as u32
273            + z as u32) as usize;
274        let jx = unsafe { self.cube.get_unchecked(offset..) };
275        AvxVectorSse {
276            v: unsafe { _mm_load_ps(jx.as_ptr() as *const f32) },
277        }
278    }
279}
280
281#[cfg(feature = "options")]
282impl<const GRID_SIZE: usize> TetrahedralAvxFma<GRID_SIZE> {
283    #[target_feature(enable = "avx2", enable = "fma")]
284    unsafe fn interpolate(
285        &self,
286        in_r: usize,
287        in_g: usize,
288        in_b: usize,
289        lut: &[BarycentricWeight<f32>],
290        r: impl Fetcher<AvxVectorSse>,
291    ) -> AvxVectorSse {
292        let lut_r = unsafe { lut.get_unchecked(in_r) };
293        let lut_g = unsafe { lut.get_unchecked(in_g) };
294        let lut_b = unsafe { lut.get_unchecked(in_b) };
295
296        let x: i32 = lut_r.x;
297        let y: i32 = lut_g.x;
298        let z: i32 = lut_b.x;
299
300        let x_n: i32 = lut_r.x_n;
301        let y_n: i32 = lut_g.x_n;
302        let z_n: i32 = lut_b.x_n;
303
304        let rx = lut_r.w;
305        let ry = lut_g.w;
306        let rz = lut_b.w;
307
308        let c0 = r.fetch(x, y, z);
309
310        let c2;
311        let c1;
312        let c3;
313        if rx >= ry {
314            if ry >= rz {
315                //rx >= ry && ry >= rz
316                c1 = r.fetch(x_n, y, z) - c0;
317                c2 = r.fetch(x_n, y_n, z) - r.fetch(x_n, y, z);
318                c3 = r.fetch(x_n, y_n, z_n) - r.fetch(x_n, y_n, z);
319            } else if rx >= rz {
320                //rx >= rz && rz >= ry
321                c1 = r.fetch(x_n, y, z) - c0;
322                c2 = r.fetch(x_n, y_n, z_n) - r.fetch(x_n, y, z_n);
323                c3 = r.fetch(x_n, y, z_n) - r.fetch(x_n, y, z);
324            } else {
325                //rz > rx && rx >= ry
326                c1 = r.fetch(x_n, y, z_n) - r.fetch(x, y, z_n);
327                c2 = r.fetch(x_n, y_n, z_n) - r.fetch(x_n, y, z_n);
328                c3 = r.fetch(x, y, z_n) - c0;
329            }
330        } else if rx >= rz {
331            //ry > rx && rx >= rz
332            c1 = r.fetch(x_n, y_n, z) - r.fetch(x, y_n, z);
333            c2 = r.fetch(x, y_n, z) - c0;
334            c3 = r.fetch(x_n, y_n, z_n) - r.fetch(x_n, y_n, z);
335        } else if ry >= rz {
336            //ry >= rz && rz > rx
337            c1 = r.fetch(x_n, y_n, z_n) - r.fetch(x, y_n, z_n);
338            c2 = r.fetch(x, y_n, z) - c0;
339            c3 = r.fetch(x, y_n, z_n) - r.fetch(x, y_n, z);
340        } else {
341            //rz > ry && ry > rx
342            c1 = r.fetch(x_n, y_n, z_n) - r.fetch(x, y_n, z_n);
343            c2 = r.fetch(x, y_n, z_n) - r.fetch(x, y, z_n);
344            c3 = r.fetch(x, y, z_n) - c0;
345        }
346        let s0 = c0.mla(c1, AvxVectorSse::from(rx));
347        let s1 = s0.mla(c2, AvxVectorSse::from(ry));
348        s1.mla(c3, AvxVectorSse::from(rz))
349    }
350}
351
352macro_rules! define_interp_avx {
353    ($interpolator: ident) => {
354        impl<const GRID_SIZE: usize> AvxMdInterpolation for $interpolator<GRID_SIZE> {
355            fn inter3_sse(
356                &self,
357                table: &[SseAlignedF32],
358                in_r: usize,
359                in_g: usize,
360                in_b: usize,
361                lut: &[BarycentricWeight<f32>],
362            ) -> AvxVectorSse {
363                unsafe {
364                    self.interpolate(
365                        in_r,
366                        in_g,
367                        in_b,
368                        lut,
369                        TetrahedralAvxSseFetchVector::<GRID_SIZE> { cube: table },
370                    )
371                }
372            }
373        }
374    };
375}
376
377#[cfg(feature = "options")]
378macro_rules! define_interp_avx_d {
379    ($interpolator: ident) => {
380        impl<const GRID_SIZE: usize> AvxMdInterpolationDouble for $interpolator<GRID_SIZE> {
381            fn inter3_sse(
382                &self,
383                table0: &[SseAlignedF32],
384                table1: &[SseAlignedF32],
385                in_r: usize,
386                in_g: usize,
387                in_b: usize,
388                lut: &[BarycentricWeight<f32>],
389            ) -> (AvxVectorSse, AvxVectorSse) {
390                unsafe {
391                    self.interpolate(
392                        in_r,
393                        in_g,
394                        in_b,
395                        lut,
396                        TetrahedralAvxSseFetchVector::<GRID_SIZE> { cube: table0 },
397                        TetrahedralAvxSseFetchVector::<GRID_SIZE> { cube: table1 },
398                    )
399                }
400            }
401        }
402    };
403}
404
405#[cfg(feature = "options")]
406define_interp_avx!(TetrahedralAvxFma);
407#[cfg(feature = "options")]
408define_interp_avx!(PyramidalAvxFma);
409#[cfg(feature = "options")]
410define_interp_avx!(PrismaticAvxFma);
411define_interp_avx!(TrilinearAvxFma);
412#[cfg(feature = "options")]
413define_interp_avx_d!(PrismaticAvxFmaDouble);
414#[cfg(feature = "options")]
415define_interp_avx_d!(PyramidAvxFmaDouble);
416
417#[cfg(feature = "options")]
418impl<const GRID_SIZE: usize> AvxMdInterpolationDouble for TetrahedralAvxFmaDouble<GRID_SIZE> {
419    fn inter3_sse(
420        &self,
421        table0: &[SseAlignedF32],
422        table1: &[SseAlignedF32],
423        in_r: usize,
424        in_g: usize,
425        in_b: usize,
426        lut: &[BarycentricWeight<f32>],
427    ) -> (AvxVectorSse, AvxVectorSse) {
428        unsafe {
429            self.interpolate(
430                in_r,
431                in_g,
432                in_b,
433                lut,
434                TetrahedralAvxFetchVector::<GRID_SIZE> {
435                    cube0: table0,
436                    cube1: table1,
437                },
438            )
439        }
440    }
441}
442
443impl<const GRID_SIZE: usize> AvxMdInterpolationDouble for TrilinearAvxFmaDouble<GRID_SIZE> {
444    fn inter3_sse(
445        &self,
446        table0: &[SseAlignedF32],
447        table1: &[SseAlignedF32],
448        in_r: usize,
449        in_g: usize,
450        in_b: usize,
451        lut: &[BarycentricWeight<f32>],
452    ) -> (AvxVectorSse, AvxVectorSse) {
453        unsafe {
454            self.interpolate(
455                in_r,
456                in_g,
457                in_b,
458                lut,
459                TetrahedralAvxFetchVector::<GRID_SIZE> {
460                    cube0: table0,
461                    cube1: table1,
462                },
463            )
464        }
465    }
466}
467
468#[cfg(feature = "options")]
469impl<const GRID_SIZE: usize> PyramidalAvxFma<GRID_SIZE> {
470    #[target_feature(enable = "avx2", enable = "fma")]
471    unsafe fn interpolate(
472        &self,
473        in_r: usize,
474        in_g: usize,
475        in_b: usize,
476        lut: &[BarycentricWeight<f32>],
477        r: impl Fetcher<AvxVectorSse>,
478    ) -> AvxVectorSse {
479        let lut_r = unsafe { lut.get_unchecked(in_r) };
480        let lut_g = unsafe { lut.get_unchecked(in_g) };
481        let lut_b = unsafe { lut.get_unchecked(in_b) };
482
483        let x: i32 = lut_r.x;
484        let y: i32 = lut_g.x;
485        let z: i32 = lut_b.x;
486
487        let x_n: i32 = lut_r.x_n;
488        let y_n: i32 = lut_g.x_n;
489        let z_n: i32 = lut_b.x_n;
490
491        let dr = lut_r.w;
492        let dg = lut_g.w;
493        let db = lut_b.w;
494
495        let c0 = r.fetch(x, y, z);
496
497        let w0 = AvxVectorSse::from(db);
498        let w1 = AvxVectorSse::from(dr);
499        let w2 = AvxVectorSse::from(dg);
500
501        if dr > db && dg > db {
502            let w3 = AvxVectorSse::from(dr * dg);
503            let x0 = r.fetch(x_n, y_n, z_n);
504            let x1 = r.fetch(x_n, y_n, z);
505            let x2 = r.fetch(x_n, y, z);
506            let x3 = r.fetch(x, y_n, z);
507
508            let c1 = x0 - x1;
509            let c2 = x2 - c0;
510            let c3 = x3 - c0;
511            let c4 = c0 - x3 - x2 + x1;
512
513            let s0 = c0.mla(c1, w0);
514            let s1 = s0.mla(c2, w1);
515            let s2 = s1.mla(c3, w2);
516            s2.mla(c4, w3)
517        } else if db > dr && dg > dr {
518            let w3 = AvxVectorSse::from(dg * db);
519
520            let x0 = r.fetch(x, y, z_n);
521            let x1 = r.fetch(x_n, y_n, z_n);
522            let x2 = r.fetch(x, y_n, z_n);
523            let x3 = r.fetch(x, y_n, z);
524
525            let c1 = x0 - c0;
526            let c2 = x1 - x2;
527            let c3 = x3 - c0;
528            let c4 = c0 - x3 - x0 + x2;
529
530            let s0 = c0.mla(c1, w0);
531            let s1 = s0.mla(c2, w1);
532            let s2 = s1.mla(c3, w2);
533            s2.mla(c4, w3)
534        } else {
535            let w3 = AvxVectorSse::from(db * dr);
536
537            let x0 = r.fetch(x, y, z_n);
538            let x1 = r.fetch(x_n, y, z);
539            let x2 = r.fetch(x_n, y, z_n);
540            let x3 = r.fetch(x_n, y_n, z_n);
541
542            let c1 = x0 - c0;
543            let c2 = x1 - c0;
544            let c3 = x3 - x2;
545            let c4 = c0 - x1 - x0 + x2;
546
547            let s0 = c0.mla(c1, w0);
548            let s1 = s0.mla(c2, w1);
549            let s2 = s1.mla(c3, w2);
550            s2.mla(c4, w3)
551        }
552    }
553}
554
555#[cfg(feature = "options")]
556impl<const GRID_SIZE: usize> PrismaticAvxFma<GRID_SIZE> {
557    #[target_feature(enable = "avx2", enable = "fma")]
558    unsafe fn interpolate(
559        &self,
560        in_r: usize,
561        in_g: usize,
562        in_b: usize,
563        lut: &[BarycentricWeight<f32>],
564        r: impl Fetcher<AvxVectorSse>,
565    ) -> AvxVectorSse {
566        let lut_r = unsafe { lut.get_unchecked(in_r) };
567        let lut_g = unsafe { lut.get_unchecked(in_g) };
568        let lut_b = unsafe { lut.get_unchecked(in_b) };
569
570        let x: i32 = lut_r.x;
571        let y: i32 = lut_g.x;
572        let z: i32 = lut_b.x;
573
574        let x_n: i32 = lut_r.x_n;
575        let y_n: i32 = lut_g.x_n;
576        let z_n: i32 = lut_b.x_n;
577
578        let dr = lut_r.w;
579        let dg = lut_g.w;
580        let db = lut_b.w;
581
582        let c0 = r.fetch(x, y, z);
583
584        let w0 = AvxVectorSse::from(db);
585        let w1 = AvxVectorSse::from(dr);
586        let w2 = AvxVectorSse::from(dg);
587        let w3 = AvxVectorSse::from(dg * db);
588        let w4 = AvxVectorSse::from(dr * dg);
589
590        if db > dr {
591            let x0 = r.fetch(x, y, z_n);
592            let x1 = r.fetch(x_n, y, z_n);
593            let x2 = r.fetch(x, y_n, z);
594            let x3 = r.fetch(x, y_n, z_n);
595            let x4 = r.fetch(x_n, y_n, z_n);
596
597            let c1 = x0 - c0;
598            let c2 = x1 - x0;
599            let c3 = x2 - c0;
600            let c4 = c0 - x2 - x0 + x3;
601            let c5 = x0 - x3 - x1 + x4;
602
603            let s0 = c0.mla(c1, w0);
604            let s1 = s0.mla(c2, w1);
605            let s2 = s1.mla(c3, w2);
606            let s3 = s2.mla(c4, w3);
607            s3.mla(c5, w4)
608        } else {
609            let x0 = r.fetch(x_n, y, z);
610            let x1 = r.fetch(x_n, y, z_n);
611            let x2 = r.fetch(x, y_n, z);
612            let x3 = r.fetch(x_n, y_n, z);
613            let x4 = r.fetch(x_n, y_n, z_n);
614
615            let c1 = x1 - x0;
616            let c2 = x0 - c0;
617            let c3 = x2 - c0;
618            let c4 = x0 - x3 - x1 + x4;
619            let c5 = c0 - x2 - x0 + x3;
620
621            let s0 = c0.mla(c1, w0);
622            let s1 = s0.mla(c2, w1);
623            let s2 = s1.mla(c3, w2);
624            let s3 = s2.mla(c4, w3);
625            s3.mla(c5, w4)
626        }
627    }
628}
629
630#[cfg(feature = "options")]
631impl<const GRID_SIZE: usize> PrismaticAvxFmaDouble<GRID_SIZE> {
632    #[target_feature(enable = "avx2", enable = "fma")]
633    unsafe fn interpolate(
634        &self,
635        in_r: usize,
636        in_g: usize,
637        in_b: usize,
638        lut: &[BarycentricWeight<f32>],
639        r0: impl Fetcher<AvxVectorSse>,
640        r1: impl Fetcher<AvxVectorSse>,
641    ) -> (AvxVectorSse, AvxVectorSse) {
642        let lut_r = unsafe { lut.get_unchecked(in_r) };
643        let lut_g = unsafe { lut.get_unchecked(in_g) };
644        let lut_b = unsafe { lut.get_unchecked(in_b) };
645
646        let x: i32 = lut_r.x;
647        let y: i32 = lut_g.x;
648        let z: i32 = lut_b.x;
649
650        let x_n: i32 = lut_r.x_n;
651        let y_n: i32 = lut_g.x_n;
652        let z_n: i32 = lut_b.x_n;
653
654        let dr = lut_r.w;
655        let dg = lut_g.w;
656        let db = lut_b.w;
657
658        let c0_0 = r0.fetch(x, y, z);
659        let c0_1 = r0.fetch(x, y, z);
660
661        let w0 = AvxVector::from(db);
662        let w1 = AvxVector::from(dr);
663        let w2 = AvxVector::from(dg);
664        let w3 = AvxVector::from(dg * db);
665        let w4 = AvxVector::from(dr * dg);
666
667        let c0 = AvxVector::from_sse(c0_0, c0_1);
668
669        if db > dr {
670            let x0_0 = r0.fetch(x, y, z_n);
671            let x1_0 = r0.fetch(x_n, y, z_n);
672            let x2_0 = r0.fetch(x, y_n, z);
673            let x3_0 = r0.fetch(x, y_n, z_n);
674            let x4_0 = r0.fetch(x_n, y_n, z_n);
675
676            let x0_1 = r1.fetch(x, y, z_n);
677            let x1_1 = r1.fetch(x_n, y, z_n);
678            let x2_1 = r1.fetch(x, y_n, z);
679            let x3_1 = r1.fetch(x, y_n, z_n);
680            let x4_1 = r1.fetch(x_n, y_n, z_n);
681
682            let x0 = AvxVector::from_sse(x0_0, x0_1);
683            let x1 = AvxVector::from_sse(x1_0, x1_1);
684            let x2 = AvxVector::from_sse(x2_0, x2_1);
685            let x3 = AvxVector::from_sse(x3_0, x3_1);
686            let x4 = AvxVector::from_sse(x4_0, x4_1);
687
688            let c1 = x0 - c0;
689            let c2 = x1 - x0;
690            let c3 = x2 - c0;
691            let c4 = c0 - x2 - x0 + x3;
692            let c5 = x0 - x3 - x1 + x4;
693
694            let s0 = c0.mla(c1, w0);
695            let s1 = s0.mla(c2, w1);
696            let s2 = s1.mla(c3, w2);
697            let s3 = s2.mla(c4, w3);
698            s3.mla(c5, w4).split()
699        } else {
700            let x0_0 = r0.fetch(x_n, y, z);
701            let x1_0 = r0.fetch(x_n, y, z_n);
702            let x2_0 = r0.fetch(x, y_n, z);
703            let x3_0 = r0.fetch(x_n, y_n, z);
704            let x4_0 = r0.fetch(x_n, y_n, z_n);
705
706            let x0_1 = r1.fetch(x_n, y, z);
707            let x1_1 = r1.fetch(x_n, y, z_n);
708            let x2_1 = r1.fetch(x, y_n, z);
709            let x3_1 = r1.fetch(x_n, y_n, z);
710            let x4_1 = r1.fetch(x_n, y_n, z_n);
711
712            let x0 = AvxVector::from_sse(x0_0, x0_1);
713            let x1 = AvxVector::from_sse(x1_0, x1_1);
714            let x2 = AvxVector::from_sse(x2_0, x2_1);
715            let x3 = AvxVector::from_sse(x3_0, x3_1);
716            let x4 = AvxVector::from_sse(x4_0, x4_1);
717
718            let c1 = x1 - x0;
719            let c2 = x0 - c0;
720            let c3 = x2 - c0;
721            let c4 = x0 - x3 - x1 + x4;
722            let c5 = c0 - x2 - x0 + x3;
723
724            let s0 = c0.mla(c1, w0);
725            let s1 = s0.mla(c2, w1);
726            let s2 = s1.mla(c3, w2);
727            let s3 = s2.mla(c4, w3);
728            s3.mla(c5, w4).split()
729        }
730    }
731}
732
733#[cfg(feature = "options")]
734impl<const GRID_SIZE: usize> PyramidAvxFmaDouble<GRID_SIZE> {
735    #[target_feature(enable = "avx2", enable = "fma")]
736    unsafe fn interpolate(
737        &self,
738        in_r: usize,
739        in_g: usize,
740        in_b: usize,
741        lut: &[BarycentricWeight<f32>],
742        r0: impl Fetcher<AvxVectorSse>,
743        r1: impl Fetcher<AvxVectorSse>,
744    ) -> (AvxVectorSse, AvxVectorSse) {
745        let lut_r = unsafe { lut.get_unchecked(in_r) };
746        let lut_g = unsafe { lut.get_unchecked(in_g) };
747        let lut_b = unsafe { lut.get_unchecked(in_b) };
748
749        let x: i32 = lut_r.x;
750        let y: i32 = lut_g.x;
751        let z: i32 = lut_b.x;
752
753        let x_n: i32 = lut_r.x_n;
754        let y_n: i32 = lut_g.x_n;
755        let z_n: i32 = lut_b.x_n;
756
757        let dr = lut_r.w;
758        let dg = lut_g.w;
759        let db = lut_b.w;
760
761        let c0_0 = r0.fetch(x, y, z);
762        let c0_1 = r1.fetch(x, y, z);
763
764        let w0 = AvxVector::from(db);
765        let w1 = AvxVector::from(dr);
766        let w2 = AvxVector::from(dg);
767
768        let c0 = AvxVector::from_sse(c0_0, c0_1);
769
770        if dr > db && dg > db {
771            let w3 = AvxVector::from(dr * dg);
772
773            let x0_0 = r0.fetch(x_n, y_n, z_n);
774            let x1_0 = r0.fetch(x_n, y_n, z);
775            let x2_0 = r0.fetch(x_n, y, z);
776            let x3_0 = r0.fetch(x, y_n, z);
777
778            let x0_1 = r1.fetch(x_n, y_n, z_n);
779            let x1_1 = r1.fetch(x_n, y_n, z);
780            let x2_1 = r1.fetch(x_n, y, z);
781            let x3_1 = r1.fetch(x, y_n, z);
782
783            let x0 = AvxVector::from_sse(x0_0, x0_1);
784            let x1 = AvxVector::from_sse(x1_0, x1_1);
785            let x2 = AvxVector::from_sse(x2_0, x2_1);
786            let x3 = AvxVector::from_sse(x3_0, x3_1);
787
788            let c1 = x0 - x1;
789            let c2 = x2 - c0;
790            let c3 = x3 - c0;
791            let c4 = c0 - x3 - x2 + x1;
792
793            let s0 = c0.mla(c1, w0);
794            let s1 = s0.mla(c2, w1);
795            let s2 = s1.mla(c3, w2);
796            s2.mla(c4, w3).split()
797        } else if db > dr && dg > dr {
798            let w3 = AvxVector::from(dg * db);
799
800            let x0_0 = r0.fetch(x, y, z_n);
801            let x1_0 = r0.fetch(x_n, y_n, z_n);
802            let x2_0 = r0.fetch(x, y_n, z_n);
803            let x3_0 = r0.fetch(x, y_n, z);
804
805            let x0_1 = r1.fetch(x, y, z_n);
806            let x1_1 = r1.fetch(x_n, y_n, z_n);
807            let x2_1 = r1.fetch(x, y_n, z_n);
808            let x3_1 = r1.fetch(x, y_n, z);
809
810            let x0 = AvxVector::from_sse(x0_0, x0_1);
811            let x1 = AvxVector::from_sse(x1_0, x1_1);
812            let x2 = AvxVector::from_sse(x2_0, x2_1);
813            let x3 = AvxVector::from_sse(x3_0, x3_1);
814
815            let c1 = x0 - c0;
816            let c2 = x1 - x2;
817            let c3 = x3 - c0;
818            let c4 = c0 - x3 - x0 + x2;
819
820            let s0 = c0.mla(c1, w0);
821            let s1 = s0.mla(c2, w1);
822            let s2 = s1.mla(c3, w2);
823            s2.mla(c4, w3).split()
824        } else {
825            let w3 = AvxVector::from(db * dr);
826
827            let x0_0 = r0.fetch(x, y, z_n);
828            let x1_0 = r0.fetch(x_n, y, z);
829            let x2_0 = r0.fetch(x_n, y, z_n);
830            let x3_0 = r0.fetch(x_n, y_n, z_n);
831
832            let x0_1 = r1.fetch(x, y, z_n);
833            let x1_1 = r1.fetch(x_n, y, z);
834            let x2_1 = r1.fetch(x_n, y, z_n);
835            let x3_1 = r1.fetch(x_n, y_n, z_n);
836
837            let x0 = AvxVector::from_sse(x0_0, x0_1);
838            let x1 = AvxVector::from_sse(x1_0, x1_1);
839            let x2 = AvxVector::from_sse(x2_0, x2_1);
840            let x3 = AvxVector::from_sse(x3_0, x3_1);
841
842            let c1 = x0 - c0;
843            let c2 = x1 - c0;
844            let c3 = x3 - x2;
845            let c4 = c0 - x1 - x0 + x2;
846
847            let s0 = c0.mla(c1, w0);
848            let s1 = s0.mla(c2, w1);
849            let s2 = s1.mla(c3, w2);
850            s2.mla(c4, w3).split()
851        }
852    }
853}
854
855#[cfg(feature = "options")]
856impl<const GRID_SIZE: usize> TetrahedralAvxFmaDouble<GRID_SIZE> {
857    #[target_feature(enable = "avx2", enable = "fma")]
858    unsafe fn interpolate(
859        &self,
860        in_r: usize,
861        in_g: usize,
862        in_b: usize,
863        lut: &[BarycentricWeight<f32>],
864        rv: impl Fetcher<AvxVector>,
865    ) -> (AvxVectorSse, AvxVectorSse) {
866        let lut_r = unsafe { lut.get_unchecked(in_r) };
867        let lut_g = unsafe { lut.get_unchecked(in_g) };
868        let lut_b = unsafe { lut.get_unchecked(in_b) };
869
870        let x: i32 = lut_r.x;
871        let y: i32 = lut_g.x;
872        let z: i32 = lut_b.x;
873
874        let x_n: i32 = lut_r.x_n;
875        let y_n: i32 = lut_g.x_n;
876        let z_n: i32 = lut_b.x_n;
877
878        let rx = lut_r.w;
879        let ry = lut_g.w;
880        let rz = lut_b.w;
881
882        let c0 = rv.fetch(x, y, z);
883
884        let w0 = AvxVector::from(rx);
885        let w1 = AvxVector::from(ry);
886        let w2 = AvxVector::from(rz);
887
888        let c2;
889        let c1;
890        let c3;
891        if rx >= ry {
892            if ry >= rz {
893                //rx >= ry && ry >= rz
894                c1 = rv.fetch(x_n, y, z) - c0;
895                c2 = rv.fetch(x_n, y_n, z) - rv.fetch(x_n, y, z);
896                c3 = rv.fetch(x_n, y_n, z_n) - rv.fetch(x_n, y_n, z);
897            } else if rx >= rz {
898                //rx >= rz && rz >= ry
899                c1 = rv.fetch(x_n, y, z) - c0;
900                c2 = rv.fetch(x_n, y_n, z_n) - rv.fetch(x_n, y, z_n);
901                c3 = rv.fetch(x_n, y, z_n) - rv.fetch(x_n, y, z);
902            } else {
903                //rz > rx && rx >= ry
904                c1 = rv.fetch(x_n, y, z_n) - rv.fetch(x, y, z_n);
905                c2 = rv.fetch(x_n, y_n, z_n) - rv.fetch(x_n, y, z_n);
906                c3 = rv.fetch(x, y, z_n) - c0;
907            }
908        } else if rx >= rz {
909            //ry > rx && rx >= rz
910            c1 = rv.fetch(x_n, y_n, z) - rv.fetch(x, y_n, z);
911            c2 = rv.fetch(x, y_n, z) - c0;
912            c3 = rv.fetch(x_n, y_n, z_n) - rv.fetch(x_n, y_n, z);
913        } else if ry >= rz {
914            //ry >= rz && rz > rx
915            c1 = rv.fetch(x_n, y_n, z_n) - rv.fetch(x, y_n, z_n);
916            c2 = rv.fetch(x, y_n, z) - c0;
917            c3 = rv.fetch(x, y_n, z_n) - rv.fetch(x, y_n, z);
918        } else {
919            //rz > ry && ry > rx
920            c1 = rv.fetch(x_n, y_n, z_n) - rv.fetch(x, y_n, z_n);
921            c2 = rv.fetch(x, y_n, z_n) - rv.fetch(x, y, z_n);
922            c3 = rv.fetch(x, y, z_n) - c0;
923        }
924        let s0 = c0.mla(c1, w0);
925        let s1 = s0.mla(c2, w1);
926        s1.mla(c3, w2).split()
927    }
928}
929
930impl<const GRID_SIZE: usize> TrilinearAvxFmaDouble<GRID_SIZE> {
931    #[target_feature(enable = "avx2", enable = "fma")]
932    unsafe fn interpolate(
933        &self,
934        in_r: usize,
935        in_g: usize,
936        in_b: usize,
937        lut: &[BarycentricWeight<f32>],
938        rv: impl Fetcher<AvxVector>,
939    ) -> (AvxVectorSse, AvxVectorSse) {
940        let lut_r = unsafe { lut.get_unchecked(in_r) };
941        let lut_g = unsafe { lut.get_unchecked(in_g) };
942        let lut_b = unsafe { lut.get_unchecked(in_b) };
943
944        let x: i32 = lut_r.x;
945        let y: i32 = lut_g.x;
946        let z: i32 = lut_b.x;
947
948        let x_n: i32 = lut_r.x_n;
949        let y_n: i32 = lut_g.x_n;
950        let z_n: i32 = lut_b.x_n;
951
952        let rx = lut_r.w;
953        let ry = lut_g.w;
954        let rz = lut_b.w;
955
956        let w0 = AvxVector::from(rx);
957        let w1 = AvxVector::from(ry);
958        let w2 = AvxVector::from(rz);
959
960        let c000 = rv.fetch(x, y, z);
961        let c100 = rv.fetch(x_n, y, z);
962        let c010 = rv.fetch(x, y_n, z);
963        let c110 = rv.fetch(x_n, y_n, z);
964        let c001 = rv.fetch(x, y, z_n);
965        let c101 = rv.fetch(x_n, y, z_n);
966        let c011 = rv.fetch(x, y_n, z_n);
967        let c111 = rv.fetch(x_n, y_n, z_n);
968
969        let dx = AvxVector::from(rx);
970
971        let c00 = c000.neg_mla(c000, dx).mla(c100, w0);
972        let c10 = c010.neg_mla(c010, dx).mla(c110, w0);
973        let c01 = c001.neg_mla(c001, dx).mla(c101, w0);
974        let c11 = c011.neg_mla(c011, dx).mla(c111, w0);
975
976        let dy = AvxVector::from(ry);
977
978        let c0 = c00.neg_mla(c00, dy).mla(c10, w1);
979        let c1 = c01.neg_mla(c01, dy).mla(c11, w1);
980
981        let dz = AvxVector::from(rz);
982
983        c0.neg_mla(c0, dz).mla(c1, w2).split()
984    }
985}
986
987impl<const GRID_SIZE: usize> TrilinearAvxFma<GRID_SIZE> {
988    #[target_feature(enable = "avx2", enable = "fma")]
989    unsafe fn interpolate(
990        &self,
991        in_r: usize,
992        in_g: usize,
993        in_b: usize,
994        lut: &[BarycentricWeight<f32>],
995        r: impl Fetcher<AvxVectorSse>,
996    ) -> AvxVectorSse {
997        let lut_r = unsafe { lut.get_unchecked(in_r) };
998        let lut_g = unsafe { lut.get_unchecked(in_g) };
999        let lut_b = unsafe { lut.get_unchecked(in_b) };
1000
1001        let x: i32 = lut_r.x;
1002        let y: i32 = lut_g.x;
1003        let z: i32 = lut_b.x;
1004
1005        let x_n: i32 = lut_r.x_n;
1006        let y_n: i32 = lut_g.x_n;
1007        let z_n: i32 = lut_b.x_n;
1008
1009        let dr = lut_r.w;
1010        let dg = lut_g.w;
1011        let db = lut_b.w;
1012
1013        let w0 = AvxVector::from(dr);
1014        let w1 = AvxVector::from(dg);
1015        let w2 = AvxVectorSse::from(db);
1016
1017        let c000 = r.fetch(x, y, z);
1018        let c100 = r.fetch(x_n, y, z);
1019        let c010 = r.fetch(x, y_n, z);
1020        let c110 = r.fetch(x_n, y_n, z);
1021        let c001 = r.fetch(x, y, z_n);
1022        let c101 = r.fetch(x_n, y, z_n);
1023        let c011 = r.fetch(x, y_n, z_n);
1024        let c111 = r.fetch(x_n, y_n, z_n);
1025
1026        let x000 = AvxVector::from_sse(c000, c001);
1027        let x010 = AvxVector::from_sse(c010, c011);
1028        let x011 = AvxVector::from_sse(c100, c101);
1029        let x111 = AvxVector::from_sse(c110, c111);
1030
1031        let c00 = x000.neg_mla(x000, w0).mla(x011, w0);
1032        let c10 = x010.neg_mla(x010, w0).mla(x111, w0);
1033
1034        let z0 = c00.neg_mla(c00, w1).mla(c10, w1);
1035
1036        let (c0, c1) = z0.split();
1037
1038        c0.neg_mla(c0, w2).mla(c1, w2)
1039    }
1040}