Skip to main content

moxcms/conversions/sse/
interpolator.rs

1/*
2 * // Copyright (c) Radzivon Bartoshyk 3/2025. All rights reserved.
3 * //
4 * // Redistribution and use in source and binary forms, with or without modification,
5 * // are permitted provided that the following conditions are met:
6 * //
7 * // 1.  Redistributions of source code must retain the above copyright notice, this
8 * // list of conditions and the following disclaimer.
9 * //
10 * // 2.  Redistributions in binary form must reproduce the above copyright notice,
11 * // this list of conditions and the following disclaimer in the documentation
12 * // and/or other materials provided with the distribution.
13 * //
14 * // 3.  Neither the name of the copyright holder nor the names of its
15 * // contributors may be used to endorse or promote products derived from
16 * // this software without specific prior written permission.
17 * //
18 * // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 * // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 * // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 * // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 * // FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 * // DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 * // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 * // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 * // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 * // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 */
29#![cfg(feature = "sse_luts")]
30use crate::conversions::interpolator::BarycentricWeight;
31use crate::math::FusedMultiplyAdd;
32#[cfg(target_arch = "x86")]
33use std::arch::x86::*;
34#[cfg(target_arch = "x86_64")]
35use std::arch::x86_64::*;
36use std::ops::{Add, Mul, Sub};
37
38#[repr(align(16), C)]
39pub(crate) struct SseAlignedF32(pub(crate) [f32; 4]);
40
41#[cfg(feature = "options")]
42pub(crate) struct TetrahedralSse<const GRID_SIZE: usize> {}
43
44#[cfg(feature = "options")]
45pub(crate) struct PyramidalSse<const GRID_SIZE: usize> {}
46
47#[cfg(feature = "options")]
48pub(crate) struct PrismaticSse<const GRID_SIZE: usize> {}
49
50pub(crate) struct TrilinearSse<const GRID_SIZE: usize> {}
51
52trait Fetcher<T> {
53    fn fetch(&self, x: i32, y: i32, z: i32) -> T;
54}
55
56#[derive(Copy, Clone)]
57#[repr(transparent)]
58pub(crate) struct SseVector {
59    pub(crate) v: __m128,
60}
61
62impl From<f32> for SseVector {
63    #[inline(always)]
64    fn from(v: f32) -> Self {
65        SseVector {
66            v: unsafe { _mm_set1_ps(v) },
67        }
68    }
69}
70
71impl Sub<SseVector> for SseVector {
72    type Output = Self;
73    #[inline(always)]
74    fn sub(self, rhs: SseVector) -> Self::Output {
75        SseVector {
76            v: unsafe { _mm_sub_ps(self.v, rhs.v) },
77        }
78    }
79}
80
81impl Add<SseVector> for SseVector {
82    type Output = Self;
83    #[inline(always)]
84    fn add(self, rhs: SseVector) -> Self::Output {
85        SseVector {
86            v: unsafe { _mm_add_ps(self.v, rhs.v) },
87        }
88    }
89}
90
91impl Mul<SseVector> for SseVector {
92    type Output = Self;
93    #[inline(always)]
94    fn mul(self, rhs: SseVector) -> Self::Output {
95        SseVector {
96            v: unsafe { _mm_mul_ps(self.v, rhs.v) },
97        }
98    }
99}
100
101impl FusedMultiplyAdd<SseVector> for SseVector {
102    #[inline(always)]
103    fn mla(&self, b: SseVector, c: SseVector) -> SseVector {
104        SseVector {
105            v: unsafe { _mm_add_ps(self.v, _mm_mul_ps(b.v, c.v)) },
106        }
107    }
108}
109
110struct TetrahedralSseFetchVector<'a, const GRID_SIZE: usize> {
111    cube: &'a [SseAlignedF32],
112}
113
114/// LUT size here is always fixed size (GRID_SIZE^3) and its use
115/// is hardened at [crate::conversions::sse::assert_barycentric_lut_size_precondition].
116impl<const GRID_SIZE: usize> Fetcher<SseVector> for TetrahedralSseFetchVector<'_, GRID_SIZE> {
117    #[inline(always)]
118    fn fetch(&self, x: i32, y: i32, z: i32) -> SseVector {
119        let offset = (x as u32 * (GRID_SIZE as u32 * GRID_SIZE as u32)
120            + y as u32 * GRID_SIZE as u32
121            + z as u32) as usize;
122        let jx = unsafe { self.cube.get_unchecked(offset..) };
123        SseVector {
124            v: unsafe { _mm_load_ps(jx.as_ptr() as *const _) },
125        }
126    }
127}
128
129pub(crate) trait SseMdInterpolation {
130    fn inter3_sse(
131        &self,
132        table: &[SseAlignedF32],
133        in_r: usize,
134        in_g: usize,
135        in_b: usize,
136        lut: &[BarycentricWeight<f32>],
137    ) -> SseVector;
138}
139
140#[cfg(feature = "options")]
141impl<const GRID_SIZE: usize> TetrahedralSse<GRID_SIZE> {
142    #[target_feature(enable = "sse4.1")]
143    unsafe fn interpolate(
144        &self,
145        in_r: usize,
146        in_g: usize,
147        in_b: usize,
148        lut: &[BarycentricWeight<f32>],
149        r: impl Fetcher<SseVector>,
150    ) -> SseVector {
151        let lut_r = unsafe { *lut.get_unchecked(in_r) };
152        let lut_g = unsafe { *lut.get_unchecked(in_g) };
153        let lut_b = unsafe { *lut.get_unchecked(in_b) };
154
155        let x: i32 = lut_r.x;
156        let y: i32 = lut_g.x;
157        let z: i32 = lut_b.x;
158
159        let x_n: i32 = lut_r.x_n;
160        let y_n: i32 = lut_g.x_n;
161        let z_n: i32 = lut_b.x_n;
162
163        let rx = lut_r.w;
164        let ry = lut_g.w;
165        let rz = lut_b.w;
166
167        let c0 = r.fetch(x, y, z);
168
169        let c2;
170        let c1;
171        let c3;
172        if rx >= ry {
173            if ry >= rz {
174                //rx >= ry && ry >= rz
175                c1 = r.fetch(x_n, y, z) - c0;
176                c2 = r.fetch(x_n, y_n, z) - r.fetch(x_n, y, z);
177                c3 = r.fetch(x_n, y_n, z_n) - r.fetch(x_n, y_n, z);
178            } else if rx >= rz {
179                //rx >= rz && rz >= ry
180                c1 = r.fetch(x_n, y, z) - c0;
181                c2 = r.fetch(x_n, y_n, z_n) - r.fetch(x_n, y, z_n);
182                c3 = r.fetch(x_n, y, z_n) - r.fetch(x_n, y, z);
183            } else {
184                //rz > rx && rx >= ry
185                c1 = r.fetch(x_n, y, z_n) - r.fetch(x, y, z_n);
186                c2 = r.fetch(x_n, y_n, z_n) - r.fetch(x_n, y, z_n);
187                c3 = r.fetch(x, y, z_n) - c0;
188            }
189        } else if rx >= rz {
190            //ry > rx && rx >= rz
191            c1 = r.fetch(x_n, y_n, z) - r.fetch(x, y_n, z);
192            c2 = r.fetch(x, y_n, z) - c0;
193            c3 = r.fetch(x_n, y_n, z_n) - r.fetch(x_n, y_n, z);
194        } else if ry >= rz {
195            //ry >= rz && rz > rx
196            c1 = r.fetch(x_n, y_n, z_n) - r.fetch(x, y_n, z_n);
197            c2 = r.fetch(x, y_n, z) - c0;
198            c3 = r.fetch(x, y_n, z_n) - r.fetch(x, y_n, z);
199        } else {
200            //rz > ry && ry > rx
201            c1 = r.fetch(x_n, y_n, z_n) - r.fetch(x, y_n, z_n);
202            c2 = r.fetch(x, y_n, z_n) - r.fetch(x, y, z_n);
203            c3 = r.fetch(x, y, z_n) - c0;
204        }
205        let s0 = c0.mla(c1, SseVector::from(rx));
206        let s1 = s0.mla(c2, SseVector::from(ry));
207        s1.mla(c3, SseVector::from(rz))
208    }
209}
210
211macro_rules! define_inter_sse {
212    ($interpolator: ident) => {
213        impl<const GRID_SIZE: usize> SseMdInterpolation for $interpolator<GRID_SIZE> {
214            fn inter3_sse(
215                &self,
216                table: &[SseAlignedF32],
217                in_r: usize,
218                in_g: usize,
219                in_b: usize,
220                lut: &[BarycentricWeight<f32>],
221            ) -> SseVector {
222                unsafe {
223                    self.interpolate(
224                        in_r,
225                        in_g,
226                        in_b,
227                        lut,
228                        TetrahedralSseFetchVector::<GRID_SIZE> { cube: table },
229                    )
230                }
231            }
232        }
233    };
234}
235
236#[cfg(feature = "options")]
237define_inter_sse!(TetrahedralSse);
238#[cfg(feature = "options")]
239define_inter_sse!(PyramidalSse);
240#[cfg(feature = "options")]
241define_inter_sse!(PrismaticSse);
242define_inter_sse!(TrilinearSse);
243
244#[cfg(feature = "options")]
245impl<const GRID_SIZE: usize> PyramidalSse<GRID_SIZE> {
246    #[target_feature(enable = "sse4.1")]
247    unsafe fn interpolate(
248        &self,
249        in_r: usize,
250        in_g: usize,
251        in_b: usize,
252        lut: &[BarycentricWeight<f32>],
253        r: impl Fetcher<SseVector>,
254    ) -> SseVector {
255        let lut_r = unsafe { *lut.get_unchecked(in_r) };
256        let lut_g = unsafe { *lut.get_unchecked(in_g) };
257        let lut_b = unsafe { *lut.get_unchecked(in_b) };
258
259        let x: i32 = lut_r.x;
260        let y: i32 = lut_g.x;
261        let z: i32 = lut_b.x;
262
263        let x_n: i32 = lut_r.x_n;
264        let y_n: i32 = lut_g.x_n;
265        let z_n: i32 = lut_b.x_n;
266
267        let dr = lut_r.w;
268        let dg = lut_g.w;
269        let db = lut_b.w;
270
271        let c0 = r.fetch(x, y, z);
272
273        if dr > db && dg > db {
274            let x0 = r.fetch(x_n, y_n, z_n);
275            let x1 = r.fetch(x_n, y_n, z);
276            let x2 = r.fetch(x_n, y, z);
277            let x3 = r.fetch(x, y_n, z);
278
279            let c1 = x0 - x1;
280            let c2 = x2 - c0;
281            let c3 = x3 - c0;
282            let c4 = c0 - x3 - x2 + x1;
283
284            let s0 = c0.mla(c1, SseVector::from(db));
285            let s1 = s0.mla(c2, SseVector::from(dr));
286            let s2 = s1.mla(c3, SseVector::from(dg));
287            s2.mla(c4, SseVector::from(dr * dg))
288        } else if db > dr && dg > dr {
289            let x0 = r.fetch(x, y, z_n);
290            let x1 = r.fetch(x_n, y_n, z_n);
291            let x2 = r.fetch(x, y_n, z_n);
292            let x3 = r.fetch(x, y_n, z);
293
294            let c1 = x0 - c0;
295            let c2 = x1 - x2;
296            let c3 = x3 - c0;
297            let c4 = c0 - x3 - x0 + x2;
298
299            let s0 = c0.mla(c1, SseVector::from(db));
300            let s1 = s0.mla(c2, SseVector::from(dr));
301            let s2 = s1.mla(c3, SseVector::from(dg));
302            s2.mla(c4, SseVector::from(dg * db))
303        } else {
304            let x0 = r.fetch(x, y, z_n);
305            let x1 = r.fetch(x_n, y, z);
306            let x2 = r.fetch(x_n, y, z_n);
307            let x3 = r.fetch(x_n, y_n, z_n);
308
309            let c1 = x0 - c0;
310            let c2 = x1 - c0;
311            let c3 = x3 - x2;
312            let c4 = c0 - x1 - x0 + x2;
313
314            let s0 = c0.mla(c1, SseVector::from(db));
315            let s1 = s0.mla(c2, SseVector::from(dr));
316            let s2 = s1.mla(c3, SseVector::from(dg));
317            s2.mla(c4, SseVector::from(db * dr))
318        }
319    }
320}
321
322#[cfg(feature = "options")]
323impl<const GRID_SIZE: usize> PrismaticSse<GRID_SIZE> {
324    #[target_feature(enable = "sse4.1")]
325    unsafe fn interpolate(
326        &self,
327        in_r: usize,
328        in_g: usize,
329        in_b: usize,
330        lut: &[BarycentricWeight<f32>],
331        r: impl Fetcher<SseVector>,
332    ) -> SseVector {
333        let lut_r = unsafe { *lut.get_unchecked(in_r) };
334        let lut_g = unsafe { *lut.get_unchecked(in_g) };
335        let lut_b = unsafe { *lut.get_unchecked(in_b) };
336
337        let x: i32 = lut_r.x;
338        let y: i32 = lut_g.x;
339        let z: i32 = lut_b.x;
340
341        let x_n: i32 = lut_r.x_n;
342        let y_n: i32 = lut_g.x_n;
343        let z_n: i32 = lut_b.x_n;
344
345        let dr = lut_r.w;
346        let dg = lut_g.w;
347        let db = lut_b.w;
348
349        let c0 = r.fetch(x, y, z);
350
351        if db > dr {
352            let x0 = r.fetch(x, y, z_n);
353            let x1 = r.fetch(x_n, y, z_n);
354            let x2 = r.fetch(x, y_n, z);
355            let x3 = r.fetch(x, y_n, z_n);
356            let x4 = r.fetch(x_n, y_n, z_n);
357
358            let c1 = x0 - c0;
359            let c2 = x1 - x0;
360            let c3 = x2 - c0;
361            let c4 = c0 - x2 - x0 + x3;
362            let c5 = x0 - x3 - x1 + x4;
363
364            let s0 = c0.mla(c1, SseVector::from(db));
365            let s1 = s0.mla(c2, SseVector::from(dr));
366            let s2 = s1.mla(c3, SseVector::from(dg));
367            let s3 = s2.mla(c4, SseVector::from(dg * db));
368            s3.mla(c5, SseVector::from(dr * dg))
369        } else {
370            let x0 = r.fetch(x_n, y, z);
371            let x1 = r.fetch(x_n, y, z_n);
372            let x2 = r.fetch(x, y_n, z);
373            let x3 = r.fetch(x_n, y_n, z);
374            let x4 = r.fetch(x_n, y_n, z_n);
375
376            let c1 = x1 - x0;
377            let c2 = x0 - c0;
378            let c3 = x2 - c0;
379            let c4 = x0 - x3 - x1 + x4;
380            let c5 = c0 - x2 - x0 + x3;
381
382            let s0 = c0.mla(c1, SseVector::from(db));
383            let s1 = s0.mla(c2, SseVector::from(dr));
384            let s2 = s1.mla(c3, SseVector::from(dg));
385            let s3 = s2.mla(c4, SseVector::from(dg * db));
386            s3.mla(c5, SseVector::from(dr * dg))
387        }
388    }
389}
390
391impl<const GRID_SIZE: usize> TrilinearSse<GRID_SIZE> {
392    #[target_feature(enable = "sse4.1")]
393    unsafe fn interpolate(
394        &self,
395        in_r: usize,
396        in_g: usize,
397        in_b: usize,
398        lut: &[BarycentricWeight<f32>],
399        r: impl Fetcher<SseVector>,
400    ) -> SseVector {
401        let lut_r = unsafe { *lut.get_unchecked(in_r) };
402        let lut_g = unsafe { *lut.get_unchecked(in_g) };
403        let lut_b = unsafe { *lut.get_unchecked(in_b) };
404
405        let x: i32 = lut_r.x;
406        let y: i32 = lut_g.x;
407        let z: i32 = lut_b.x;
408
409        let x_n: i32 = lut_r.x_n;
410        let y_n: i32 = lut_g.x_n;
411        let z_n: i32 = lut_b.x_n;
412
413        let dr = lut_r.w;
414        let dg = lut_g.w;
415        let db = lut_b.w;
416
417        let w0 = SseVector::from(dr);
418        let w1 = SseVector::from(dg);
419        let w2 = SseVector::from(db);
420
421        let c000 = r.fetch(x, y, z);
422        let c100 = r.fetch(x_n, y, z);
423        let c010 = r.fetch(x, y_n, z);
424        let c110 = r.fetch(x_n, y_n, z);
425        let c001 = r.fetch(x, y, z_n);
426        let c101 = r.fetch(x_n, y, z_n);
427        let c011 = r.fetch(x, y_n, z_n);
428        let c111 = r.fetch(x_n, y_n, z_n);
429
430        let dx = SseVector::from(1.0 - dr);
431
432        let c00 = (c000 * dx).mla(c100, w0);
433        let c10 = (c010 * dx).mla(c110, w0);
434        let c01 = (c001 * dx).mla(c101, w0);
435        let c11 = (c011 * dx).mla(c111, w0);
436
437        let dy = SseVector::from(1.0 - dg);
438
439        let c0 = (c00 * dy).mla(c10, w1);
440        let c1 = (c01 * dy).mla(c11, w1);
441
442        let dz = SseVector::from(1.0 - db);
443
444        (c0 * dz).mla(c1, w2)
445    }
446}