Skip to main content

moxcms/conversions/
lut3x3.rs

1/*
2 * // Copyright (c) Radzivon Bartoshyk 3/2025. All rights reserved.
3 * //
4 * // Redistribution and use in source and binary forms, with or without modification,
5 * // are permitted provided that the following conditions are met:
6 * //
7 * // 1.  Redistributions of source code must retain the above copyright notice, this
8 * // list of conditions and the following disclaimer.
9 * //
10 * // 2.  Redistributions in binary form must reproduce the above copyright notice,
11 * // this list of conditions and the following disclaimer in the documentation
12 * // and/or other materials provided with the distribution.
13 * //
14 * // 3.  Neither the name of the copyright holder nor the names of its
15 * // contributors may be used to endorse or promote products derived from
16 * // this software without specific prior written permission.
17 * //
18 * // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 * // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 * // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 * // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 * // FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 * // DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 * // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 * // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 * // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 * // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 */
29#![cfg(feature = "lut")]
30#[cfg(feature = "any_to_any")]
31use crate::conversions::katana::{KatanaFinalStage, KatanaInitialStage};
32use crate::err::{MalformedSize, try_vec};
33use crate::profile::LutDataType;
34use crate::safe_math::{SafeMul, SafePowi};
35use crate::trc::lut_interp_linear_float;
36use crate::*;
37#[cfg(feature = "any_to_any")]
38use num_traits::AsPrimitive;
39
40#[derive(Default)]
41struct Lut3x3 {
42    input: [Vec<f32>; 3],
43    clut: Vec<f32>,
44    grid_size: u8,
45    gamma: [Vec<f32>; 3],
46    interpolation_method: InterpolationMethod,
47    pcs: DataColorSpace,
48}
49
50#[cfg(feature = "any_to_any")]
51#[derive(Default)]
52struct KatanaLut3x3<T: Copy + Default> {
53    input: [Vec<f32>; 3],
54    clut: Vec<f32>,
55    grid_size: u8,
56    gamma: [Vec<f32>; 3],
57    interpolation_method: InterpolationMethod,
58    pcs: DataColorSpace,
59    _phantom: std::marker::PhantomData<T>,
60    bit_depth: usize,
61}
62
63fn make_lut_3x3(
64    lut: &LutDataType,
65    options: TransformOptions,
66    pcs: DataColorSpace,
67) -> Result<Lut3x3, CmsError> {
68    let clut_length: usize = (lut.num_clut_grid_points as usize)
69        .safe_powi(lut.num_input_channels as u32)?
70        .safe_mul(lut.num_output_channels as usize)?;
71
72    let lin_table = lut.input_table.to_clut_f32();
73
74    if lin_table.len() < lut.num_input_table_entries as usize * 3 {
75        return Err(CmsError::MalformedCurveLutTable(MalformedSize {
76            size: lin_table.len(),
77            expected: lut.num_input_table_entries as usize * 3,
78        }));
79    }
80
81    let lin_curve0 = lin_table[..lut.num_input_table_entries as usize].to_vec();
82    let lin_curve1 = lin_table
83        [lut.num_input_table_entries as usize..lut.num_input_table_entries as usize * 2]
84        .to_vec();
85    let lin_curve2 = lin_table
86        [lut.num_input_table_entries as usize * 2..lut.num_input_table_entries as usize * 3]
87        .to_vec();
88
89    let clut_table = lut.clut_table.to_clut_f32();
90    if clut_table.len() != clut_length {
91        return Err(CmsError::MalformedClut(MalformedSize {
92            size: clut_table.len(),
93            expected: clut_length,
94        }));
95    }
96
97    let gamma_curves = lut.output_table.to_clut_f32();
98
99    if gamma_curves.len() < lut.num_output_table_entries as usize * 3 {
100        return Err(CmsError::MalformedCurveLutTable(MalformedSize {
101            size: gamma_curves.len(),
102            expected: lut.num_output_table_entries as usize * 3,
103        }));
104    }
105
106    let gamma_curve0 = gamma_curves[..lut.num_output_table_entries as usize].to_vec();
107    let gamma_curve1 = gamma_curves
108        [lut.num_output_table_entries as usize..lut.num_output_table_entries as usize * 2]
109        .to_vec();
110    let gamma_curve2 = gamma_curves
111        [lut.num_output_table_entries as usize * 2..lut.num_output_table_entries as usize * 3]
112        .to_vec();
113
114    let transform = Lut3x3 {
115        input: [lin_curve0, lin_curve1, lin_curve2],
116        gamma: [gamma_curve0, gamma_curve1, gamma_curve2],
117        interpolation_method: options.interpolation_method,
118        clut: clut_table,
119        grid_size: lut.num_clut_grid_points,
120        pcs,
121    };
122
123    Ok(transform)
124}
125
126fn stage_lut_3x3(
127    lut: &LutDataType,
128    options: TransformOptions,
129    pcs: DataColorSpace,
130) -> Result<Box<dyn Stage>, CmsError> {
131    let lut = make_lut_3x3(lut, options, pcs)?;
132
133    let transform = Lut3x3 {
134        input: lut.input,
135        gamma: lut.gamma,
136        interpolation_method: lut.interpolation_method,
137        clut: lut.clut,
138        grid_size: lut.grid_size,
139        pcs: lut.pcs,
140    };
141
142    Ok(Box::new(transform))
143}
144
145#[cfg(feature = "any_to_any")]
146pub(crate) fn katana_input_stage_lut_3x3<
147    T: Copy + Default + AsPrimitive<f32> + PointeeSizeExpressible + Send + Sync,
148>(
149    lut: &LutDataType,
150    options: TransformOptions,
151    pcs: DataColorSpace,
152    bit_depth: usize,
153) -> Result<Box<dyn KatanaInitialStage<f32, T> + Send + Sync>, CmsError>
154where
155    f32: AsPrimitive<T>,
156{
157    let lut = make_lut_3x3(lut, options, pcs)?;
158
159    let transform = KatanaLut3x3::<T> {
160        input: lut.input,
161        gamma: lut.gamma,
162        interpolation_method: lut.interpolation_method,
163        clut: lut.clut,
164        grid_size: lut.grid_size,
165        pcs: lut.pcs,
166        _phantom: std::marker::PhantomData,
167        bit_depth,
168    };
169
170    Ok(Box::new(transform))
171}
172
173#[cfg(feature = "any_to_any")]
174pub(crate) fn katana_output_stage_lut_3x3<
175    T: Copy + Default + AsPrimitive<f32> + PointeeSizeExpressible + Send + Sync,
176>(
177    lut: &LutDataType,
178    options: TransformOptions,
179    pcs: DataColorSpace,
180    bit_depth: usize,
181) -> Result<Box<dyn KatanaFinalStage<f32, T> + Send + Sync>, CmsError>
182where
183    f32: AsPrimitive<T>,
184{
185    let lut = make_lut_3x3(lut, options, pcs)?;
186
187    let transform = KatanaLut3x3::<T> {
188        input: lut.input,
189        gamma: lut.gamma,
190        interpolation_method: lut.interpolation_method,
191        clut: lut.clut,
192        grid_size: lut.grid_size,
193        pcs: lut.pcs,
194        _phantom: std::marker::PhantomData,
195        bit_depth,
196    };
197
198    Ok(Box::new(transform))
199}
200
201impl Lut3x3 {
202    fn transform_impl<Fetch: Fn(f32, f32, f32) -> Vector3f>(
203        &self,
204        src: &[f32],
205        dst: &mut [f32],
206        fetch: Fetch,
207    ) -> Result<(), CmsError> {
208        let linearization_0 = &self.input[0];
209        let linearization_1 = &self.input[1];
210        let linearization_2 = &self.input[2];
211        for (dest, src) in dst.chunks_exact_mut(3).zip(src.chunks_exact(3)) {
212            debug_assert!(self.grid_size as i32 >= 1);
213            let linear_x = lut_interp_linear_float(src[0], linearization_0);
214            let linear_y = lut_interp_linear_float(src[1], linearization_1);
215            let linear_z = lut_interp_linear_float(src[2], linearization_2);
216
217            let clut = fetch(linear_x, linear_y, linear_z);
218
219            let pcs_x = lut_interp_linear_float(clut.v[0], &self.gamma[0]);
220            let pcs_y = lut_interp_linear_float(clut.v[1], &self.gamma[1]);
221            let pcs_z = lut_interp_linear_float(clut.v[2], &self.gamma[2]);
222            dest[0] = pcs_x;
223            dest[1] = pcs_y;
224            dest[2] = pcs_z;
225        }
226        Ok(())
227    }
228}
229
230impl Stage for Lut3x3 {
231    fn transform(&self, src: &[f32], dst: &mut [f32]) -> Result<(), CmsError> {
232        let l_tbl = Cube::new(&self.clut, self.grid_size as usize, 3)?;
233
234        // If PCS is LAB then linear interpolation should be used
235        if self.pcs == DataColorSpace::Lab || self.pcs == DataColorSpace::Xyz {
236            return self.transform_impl(src, dst, |x, y, z| l_tbl.trilinear_vec3(x, y, z));
237        }
238
239        match self.interpolation_method {
240            #[cfg(feature = "options")]
241            InterpolationMethod::Tetrahedral => {
242                self.transform_impl(src, dst, |x, y, z| l_tbl.tetra_vec3(x, y, z))?;
243            }
244            #[cfg(feature = "options")]
245            InterpolationMethod::Pyramid => {
246                self.transform_impl(src, dst, |x, y, z| l_tbl.pyramid_vec3(x, y, z))?;
247            }
248            #[cfg(feature = "options")]
249            InterpolationMethod::Prism => {
250                self.transform_impl(src, dst, |x, y, z| l_tbl.prism_vec3(x, y, z))?;
251            }
252            InterpolationMethod::Linear => {
253                self.transform_impl(src, dst, |x, y, z| l_tbl.trilinear_vec3(x, y, z))?;
254            }
255        }
256        Ok(())
257    }
258}
259
260#[cfg(feature = "any_to_any")]
261impl<T: Copy + Default + PointeeSizeExpressible + AsPrimitive<f32>> KatanaLut3x3<T>
262where
263    f32: AsPrimitive<T>,
264{
265    fn to_pcs_impl<Fetch: Fn(f32, f32, f32) -> Vector3f>(
266        &self,
267        input: &[T],
268        fetch: Fetch,
269    ) -> Result<Vec<f32>, CmsError> {
270        if input.len() % 3 != 0 {
271            return Err(CmsError::LaneMultipleOfChannels);
272        }
273        let normalizing_value = if T::FINITE {
274            1.0 / ((1u32 << self.bit_depth) - 1) as f32
275        } else {
276            1.0
277        };
278        let mut dst = try_vec![0.; input.len()];
279        let linearization_0 = &self.input[0];
280        let linearization_1 = &self.input[1];
281        let linearization_2 = &self.input[2];
282        for (dest, src) in dst.chunks_exact_mut(3).zip(input.chunks_exact(3)) {
283            let linear_x =
284                lut_interp_linear_float(src[0].as_() * normalizing_value, linearization_0);
285            let linear_y =
286                lut_interp_linear_float(src[1].as_() * normalizing_value, linearization_1);
287            let linear_z =
288                lut_interp_linear_float(src[2].as_() * normalizing_value, linearization_2);
289
290            let clut = fetch(linear_x, linear_y, linear_z);
291
292            let pcs_x = lut_interp_linear_float(clut.v[0], &self.gamma[0]);
293            let pcs_y = lut_interp_linear_float(clut.v[1], &self.gamma[1]);
294            let pcs_z = lut_interp_linear_float(clut.v[2], &self.gamma[2]);
295            dest[0] = pcs_x;
296            dest[1] = pcs_y;
297            dest[2] = pcs_z;
298        }
299        Ok(dst)
300    }
301
302    fn to_output<Fetch: Fn(f32, f32, f32) -> Vector3f>(
303        &self,
304        src: &[f32],
305        dst: &mut [T],
306        fetch: Fetch,
307    ) -> Result<(), CmsError> {
308        if src.len() % 3 != 0 {
309            return Err(CmsError::LaneMultipleOfChannels);
310        }
311        if dst.len() % 3 != 0 {
312            return Err(CmsError::LaneMultipleOfChannels);
313        }
314        if dst.len() != src.len() {
315            return Err(CmsError::LaneSizeMismatch);
316        }
317        let norm_value = if T::FINITE {
318            ((1u32 << self.bit_depth) - 1) as f32
319        } else {
320            1.0
321        };
322
323        let linearization_0 = &self.input[0];
324        let linearization_1 = &self.input[1];
325        let linearization_2 = &self.input[2];
326        for (dest, src) in dst.chunks_exact_mut(3).zip(src.chunks_exact(3)) {
327            let linear_x = lut_interp_linear_float(src[0], linearization_0);
328            let linear_y = lut_interp_linear_float(src[1], linearization_1);
329            let linear_z = lut_interp_linear_float(src[2], linearization_2);
330
331            let clut = fetch(linear_x, linear_y, linear_z);
332
333            let pcs_x = lut_interp_linear_float(clut.v[0], &self.gamma[0]);
334            let pcs_y = lut_interp_linear_float(clut.v[1], &self.gamma[1]);
335            let pcs_z = lut_interp_linear_float(clut.v[2], &self.gamma[2]);
336
337            if T::FINITE {
338                dest[0] = (pcs_x * norm_value).round().max(0.0).min(norm_value).as_();
339                dest[1] = (pcs_y * norm_value).round().max(0.0).min(norm_value).as_();
340                dest[2] = (pcs_z * norm_value).round().max(0.0).min(norm_value).as_();
341            } else {
342                dest[0] = pcs_x.as_();
343                dest[1] = pcs_y.as_();
344                dest[2] = pcs_z.as_();
345            }
346        }
347        Ok(())
348    }
349}
350
351#[cfg(feature = "any_to_any")]
352impl<T: Copy + Default + PointeeSizeExpressible + AsPrimitive<f32>> KatanaInitialStage<f32, T>
353    for KatanaLut3x3<T>
354where
355    f32: AsPrimitive<T>,
356{
357    fn to_pcs(&self, input: &[T]) -> Result<Vec<f32>, CmsError> {
358        let l_tbl = Cube::new(&self.clut, self.grid_size as usize, 3)?;
359
360        // If PCS is LAB then linear interpolation should be used
361        if self.pcs == DataColorSpace::Lab || self.pcs == DataColorSpace::Xyz {
362            return self.to_pcs_impl(input, |x, y, z| l_tbl.trilinear_vec3(x, y, z));
363        }
364
365        match self.interpolation_method {
366            #[cfg(feature = "options")]
367            InterpolationMethod::Tetrahedral => {
368                self.to_pcs_impl(input, |x, y, z| l_tbl.tetra_vec3(x, y, z))
369            }
370            #[cfg(feature = "options")]
371            InterpolationMethod::Pyramid => {
372                self.to_pcs_impl(input, |x, y, z| l_tbl.pyramid_vec3(x, y, z))
373            }
374            #[cfg(feature = "options")]
375            InterpolationMethod::Prism => {
376                self.to_pcs_impl(input, |x, y, z| l_tbl.prism_vec3(x, y, z))
377            }
378            InterpolationMethod::Linear => {
379                self.to_pcs_impl(input, |x, y, z| l_tbl.trilinear_vec3(x, y, z))
380            }
381        }
382    }
383}
384
385#[cfg(feature = "any_to_any")]
386impl<T: Copy + Default + PointeeSizeExpressible + AsPrimitive<f32>> KatanaFinalStage<f32, T>
387    for KatanaLut3x3<T>
388where
389    f32: AsPrimitive<T>,
390{
391    fn to_output(&self, src: &mut [f32], dst: &mut [T]) -> Result<(), CmsError> {
392        let l_tbl = Cube::new(&self.clut, self.grid_size as usize, 3)?;
393
394        // If PCS is LAB then linear interpolation should be used
395        if self.pcs == DataColorSpace::Lab || self.pcs == DataColorSpace::Xyz {
396            return self.to_output(src, dst, |x, y, z| l_tbl.trilinear_vec3(x, y, z));
397        }
398
399        match self.interpolation_method {
400            #[cfg(feature = "options")]
401            InterpolationMethod::Tetrahedral => {
402                self.to_output(src, dst, |x, y, z| l_tbl.tetra_vec3(x, y, z))
403            }
404            #[cfg(feature = "options")]
405            InterpolationMethod::Pyramid => {
406                self.to_output(src, dst, |x, y, z| l_tbl.pyramid_vec3(x, y, z))
407            }
408            #[cfg(feature = "options")]
409            InterpolationMethod::Prism => {
410                self.to_output(src, dst, |x, y, z| l_tbl.prism_vec3(x, y, z))
411            }
412            InterpolationMethod::Linear => {
413                self.to_output(src, dst, |x, y, z| l_tbl.trilinear_vec3(x, y, z))
414            }
415        }
416    }
417}
418
419pub(crate) fn create_lut3x3(
420    lut: &LutDataType,
421    src: &[f32],
422    options: TransformOptions,
423    pcs: DataColorSpace,
424) -> Result<Vec<f32>, CmsError> {
425    if lut.num_input_channels != 3 || lut.num_output_channels != 3 {
426        return Err(CmsError::UnsupportedProfileConnection);
427    }
428
429    let mut dest = try_vec![0.; src.len()];
430
431    let lut_stage = stage_lut_3x3(lut, options, pcs)?;
432    lut_stage.transform(src, &mut dest)?;
433    Ok(dest)
434}