Skip to main content

moxcms/conversions/
mab.rs

1/*
2 * // Copyright (c) Radzivon Bartoshyk 3/2025. All rights reserved.
3 * //
4 * // Redistribution and use in source and binary forms, with or without modification,
5 * // are permitted provided that the following conditions are met:
6 * //
7 * // 1.  Redistributions of source code must retain the above copyright notice, this
8 * // list of conditions and the following disclaimer.
9 * //
10 * // 2.  Redistributions in binary form must reproduce the above copyright notice,
11 * // this list of conditions and the following disclaimer in the documentation
12 * // and/or other materials provided with the distribution.
13 * //
14 * // 3.  Neither the name of the copyright holder nor the names of its
15 * // contributors may be used to endorse or promote products derived from
16 * // this software without specific prior written permission.
17 * //
18 * // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 * // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 * // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 * // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 * // FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 * // DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 * // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 * // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 * // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 * // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 */
29#![cfg(feature = "lut")]
30use crate::mlaf::mlaf;
31use crate::safe_math::SafeMul;
32use crate::{
33    CmsError, Cube, DataColorSpace, InPlaceStage, InterpolationMethod, LutMultidimensionalType,
34    MalformedSize, Matrix3d, Matrix3f, TransformOptions, Vector3d, Vector3f,
35};
36
37#[allow(unused)]
38struct ACurves3<'a> {
39    curve0: Box<[f32; 65536]>,
40    curve1: Box<[f32; 65536]>,
41    curve2: Box<[f32; 65536]>,
42    clut: &'a [f32],
43    grid_size: [u8; 3],
44    interpolation_method: InterpolationMethod,
45    pcs: DataColorSpace,
46    depth: usize,
47}
48
49#[allow(unused)]
50struct ACurves3Optimized<'a> {
51    clut: &'a [f32],
52    grid_size: [u8; 3],
53    interpolation_method: InterpolationMethod,
54    pcs: DataColorSpace,
55}
56
57#[allow(unused)]
58impl ACurves3<'_> {
59    fn transform_impl<Fetch: Fn(f32, f32, f32) -> Vector3f>(
60        &self,
61        dst: &mut [f32],
62        fetch: Fetch,
63    ) -> Result<(), CmsError> {
64        let scale_value = (self.depth - 1) as f32;
65
66        for dst in dst.chunks_exact_mut(3) {
67            let a0 = (dst[0] * scale_value).round().min(scale_value) as u16;
68            let a1 = (dst[1] * scale_value).round().min(scale_value) as u16;
69            let a2 = (dst[2] * scale_value).round().min(scale_value) as u16;
70            let b0 = self.curve0[a0 as usize];
71            let b1 = self.curve1[a1 as usize];
72            let b2 = self.curve2[a2 as usize];
73            let interpolated = fetch(b0, b1, b2);
74            dst[0] = interpolated.v[0];
75            dst[1] = interpolated.v[1];
76            dst[2] = interpolated.v[2];
77        }
78        Ok(())
79    }
80}
81
82#[allow(unused)]
83impl ACurves3Optimized<'_> {
84    fn transform_impl<Fetch: Fn(f32, f32, f32) -> Vector3f>(
85        &self,
86        dst: &mut [f32],
87        fetch: Fetch,
88    ) -> Result<(), CmsError> {
89        for dst in dst.chunks_exact_mut(3) {
90            let a0 = dst[0];
91            let a1 = dst[1];
92            let a2 = dst[2];
93            let interpolated = fetch(a0, a1, a2);
94            dst[0] = interpolated.v[0];
95            dst[1] = interpolated.v[1];
96            dst[2] = interpolated.v[2];
97        }
98        Ok(())
99    }
100}
101
102impl InPlaceStage for ACurves3<'_> {
103    fn transform(&self, dst: &mut [f32]) -> Result<(), CmsError> {
104        let lut = Cube::new_cube(self.clut, self.grid_size, 3)?;
105
106        // If PCS is LAB then linear interpolation should be used
107        if self.pcs == DataColorSpace::Lab || self.pcs == DataColorSpace::Xyz {
108            return self.transform_impl(dst, |x, y, z| lut.trilinear_vec3(x, y, z));
109        }
110
111        match self.interpolation_method {
112            #[cfg(feature = "options")]
113            InterpolationMethod::Tetrahedral => {
114                self.transform_impl(dst, |x, y, z| lut.tetra_vec3(x, y, z))?;
115            }
116            #[cfg(feature = "options")]
117            InterpolationMethod::Pyramid => {
118                self.transform_impl(dst, |x, y, z| lut.pyramid_vec3(x, y, z))?;
119            }
120            #[cfg(feature = "options")]
121            InterpolationMethod::Prism => {
122                self.transform_impl(dst, |x, y, z| lut.prism_vec3(x, y, z))?;
123            }
124            InterpolationMethod::Linear => {
125                self.transform_impl(dst, |x, y, z| lut.trilinear_vec3(x, y, z))?;
126            }
127        }
128        Ok(())
129    }
130}
131
132impl InPlaceStage for ACurves3Optimized<'_> {
133    fn transform(&self, dst: &mut [f32]) -> Result<(), CmsError> {
134        let lut = Cube::new_cube(self.clut, self.grid_size, 3)?;
135
136        // If PCS is LAB then linear interpolation should be used
137        if self.pcs == DataColorSpace::Lab {
138            return self.transform_impl(dst, |x, y, z| lut.trilinear_vec3(x, y, z));
139        }
140
141        match self.interpolation_method {
142            #[cfg(feature = "options")]
143            InterpolationMethod::Tetrahedral => {
144                self.transform_impl(dst, |x, y, z| lut.tetra_vec3(x, y, z))?;
145            }
146            #[cfg(feature = "options")]
147            InterpolationMethod::Pyramid => {
148                self.transform_impl(dst, |x, y, z| lut.pyramid_vec3(x, y, z))?;
149            }
150            #[cfg(feature = "options")]
151            InterpolationMethod::Prism => {
152                self.transform_impl(dst, |x, y, z| lut.prism_vec3(x, y, z))?;
153            }
154            InterpolationMethod::Linear => {
155                self.transform_impl(dst, |x, y, z| lut.trilinear_vec3(x, y, z))?;
156            }
157        }
158        Ok(())
159    }
160}
161
162#[allow(unused)]
163struct ACurves3Inverse<'a> {
164    curve0: Box<[f32; 65536]>,
165    curve1: Box<[f32; 65536]>,
166    curve2: Box<[f32; 65536]>,
167    clut: &'a [f32],
168    grid_size: [u8; 3],
169    interpolation_method: InterpolationMethod,
170    pcs: DataColorSpace,
171    depth: usize,
172}
173
174#[allow(unused)]
175impl ACurves3Inverse<'_> {
176    fn transform_impl<Fetch: Fn(f32, f32, f32) -> Vector3f>(
177        &self,
178        dst: &mut [f32],
179        fetch: Fetch,
180    ) -> Result<(), CmsError> {
181        let scale_value = (self.depth as u32 - 1u32) as f32;
182
183        for dst in dst.chunks_exact_mut(3) {
184            let interpolated = fetch(dst[0], dst[1], dst[2]);
185            let a0 = (interpolated.v[0] * scale_value).round().min(scale_value) as u16;
186            let a1 = (interpolated.v[1] * scale_value).round().min(scale_value) as u16;
187            let a2 = (interpolated.v[2] * scale_value).round().min(scale_value) as u16;
188            let b0 = self.curve0[a0 as usize];
189            let b1 = self.curve1[a1 as usize];
190            let b2 = self.curve2[a2 as usize];
191            dst[0] = b0;
192            dst[1] = b1;
193            dst[2] = b2;
194        }
195        Ok(())
196    }
197}
198
199impl InPlaceStage for ACurves3Inverse<'_> {
200    fn transform(&self, dst: &mut [f32]) -> Result<(), CmsError> {
201        let lut = Cube::new_cube(self.clut, self.grid_size, 3)?;
202
203        // If PCS is LAB then linear interpolation should be used
204        if self.pcs == DataColorSpace::Lab || self.pcs == DataColorSpace::Xyz {
205            return self.transform_impl(dst, |x, y, z| lut.trilinear_vec3(x, y, z));
206        }
207
208        match self.interpolation_method {
209            #[cfg(feature = "options")]
210            InterpolationMethod::Tetrahedral => {
211                self.transform_impl(dst, |x, y, z| lut.tetra_vec3(x, y, z))?;
212            }
213            #[cfg(feature = "options")]
214            InterpolationMethod::Pyramid => {
215                self.transform_impl(dst, |x, y, z| lut.pyramid_vec3(x, y, z))?;
216            }
217            #[cfg(feature = "options")]
218            InterpolationMethod::Prism => {
219                self.transform_impl(dst, |x, y, z| lut.prism_vec3(x, y, z))?;
220            }
221            InterpolationMethod::Linear => {
222                self.transform_impl(dst, |x, y, z| lut.trilinear_vec3(x, y, z))?;
223            }
224        }
225        Ok(())
226    }
227}
228
229pub(crate) struct MCurves3 {
230    pub(crate) curve0: Box<[f32; 65536]>,
231    pub(crate) curve1: Box<[f32; 65536]>,
232    pub(crate) curve2: Box<[f32; 65536]>,
233    pub(crate) matrix: Matrix3f,
234    pub(crate) bias: Vector3f,
235    pub(crate) inverse: bool,
236    pub(crate) depth: usize,
237}
238
239impl MCurves3 {
240    fn execute_matrix_stage(&self, dst: &mut [f32]) {
241        let m = self.matrix;
242        let b = self.bias;
243
244        if !m.test_equality(Matrix3f::IDENTITY) || !b.eq(&Vector3f::default()) {
245            for dst in dst.chunks_exact_mut(3) {
246                let x = dst[0];
247                let y = dst[1];
248                let z = dst[2];
249                dst[0] = mlaf(mlaf(mlaf(b.v[0], x, m.v[0][0]), y, m.v[0][1]), z, m.v[0][2]);
250                dst[1] = mlaf(mlaf(mlaf(b.v[1], x, m.v[1][0]), y, m.v[1][1]), z, m.v[1][2]);
251                dst[2] = mlaf(mlaf(mlaf(b.v[2], x, m.v[2][0]), y, m.v[2][1]), z, m.v[2][2]);
252            }
253        }
254    }
255}
256
257impl InPlaceStage for MCurves3 {
258    fn transform(&self, dst: &mut [f32]) -> Result<(), CmsError> {
259        let scale_value = (self.depth - 1) as f32;
260
261        if self.inverse {
262            self.execute_matrix_stage(dst);
263        }
264
265        for dst in dst.chunks_exact_mut(3) {
266            let a0 = (dst[0] * scale_value).round().min(scale_value) as u16;
267            let a1 = (dst[1] * scale_value).round().min(scale_value) as u16;
268            let a2 = (dst[2] * scale_value).round().min(scale_value) as u16;
269            let b0 = self.curve0[a0 as usize];
270            let b1 = self.curve1[a1 as usize];
271            let b2 = self.curve2[a2 as usize];
272            dst[0] = b0;
273            dst[1] = b1;
274            dst[2] = b2;
275        }
276
277        if !self.inverse {
278            self.execute_matrix_stage(dst);
279        }
280
281        Ok(())
282    }
283}
284
285pub(crate) struct BCurves3<const DEPTH: usize> {
286    pub(crate) curve0: Box<[f32; 65536]>,
287    pub(crate) curve1: Box<[f32; 65536]>,
288    pub(crate) curve2: Box<[f32; 65536]>,
289}
290
291impl<const DEPTH: usize> InPlaceStage for BCurves3<DEPTH> {
292    fn transform(&self, dst: &mut [f32]) -> Result<(), CmsError> {
293        let scale_value = (DEPTH - 1) as f32;
294
295        for dst in dst.chunks_exact_mut(3) {
296            let a0 = (dst[0] * scale_value).round().min(scale_value) as u16;
297            let a1 = (dst[1] * scale_value).round().min(scale_value) as u16;
298            let a2 = (dst[2] * scale_value).round().min(scale_value) as u16;
299            let b0 = self.curve0[a0 as usize];
300            let b1 = self.curve1[a1 as usize];
301            let b2 = self.curve2[a2 as usize];
302            dst[0] = b0;
303            dst[1] = b1;
304            dst[2] = b2;
305        }
306
307        Ok(())
308    }
309}
310
311pub(crate) fn prepare_mab_3x3(
312    mab: &LutMultidimensionalType,
313    lut: &mut [f32],
314    options: TransformOptions,
315    pcs: DataColorSpace,
316) -> Result<(), CmsError> {
317    const LERP_DEPTH: usize = 65536;
318    const BP: usize = 13;
319    const DEPTH: usize = 8192;
320
321    if mab.num_input_channels != 3 || mab.num_output_channels != 3 {
322        return Err(CmsError::UnsupportedProfileConnection);
323    }
324    if mab.a_curves.len() == 3 && mab.clut.is_some() {
325        let clut = &mab.clut.as_ref().map(|x| x.to_clut_f32()).unwrap();
326        let lut_grid = (mab.grid_points[0] as usize)
327            .safe_mul(mab.grid_points[1] as usize)?
328            .safe_mul(mab.grid_points[2] as usize)?
329            .safe_mul(mab.num_output_channels as usize)?;
330        if clut.len() != lut_grid {
331            return Err(CmsError::MalformedCurveLutTable(MalformedSize {
332                size: clut.len(),
333                expected: lut_grid,
334            }));
335        }
336
337        let all_curves_linear = mab.a_curves.iter().all(|curve| curve.is_linear());
338        let grid_size = [mab.grid_points[0], mab.grid_points[1], mab.grid_points[2]];
339
340        if all_curves_linear {
341            let l = ACurves3Optimized {
342                clut,
343                grid_size,
344                interpolation_method: options.interpolation_method,
345                pcs,
346            };
347            l.transform(lut)?;
348        } else {
349            let curves: Result<Vec<_>, _> = mab
350                .a_curves
351                .iter()
352                .map(|c| {
353                    c.build_linearize_table::<u16, LERP_DEPTH, BP>()
354                        .ok_or(CmsError::InvalidTrcCurve)
355                })
356                .collect();
357
358            let [curve0, curve1, curve2] =
359                curves?.try_into().map_err(|_| CmsError::InvalidTrcCurve)?;
360            let l = ACurves3 {
361                curve0,
362                curve1,
363                curve2,
364                clut,
365                grid_size,
366                interpolation_method: options.interpolation_method,
367                pcs,
368                depth: DEPTH,
369            };
370            l.transform(lut)?;
371        }
372    }
373
374    if mab.m_curves.len() == 3 {
375        let all_curves_linear = mab.m_curves.iter().all(|curve| curve.is_linear());
376        if !all_curves_linear
377            || !mab.matrix.test_equality(Matrix3d::IDENTITY)
378            || mab.bias.ne(&Vector3d::default())
379        {
380            let curves: Result<Vec<_>, _> = mab
381                .m_curves
382                .iter()
383                .map(|c| {
384                    c.build_linearize_table::<u16, LERP_DEPTH, BP>()
385                        .ok_or(CmsError::InvalidTrcCurve)
386                })
387                .collect();
388
389            let [curve0, curve1, curve2] =
390                curves?.try_into().map_err(|_| CmsError::InvalidTrcCurve)?;
391            let matrix = mab.matrix.to_f32();
392            let bias: Vector3f = mab.bias.cast();
393            let m_curves = MCurves3 {
394                curve0,
395                curve1,
396                curve2,
397                matrix,
398                bias,
399                inverse: false,
400                depth: DEPTH,
401            };
402            m_curves.transform(lut)?;
403        }
404    }
405
406    if mab.b_curves.len() == 3 {
407        const LERP_DEPTH: usize = 65536;
408        const BP: usize = 13;
409        const DEPTH: usize = 8192;
410        let all_curves_linear = mab.b_curves.iter().all(|curve| curve.is_linear());
411        if !all_curves_linear {
412            let curves: Result<Vec<_>, _> = mab
413                .b_curves
414                .iter()
415                .map(|c| {
416                    c.build_linearize_table::<u16, LERP_DEPTH, BP>()
417                        .ok_or(CmsError::InvalidTrcCurve)
418                })
419                .collect();
420
421            let [curve0, curve1, curve2] =
422                curves?.try_into().map_err(|_| CmsError::InvalidTrcCurve)?;
423
424            let b_curves = BCurves3::<DEPTH> {
425                curve0,
426                curve1,
427                curve2,
428            };
429            b_curves.transform(lut)?;
430        }
431    } else {
432        return Err(CmsError::InvalidAtoBLut);
433    }
434
435    Ok(())
436}
437
438pub(crate) fn prepare_mba_3x3(
439    mab: &LutMultidimensionalType,
440    lut: &mut [f32],
441    options: TransformOptions,
442    pcs: DataColorSpace,
443) -> Result<(), CmsError> {
444    if mab.num_input_channels != 3 || mab.num_output_channels != 3 {
445        return Err(CmsError::UnsupportedProfileConnection);
446    }
447    const LERP_DEPTH: usize = 65536;
448    const BP: usize = 13;
449    const DEPTH: usize = 8192;
450
451    if mab.b_curves.len() == 3 {
452        let all_curves_linear = mab.b_curves.iter().all(|curve| curve.is_linear());
453        if !all_curves_linear {
454            let curves: Result<Vec<_>, _> = mab
455                .b_curves
456                .iter()
457                .map(|c| {
458                    c.build_linearize_table::<u16, LERP_DEPTH, BP>()
459                        .ok_or(CmsError::InvalidTrcCurve)
460                })
461                .collect();
462
463            let [curve0, curve1, curve2] =
464                curves?.try_into().map_err(|_| CmsError::InvalidTrcCurve)?;
465            let b_curves = BCurves3::<DEPTH> {
466                curve0,
467                curve1,
468                curve2,
469            };
470            b_curves.transform(lut)?;
471        }
472    } else {
473        return Err(CmsError::InvalidAtoBLut);
474    }
475
476    if mab.m_curves.len() == 3 {
477        let all_curves_linear = mab.m_curves.iter().all(|curve| curve.is_linear());
478        if !all_curves_linear
479            || !mab.matrix.test_equality(Matrix3d::IDENTITY)
480            || mab.bias.ne(&Vector3d::default())
481        {
482            let curves: Result<Vec<_>, _> = mab
483                .m_curves
484                .iter()
485                .map(|c| {
486                    c.build_linearize_table::<u16, LERP_DEPTH, BP>()
487                        .ok_or(CmsError::InvalidTrcCurve)
488                })
489                .collect();
490
491            let [curve0, curve1, curve2] =
492                curves?.try_into().map_err(|_| CmsError::InvalidTrcCurve)?;
493
494            let matrix = mab.matrix.to_f32();
495            let bias: Vector3f = mab.bias.cast();
496            let m_curves = MCurves3 {
497                curve0,
498                curve1,
499                curve2,
500                matrix,
501                bias,
502                inverse: true,
503                depth: DEPTH,
504            };
505            m_curves.transform(lut)?;
506        }
507    }
508
509    if mab.a_curves.len() == 3 && mab.clut.is_some() {
510        let clut = &mab.clut.as_ref().map(|x| x.to_clut_f32()).unwrap();
511        let lut_grid = (mab.grid_points[0] as usize)
512            .safe_mul(mab.grid_points[1] as usize)?
513            .safe_mul(mab.grid_points[2] as usize)?
514            .safe_mul(mab.num_output_channels as usize)?;
515        if clut.len() != lut_grid {
516            return Err(CmsError::MalformedCurveLutTable(MalformedSize {
517                size: clut.len(),
518                expected: lut_grid,
519            }));
520        }
521
522        let all_curves_linear = mab.a_curves.iter().all(|curve| curve.is_linear());
523        let grid_size = [mab.grid_points[0], mab.grid_points[1], mab.grid_points[2]];
524
525        if all_curves_linear {
526            let l = ACurves3Optimized {
527                clut,
528                grid_size,
529                interpolation_method: options.interpolation_method,
530                pcs,
531            };
532            l.transform(lut)?;
533        } else {
534            let curves: Result<Vec<_>, _> = mab
535                .a_curves
536                .iter()
537                .map(|c| {
538                    c.build_linearize_table::<u16, LERP_DEPTH, BP>()
539                        .ok_or(CmsError::InvalidTrcCurve)
540                })
541                .collect();
542
543            let [curve0, curve1, curve2] =
544                curves?.try_into().map_err(|_| CmsError::InvalidTrcCurve)?;
545            let l = ACurves3Inverse {
546                curve0,
547                curve1,
548                curve2,
549                clut,
550                grid_size,
551                interpolation_method: options.interpolation_method,
552                pcs,
553                depth: DEPTH,
554            };
555            l.transform(lut)?;
556        }
557    }
558
559    Ok(())
560}