av1_grain/diff/
solver.rs

1mod util;
2
3use std::ops::{Add, AddAssign};
4
5use anyhow::anyhow;
6use arrayvec::ArrayVec;
7use v_frame::{frame::Frame, math::clamp, plane::Plane};
8
9use self::util::{extract_ar_row, get_block_mean, get_noise_var, linsolve, multiply_mat};
10use super::{NoiseStatus, BLOCK_SIZE, BLOCK_SIZE_SQUARED};
11use crate::{
12    diff::solver::util::normalized_cross_correlation, GrainTableSegment, DEFAULT_GRAIN_SEED,
13    NUM_UV_COEFFS, NUM_UV_POINTS, NUM_Y_COEFFS, NUM_Y_POINTS,
14};
15
16const LOW_POLY_NUM_PARAMS: usize = 3;
17const NOISE_MODEL_LAG: usize = 3;
18const BLOCK_NORMALIZATION: f64 = 255.0f64;
19
20#[derive(Debug, Clone)]
21pub(super) struct FlatBlockFinder {
22    a: Box<[f64]>,
23    a_t_a_inv: [f64; LOW_POLY_NUM_PARAMS * LOW_POLY_NUM_PARAMS],
24}
25
26impl FlatBlockFinder {
27    #[must_use]
28    pub fn new() -> Self {
29        let mut eqns = EquationSystem::new(LOW_POLY_NUM_PARAMS);
30        let mut a_t_a_inv = [0.0f64; LOW_POLY_NUM_PARAMS * LOW_POLY_NUM_PARAMS];
31        let mut a = vec![0.0f64; LOW_POLY_NUM_PARAMS * BLOCK_SIZE_SQUARED];
32
33        let bs_half = (BLOCK_SIZE / 2) as f64;
34        (0..BLOCK_SIZE).for_each(|y| {
35            let yd = (y as f64 - bs_half) / bs_half;
36            (0..BLOCK_SIZE).for_each(|x| {
37                let xd = (x as f64 - bs_half) / bs_half;
38                let coords = [yd, xd, 1.0f64];
39                let row = y * BLOCK_SIZE + x;
40                a[LOW_POLY_NUM_PARAMS * row] = yd;
41                a[LOW_POLY_NUM_PARAMS * row + 1] = xd;
42                a[LOW_POLY_NUM_PARAMS * row + 2] = 1.0f64;
43
44                (0..LOW_POLY_NUM_PARAMS).for_each(|i| {
45                    (0..LOW_POLY_NUM_PARAMS).for_each(|j| {
46                        eqns.a[LOW_POLY_NUM_PARAMS * i + j] += coords[i] * coords[j];
47                    });
48                });
49            });
50        });
51
52        // Lazy inverse using existing equation solver.
53        (0..LOW_POLY_NUM_PARAMS).for_each(|i| {
54            eqns.b.fill(0.0f64);
55            eqns.b[i] = 1.0f64;
56            eqns.solve();
57
58            (0..LOW_POLY_NUM_PARAMS).for_each(|j| {
59                a_t_a_inv[j * LOW_POLY_NUM_PARAMS + i] = eqns.x[j];
60            });
61        });
62
63        FlatBlockFinder {
64            a: a.into_boxed_slice(),
65            a_t_a_inv,
66        }
67    }
68
69    // The gradient-based features used in this code are based on:
70    //  A. Kokaram, D. Kelly, H. Denman and A. Crawford, "Measuring noise
71    //  correlation for improved video denoising," 2012 19th, ICIP.
72    // The thresholds are more lenient to allow for correct grain modeling
73    // in extreme cases.
74    #[must_use]
75    #[allow(clippy::too_many_lines)]
76    pub fn run(&self, plane: &Plane<u8>) -> (Vec<u8>, usize) {
77        const TRACE_THRESHOLD: f64 = 0.15f64 / BLOCK_SIZE_SQUARED as f64;
78        const RATIO_THRESHOLD: f64 = 1.25f64;
79        const NORM_THRESHOLD: f64 = 0.08f64 / BLOCK_SIZE_SQUARED as f64;
80        const VAR_THRESHOLD: f64 = 0.005f64 / BLOCK_SIZE_SQUARED as f64;
81
82        // The following weights are used to combine the above features to give
83        // a sigmoid score for flatness. If the input was normalized to [0,100]
84        // the magnitude of these values would be close to 1 (e.g., weights
85        // corresponding to variance would be a factor of 10000x smaller).
86        const VAR_WEIGHT: f64 = -6682f64;
87        const RATIO_WEIGHT: f64 = -0.2056f64;
88        const TRACE_WEIGHT: f64 = 13087f64;
89        const NORM_WEIGHT: f64 = -12434f64;
90        const OFFSET: f64 = 2.5694f64;
91
92        let num_blocks_w = (plane.cfg.width + BLOCK_SIZE - 1) / BLOCK_SIZE;
93        let num_blocks_h = (plane.cfg.height + BLOCK_SIZE - 1) / BLOCK_SIZE;
94        let num_blocks = num_blocks_w * num_blocks_h;
95        let mut flat_blocks = vec![0u8; num_blocks];
96        let mut num_flat = 0;
97        let mut plane_result = [0.0f64; BLOCK_SIZE_SQUARED];
98        let mut block_result = [0.0f64; BLOCK_SIZE_SQUARED];
99        let mut scores = vec![IndexAndScore::default(); num_blocks];
100
101        for by in 0..num_blocks_h {
102            for bx in 0..num_blocks_w {
103                // Compute gradient covariance matrix.
104                let mut gxx = 0f64;
105                let mut gxy = 0f64;
106                let mut gyy = 0f64;
107                let mut var = 0f64;
108                let mut mean = 0f64;
109
110                self.extract_block(
111                    plane,
112                    bx * BLOCK_SIZE,
113                    by * BLOCK_SIZE,
114                    &mut plane_result,
115                    &mut block_result,
116                );
117                for yi in 1..(BLOCK_SIZE - 1) {
118                    for xi in 1..(BLOCK_SIZE - 1) {
119                        // SAFETY: We know the size of `block_result` and that we cannot exceed the bounds of it
120                        unsafe {
121                            let result_ptr = block_result.as_ptr().add(yi * BLOCK_SIZE + xi);
122
123                            let gx = (*result_ptr.add(1) - *result_ptr.sub(1)) / 2f64;
124                            let gy =
125                                (*result_ptr.add(BLOCK_SIZE) - *result_ptr.sub(BLOCK_SIZE)) / 2f64;
126                            gxx += gx * gx;
127                            gxy += gx * gy;
128                            gyy += gy * gy;
129
130                            let block_val = *result_ptr;
131                            mean += block_val;
132                            var += block_val * block_val;
133                        }
134                    }
135                }
136                let block_size_norm_factor = (BLOCK_SIZE - 2).pow(2) as f64;
137                mean /= block_size_norm_factor;
138
139                // Normalize gradients by block_size.
140                gxx /= block_size_norm_factor;
141                gxy /= block_size_norm_factor;
142                gyy /= block_size_norm_factor;
143                var = mean.mul_add(-mean, var / block_size_norm_factor);
144
145                let trace = gxx + gyy;
146                let det = gxx.mul_add(gyy, -gxy.powi(2));
147                let e_sub = (trace.mul_add(trace, -4f64 * det)).max(0.).sqrt();
148                let e1 = (trace + e_sub) / 2.0f64;
149                let e2 = (trace - e_sub) / 2.0f64;
150                // Spectral norm
151                let norm = e1;
152                let ratio = e1 / e2.max(1.0e-6_f64);
153                let is_flat = trace < TRACE_THRESHOLD
154                    && ratio < RATIO_THRESHOLD
155                    && norm < NORM_THRESHOLD
156                    && var > VAR_THRESHOLD;
157
158                let sum_weights = NORM_WEIGHT.mul_add(
159                    norm,
160                    TRACE_WEIGHT.mul_add(
161                        trace,
162                        VAR_WEIGHT.mul_add(var, RATIO_WEIGHT.mul_add(ratio, OFFSET)),
163                    ),
164                );
165                // clamp the value to [-25.0, 100.0] to prevent overflow
166                let sum_weights = clamp(sum_weights, -25.0f64, 100.0f64);
167                let score = (1.0f64 / (1.0f64 + (-sum_weights).exp())) as f32;
168                // SAFETY: We know the size of `flat_blocks` and `scores` and that we cannot exceed the bounds of it
169                unsafe {
170                    let index = by * num_blocks_w + bx;
171                    *flat_blocks.get_unchecked_mut(index) = if is_flat { 255 } else { 0 };
172                    *scores.get_unchecked_mut(index) = IndexAndScore {
173                        score: if var > VAR_THRESHOLD { score } else { 0f32 },
174                        index,
175                    };
176                }
177                if is_flat {
178                    num_flat += 1;
179                }
180            }
181        }
182
183        scores.sort_unstable_by(|a, b| a.score.partial_cmp(&b.score).expect("Shouldn't be NaN"));
184        // SAFETY: We know the size of `flat_blocks` and `scores` and that we cannot exceed the bounds of it
185        unsafe {
186            let top_nth_percentile = num_blocks * 90 / 100;
187            let score_threshold = scores.get_unchecked(top_nth_percentile).score;
188            for score in &scores {
189                if score.score >= score_threshold {
190                    let block_ref = flat_blocks.get_unchecked_mut(score.index);
191                    if *block_ref == 0 {
192                        num_flat += 1;
193                    }
194                    *block_ref |= 1;
195                }
196            }
197        }
198
199        (flat_blocks, num_flat)
200    }
201
202    fn extract_block(
203        &self,
204        plane: &Plane<u8>,
205        offset_x: usize,
206        offset_y: usize,
207        plane_result: &mut [f64; BLOCK_SIZE_SQUARED],
208        block_result: &mut [f64; BLOCK_SIZE_SQUARED],
209    ) {
210        let mut plane_coords = [0f64; LOW_POLY_NUM_PARAMS];
211        let mut a_t_a_inv_b = [0f64; LOW_POLY_NUM_PARAMS];
212        let plane_origin = plane.data_origin();
213
214        for yi in 0..BLOCK_SIZE {
215            let y = clamp(offset_y + yi, 0, plane.cfg.height - 1);
216            for xi in 0..BLOCK_SIZE {
217                let x = clamp(offset_x + xi, 0, plane.cfg.width - 1);
218                // SAFETY: We know the bounds of the plane data and `block_result`
219                // and do not exceed them.
220                unsafe {
221                    *block_result.get_unchecked_mut(yi * BLOCK_SIZE + xi) =
222                        f64::from(*plane_origin.get_unchecked(y * plane.cfg.stride + x))
223                            / BLOCK_NORMALIZATION;
224                }
225            }
226        }
227
228        multiply_mat(
229            block_result,
230            &self.a,
231            &mut a_t_a_inv_b,
232            1,
233            BLOCK_SIZE_SQUARED,
234            LOW_POLY_NUM_PARAMS,
235        );
236        multiply_mat(
237            &self.a_t_a_inv,
238            &a_t_a_inv_b,
239            &mut plane_coords,
240            LOW_POLY_NUM_PARAMS,
241            LOW_POLY_NUM_PARAMS,
242            1,
243        );
244        multiply_mat(
245            &self.a,
246            &plane_coords,
247            plane_result,
248            BLOCK_SIZE_SQUARED,
249            LOW_POLY_NUM_PARAMS,
250            1,
251        );
252
253        for (block_res, plane_res) in block_result.iter_mut().zip(plane_result.iter()) {
254            *block_res -= *plane_res;
255        }
256    }
257}
258
259#[derive(Debug, Clone, Copy, Default)]
260struct IndexAndScore {
261    pub index: usize,
262    pub score: f32,
263}
264
265/// Wrapper of data required to represent linear system of eqns and soln.
266#[derive(Debug, Clone)]
267struct EquationSystem {
268    a: Vec<f64>,
269    b: Vec<f64>,
270    x: Vec<f64>,
271    n: usize,
272}
273
274impl EquationSystem {
275    #[must_use]
276    pub fn new(n: usize) -> Self {
277        Self {
278            a: vec![0.0f64; n * n],
279            b: vec![0.0f64; n],
280            x: vec![0.0f64; n],
281            n,
282        }
283    }
284
285    pub fn solve(&mut self) -> bool {
286        let n = self.n;
287        let mut a = self.a.clone();
288        let mut b = self.b.clone();
289
290        linsolve(n, &mut a, self.n, &mut b, &mut self.x)
291    }
292
293    pub fn set_chroma_coefficient_fallback_solution(&mut self) {
294        const TOLERANCE: f64 = 1.0e-6f64;
295        let last = self.n - 1;
296        // Set all of the AR coefficients to zero, but try to solve for correlation
297        // with the luma channel
298        self.x.fill(0f64);
299        if self.a[last * self.n + last].abs() > TOLERANCE {
300            self.x[last] = self.b[last] / self.a[last * self.n + last];
301        }
302    }
303
304    pub fn copy_from(&mut self, other: &Self) {
305        assert_eq!(self.n, other.n);
306        self.a.copy_from_slice(&other.a);
307        self.x.copy_from_slice(&other.x);
308        self.b.copy_from_slice(&other.b);
309    }
310
311    pub fn clear(&mut self) {
312        self.a.fill(0f64);
313        self.b.fill(0f64);
314        self.x.fill(0f64);
315    }
316}
317
318impl Add<&EquationSystem> for EquationSystem {
319    type Output = EquationSystem;
320
321    fn add(self, addend: &EquationSystem) -> Self::Output {
322        let mut dest = self.clone();
323        let n = self.n;
324        for i in 0..n {
325            for j in 0..n {
326                dest.a[i * n + j] += addend.a[i * n + j];
327            }
328            dest.b[i] += addend.b[i];
329        }
330        dest
331    }
332}
333
334impl AddAssign<&EquationSystem> for EquationSystem {
335    fn add_assign(&mut self, rhs: &EquationSystem) {
336        *self = self.clone() + rhs;
337    }
338}
339
340/// Representation of a piecewise linear curve
341///
342/// Holds n points as (x, y) pairs, that store the curve.
343struct NoiseStrengthLut {
344    points: Vec<[f64; 2]>,
345}
346
347impl NoiseStrengthLut {
348    #[must_use]
349    pub fn new(num_bins: usize) -> Self {
350        assert!(num_bins > 0);
351        Self {
352            points: vec![[0f64; 2]; num_bins],
353        }
354    }
355}
356
357#[derive(Debug, Clone)]
358pub(super) struct NoiseModel {
359    combined_state: [NoiseModelState; 3],
360    latest_state: [NoiseModelState; 3],
361    n: usize,
362    coords: Vec<[isize; 2]>,
363}
364
365impl NoiseModel {
366    #[must_use]
367    pub fn new() -> Self {
368        let n = Self::num_coeffs();
369        let combined_state = [
370            NoiseModelState::new(n),
371            NoiseModelState::new(n + 1),
372            NoiseModelState::new(n + 1),
373        ];
374        let latest_state = [
375            NoiseModelState::new(n),
376            NoiseModelState::new(n + 1),
377            NoiseModelState::new(n + 1),
378        ];
379        let mut coords = Vec::new();
380
381        let neg_lag = -(NOISE_MODEL_LAG as isize);
382        for y in neg_lag..=0 {
383            let max_x = if y == 0 {
384                -1isize
385            } else {
386                NOISE_MODEL_LAG as isize
387            };
388            for x in neg_lag..=max_x {
389                coords.push([x, y]);
390            }
391        }
392        assert!(n == coords.len());
393
394        Self {
395            combined_state,
396            latest_state,
397            n,
398            coords,
399        }
400    }
401
402    pub fn update(
403        &mut self,
404        source: &Frame<u8>,
405        denoised: &Frame<u8>,
406        flat_blocks: &[u8],
407    ) -> NoiseStatus {
408        let num_blocks_w = (source.planes[0].cfg.width + BLOCK_SIZE - 1) / BLOCK_SIZE;
409        let num_blocks_h = (source.planes[0].cfg.height + BLOCK_SIZE - 1) / BLOCK_SIZE;
410        let mut y_model_different = false;
411
412        // Clear the latest equation system
413        for i in 0..3 {
414            self.latest_state[i].eqns.clear();
415            self.latest_state[i].num_observations = 0;
416            self.latest_state[i].strength_solver.clear();
417        }
418
419        // Check that we have enough flat blocks
420        let num_blocks = flat_blocks.iter().filter(|b| **b > 0).count();
421        if num_blocks <= 1 {
422            return NoiseStatus::Error(anyhow!("Not enough flat blocks to update noise estimate"));
423        }
424
425        let frame_dims = (source.planes[0].cfg.width, source.planes[0].cfg.height);
426        for channel in 0..3 {
427            if source.planes[channel].data.is_empty() {
428                // Monochrome source
429                break;
430            }
431            let is_chroma = channel > 0;
432            let alt_source = (channel > 0).then(|| &source.planes[0]);
433            let alt_denoised = (channel > 0).then(|| &denoised.planes[0]);
434            self.add_block_observations(
435                channel,
436                &source.planes[channel],
437                &denoised.planes[channel],
438                alt_source,
439                alt_denoised,
440                frame_dims,
441                flat_blocks,
442                num_blocks_w,
443                num_blocks_h,
444            );
445            if !self.latest_state[channel].ar_equation_system_solve(is_chroma) {
446                if is_chroma {
447                    self.latest_state[channel]
448                        .eqns
449                        .set_chroma_coefficient_fallback_solution();
450                } else {
451                    return NoiseStatus::Error(anyhow!(
452                        "Solving latest noise equation system failed on plane {}",
453                        channel
454                    ));
455                }
456            }
457            self.add_noise_std_observations(
458                channel,
459                &source.planes[channel],
460                &denoised.planes[channel],
461                alt_source,
462                frame_dims,
463                flat_blocks,
464                num_blocks_w,
465                num_blocks_h,
466            );
467            if !self.latest_state[channel].strength_solver.solve() {
468                return NoiseStatus::Error(anyhow!(
469                    "Failed to solve strength solver for latest state"
470                ));
471            }
472
473            // Check noise characteristics and return if error
474            if channel == 0
475                && self.combined_state[channel].strength_solver.num_equations > 0
476                && self.is_different()
477            {
478                y_model_different = true;
479            }
480
481            if y_model_different {
482                continue;
483            }
484
485            self.combined_state[channel].num_observations +=
486                self.latest_state[channel].num_observations;
487            self.combined_state[channel].eqns += &self.latest_state[channel].eqns;
488            if !self.combined_state[channel].ar_equation_system_solve(is_chroma) {
489                if is_chroma {
490                    self.combined_state[channel]
491                        .eqns
492                        .set_chroma_coefficient_fallback_solution();
493                } else {
494                    return NoiseStatus::Error(anyhow!(
495                        "Solving combined noise equation system failed on plane {}",
496                        channel
497                    ));
498                }
499            }
500
501            self.combined_state[channel].strength_solver +=
502                &self.latest_state[channel].strength_solver;
503
504            if !self.combined_state[channel].strength_solver.solve() {
505                return NoiseStatus::Error(anyhow!(
506                    "Failed to solve strength solver for combined state"
507                ));
508            };
509        }
510
511        if y_model_different {
512            return NoiseStatus::DifferentType;
513        }
514
515        NoiseStatus::Ok
516    }
517
518    #[allow(clippy::too_many_lines)]
519    #[must_use]
520    pub fn get_grain_parameters(&self, start_ts: u64, end_ts: u64) -> GrainTableSegment {
521        // Both the domain and the range of the scaling functions in the film_grain
522        // are normalized to 8-bit (e.g., they are implicitly scaled during grain
523        // synthesis).
524        let scaling_points_y = self.combined_state[0]
525            .strength_solver
526            .fit_piecewise(NUM_Y_POINTS)
527            .points;
528        let scaling_points_cb = self.combined_state[1]
529            .strength_solver
530            .fit_piecewise(NUM_UV_POINTS)
531            .points;
532        let scaling_points_cr = self.combined_state[2]
533            .strength_solver
534            .fit_piecewise(NUM_UV_POINTS)
535            .points;
536
537        let mut max_scaling_value: f64 = 1.0e-4f64;
538        for p in scaling_points_y
539            .iter()
540            .chain(scaling_points_cb.iter())
541            .chain(scaling_points_cr.iter())
542            .map(|p| p[1])
543        {
544            if p > max_scaling_value {
545                max_scaling_value = p;
546            }
547        }
548
549        // Scaling_shift values are in the range [8,11]
550        let max_scaling_value_log2 =
551            clamp((max_scaling_value.log2() + 1f64).floor() as u8, 2u8, 5u8);
552        let scale_factor = f64::from(1u32 << (8u8 - max_scaling_value_log2));
553        let map_scaling_point = |p: [f64; 2]| {
554            [
555                (p[0] + 0.5f64) as u8,
556                clamp(scale_factor.mul_add(p[1], 0.5f64) as i32, 0i32, 255i32) as u8,
557            ]
558        };
559
560        let scaling_points_y: ArrayVec<_, NUM_Y_POINTS> = scaling_points_y
561            .into_iter()
562            .map(map_scaling_point)
563            .collect();
564        let scaling_points_cb: ArrayVec<_, NUM_UV_POINTS> = scaling_points_cb
565            .into_iter()
566            .map(map_scaling_point)
567            .collect();
568        let scaling_points_cr: ArrayVec<_, NUM_UV_POINTS> = scaling_points_cr
569            .into_iter()
570            .map(map_scaling_point)
571            .collect();
572
573        // Convert the ar_coeffs into 8-bit values
574        let n_coeff = self.combined_state[0].eqns.n;
575        let mut max_coeff = 1.0e-4f64;
576        let mut min_coeff = 1.0e-4f64;
577        let mut y_corr = [0f64; 2];
578        let mut avg_luma_strength = 0f64;
579        for c in 0..3 {
580            let eqns = &self.combined_state[c].eqns;
581            for i in 0..n_coeff {
582                if eqns.x[i] > max_coeff {
583                    max_coeff = eqns.x[i];
584                }
585                if eqns.x[i] < min_coeff {
586                    min_coeff = eqns.x[i];
587                }
588            }
589
590            // Since the correlation between luma/chroma was computed in an already
591            // scaled space, we adjust it in the un-scaled space.
592            let solver = &self.combined_state[c].strength_solver;
593            // Compute a weighted average of the strength for the channel.
594            let mut average_strength = 0f64;
595            let mut total_weight = 0f64;
596            for i in 0..solver.eqns.n {
597                let mut w = 0f64;
598                for j in 0..solver.eqns.n {
599                    w += solver.eqns.a[i * solver.eqns.n + j];
600                }
601                w = w.sqrt();
602                average_strength += solver.eqns.x[i] * w;
603                total_weight += w;
604            }
605            if total_weight.abs() < f64::EPSILON {
606                average_strength = 1f64;
607            } else {
608                average_strength /= total_weight;
609            }
610            if c == 0 {
611                avg_luma_strength = average_strength;
612            } else {
613                y_corr[c - 1] = avg_luma_strength * eqns.x[n_coeff] / average_strength;
614                max_coeff = max_coeff.max(y_corr[c - 1]);
615                min_coeff = min_coeff.min(y_corr[c - 1]);
616            }
617        }
618
619        // Shift value: AR coeffs range (values 6-9)
620        // 6: [-2, 2),  7: [-1, 1), 8: [-0.5, 0.5), 9: [-0.25, 0.25)
621        let ar_coeff_shift = clamp(
622            7i32 - (1.0f64 + max_coeff.log2().floor()).max((-min_coeff).log2().ceil()) as i32,
623            6i32,
624            9i32,
625        ) as u8;
626        let scale_ar_coeff = f64::from(1u16 << ar_coeff_shift);
627        let ar_coeffs_y = self.get_ar_coeffs_y(n_coeff, scale_ar_coeff);
628        let ar_coeffs_cb = self.get_ar_coeffs_uv(1, n_coeff, scale_ar_coeff, y_corr);
629        let ar_coeffs_cr = self.get_ar_coeffs_uv(2, n_coeff, scale_ar_coeff, y_corr);
630
631        GrainTableSegment {
632            random_seed: if start_ts == 0 { DEFAULT_GRAIN_SEED } else { 0 },
633            start_time: start_ts,
634            end_time: end_ts,
635            ar_coeff_lag: NOISE_MODEL_LAG as u8,
636            scaling_points_y,
637            scaling_points_cb,
638            scaling_points_cr,
639            scaling_shift: 5 + (8 - max_scaling_value_log2),
640            ar_coeff_shift,
641            ar_coeffs_y,
642            ar_coeffs_cb,
643            ar_coeffs_cr,
644            // At the moment, the noise modeling code assumes that the chroma scaling
645            // functions are a function of luma.
646            cb_mult: 128,
647            cb_luma_mult: 192,
648            cb_offset: 256,
649            cr_mult: 128,
650            cr_luma_mult: 192,
651            cr_offset: 256,
652            chroma_scaling_from_luma: false,
653            grain_scale_shift: 0,
654            overlap_flag: true,
655        }
656    }
657
658    pub fn save_latest(&mut self) {
659        for c in 0..3 {
660            let latest_state = &self.latest_state[c];
661            let combined_state = &mut self.combined_state[c];
662            combined_state.eqns.copy_from(&latest_state.eqns);
663            combined_state
664                .strength_solver
665                .eqns
666                .copy_from(&latest_state.strength_solver.eqns);
667            combined_state.strength_solver.num_equations =
668                latest_state.strength_solver.num_equations;
669            combined_state.num_observations = latest_state.num_observations;
670            combined_state.ar_gain = latest_state.ar_gain;
671        }
672    }
673
674    #[must_use]
675    const fn num_coeffs() -> usize {
676        let n = 2 * NOISE_MODEL_LAG + 1;
677        (n * n) / 2
678    }
679
680    #[must_use]
681    fn get_ar_coeffs_y(&self, n_coeff: usize, scale_ar_coeff: f64) -> ArrayVec<i8, NUM_Y_COEFFS> {
682        assert!(n_coeff <= NUM_Y_COEFFS);
683        let mut coeffs = ArrayVec::new();
684        let eqns = &self.combined_state[0].eqns;
685        for i in 0..n_coeff {
686            coeffs.push(clamp((scale_ar_coeff * eqns.x[i]).round() as i32, -128i32, 127i32) as i8);
687        }
688        coeffs
689    }
690
691    #[must_use]
692    fn get_ar_coeffs_uv(
693        &self,
694        channel: usize,
695        n_coeff: usize,
696        scale_ar_coeff: f64,
697        y_corr: [f64; 2],
698    ) -> ArrayVec<i8, NUM_UV_COEFFS> {
699        assert!(n_coeff <= NUM_Y_COEFFS);
700        let mut coeffs = ArrayVec::new();
701        let eqns = &self.combined_state[channel].eqns;
702        for i in 0..n_coeff {
703            coeffs.push(clamp((scale_ar_coeff * eqns.x[i]).round() as i32, -128i32, 127i32) as i8);
704        }
705        coeffs.push(clamp(
706            (scale_ar_coeff * y_corr[channel - 1]).round() as i32,
707            -128i32,
708            127i32,
709        ) as i8);
710        coeffs
711    }
712
713    // Return true if the noise estimate appears to be different from the combined
714    // (multi-frame) estimate. The difference is measured by checking whether the
715    // AR coefficients have diverged (using a threshold on normalized cross
716    // correlation), or whether the noise strength has changed.
717    #[must_use]
718    fn is_different(&self) -> bool {
719        const COEFF_THRESHOLD: f64 = 0.9f64;
720        const STRENGTH_THRESHOLD: f64 = 0.005f64;
721
722        let latest = &self.latest_state[0];
723        let combined = &self.combined_state[0];
724        let corr = normalized_cross_correlation(&latest.eqns.x, &combined.eqns.x, combined.eqns.n);
725        if corr < COEFF_THRESHOLD {
726            return true;
727        }
728
729        let dx = 1.0f64 / latest.strength_solver.num_bins as f64;
730        let latest_eqns = &latest.strength_solver.eqns;
731        let combined_eqns = &combined.strength_solver.eqns;
732        let mut diff = 0.0f64;
733        let mut total_weight = 0.0f64;
734        for j in 0..latest_eqns.n {
735            let mut weight = 0.0f64;
736            for i in 0..latest_eqns.n {
737                weight += latest_eqns.a[i * latest_eqns.n + j];
738            }
739            weight = weight.sqrt();
740            diff += weight * (latest_eqns.x[j] - combined_eqns.x[j]).abs();
741            total_weight += weight;
742        }
743
744        diff * dx / total_weight > STRENGTH_THRESHOLD
745    }
746
747    #[allow(clippy::too_many_arguments)]
748    fn add_block_observations(
749        &mut self,
750        channel: usize,
751        source: &Plane<u8>,
752        denoised: &Plane<u8>,
753        alt_source: Option<&Plane<u8>>,
754        alt_denoised: Option<&Plane<u8>>,
755        frame_dims: (usize, usize),
756        flat_blocks: &[u8],
757        num_blocks_w: usize,
758        num_blocks_h: usize,
759    ) {
760        let num_coords = self.n;
761        let state = &mut self.latest_state[channel];
762        let a = &mut state.eqns.a;
763        let b = &mut state.eqns.b;
764        let mut buffer = vec![0f64; num_coords + 1].into_boxed_slice();
765        let n = state.eqns.n;
766        let block_w = BLOCK_SIZE >> source.cfg.xdec;
767        let block_h = BLOCK_SIZE >> source.cfg.ydec;
768
769        let dec = (source.cfg.xdec, source.cfg.ydec);
770        let stride = source.cfg.stride;
771        let source_origin = source.data_origin();
772        let denoised_origin = denoised.data_origin();
773        let alt_stride = alt_source.map_or(0, |s| s.cfg.stride);
774        let alt_source_origin = alt_source.map(|s| s.data_origin());
775        let alt_denoised_origin = alt_denoised.map(|s| s.data_origin());
776
777        for by in 0..num_blocks_h {
778            let y_o = by * block_h;
779            for bx in 0..num_blocks_w {
780                // SAFETY: We know the indexes we provide do not overflow the data bounds
781                unsafe {
782                    let flat_block_ptr = flat_blocks.as_ptr().add(by * num_blocks_w + bx);
783                    let x_o = bx * block_w;
784                    if *flat_block_ptr == 0 {
785                        continue;
786                    }
787                    let y_start = if by > 0 && *flat_block_ptr.sub(num_blocks_w) > 0 {
788                        0
789                    } else {
790                        NOISE_MODEL_LAG
791                    };
792                    let x_start = if bx > 0 && *flat_block_ptr.sub(1) > 0 {
793                        0
794                    } else {
795                        NOISE_MODEL_LAG
796                    };
797                    let y_end = ((frame_dims.1 >> dec.1) - by * block_h).min(block_h);
798                    let x_end = ((frame_dims.0 >> dec.0) - bx * block_w - NOISE_MODEL_LAG).min(
799                        if bx + 1 < num_blocks_w && *flat_block_ptr.add(1) > 0 {
800                            block_w
801                        } else {
802                            block_w - NOISE_MODEL_LAG
803                        },
804                    );
805                    for y in y_start..y_end {
806                        for x in x_start..x_end {
807                            let val = extract_ar_row(
808                                &self.coords,
809                                num_coords,
810                                source_origin,
811                                denoised_origin,
812                                stride,
813                                dec,
814                                alt_source_origin,
815                                alt_denoised_origin,
816                                alt_stride,
817                                x + x_o,
818                                y + y_o,
819                                &mut buffer,
820                            );
821                            for i in 0..n {
822                                for j in 0..n {
823                                    *a.get_unchecked_mut(i * n + j) += (*buffer.get_unchecked(i)
824                                        * *buffer.get_unchecked(j))
825                                        / BLOCK_NORMALIZATION.powi(2);
826                                }
827                                *b.get_unchecked_mut(i) +=
828                                    (*buffer.get_unchecked(i) * val) / BLOCK_NORMALIZATION.powi(2);
829                            }
830                            state.num_observations += 1;
831                        }
832                    }
833                }
834            }
835        }
836    }
837
838    #[allow(clippy::too_many_arguments)]
839    fn add_noise_std_observations(
840        &mut self,
841        channel: usize,
842        source: &Plane<u8>,
843        denoised: &Plane<u8>,
844        alt_source: Option<&Plane<u8>>,
845        frame_dims: (usize, usize),
846        flat_blocks: &[u8],
847        num_blocks_w: usize,
848        num_blocks_h: usize,
849    ) {
850        let coeffs = &self.latest_state[channel].eqns.x;
851        let num_coords = self.n;
852        let luma_gain = self.latest_state[0].ar_gain;
853        let noise_gain = self.latest_state[channel].ar_gain;
854        let block_w = BLOCK_SIZE >> source.cfg.xdec;
855        let block_h = BLOCK_SIZE >> source.cfg.ydec;
856
857        for by in 0..num_blocks_h {
858            let y_o = by * block_h;
859            for bx in 0..num_blocks_w {
860                let x_o = bx * block_w;
861                if flat_blocks[by * num_blocks_w + bx] == 0 {
862                    continue;
863                }
864                let num_samples_h = ((frame_dims.1 >> source.cfg.ydec) - by * block_h).min(block_h);
865                let num_samples_w = ((frame_dims.0 >> source.cfg.xdec) - bx * block_w).min(block_w);
866                // Make sure that we have a reasonable amount of samples to consider the
867                // block
868                if num_samples_w * num_samples_h > BLOCK_SIZE {
869                    let block_mean = get_block_mean(
870                        alt_source.unwrap_or(source),
871                        frame_dims,
872                        x_o << source.cfg.xdec,
873                        y_o << source.cfg.ydec,
874                    );
875                    let noise_var = get_noise_var(
876                        source,
877                        denoised,
878                        (
879                            frame_dims.0 >> source.cfg.xdec,
880                            frame_dims.1 >> source.cfg.ydec,
881                        ),
882                        x_o,
883                        y_o,
884                        block_w,
885                        block_h,
886                    );
887                    // We want to remove the part of the noise that came from being
888                    // correlated with luma. Note that the noise solver for luma must
889                    // have already been run.
890                    let luma_strength = if channel > 0 {
891                        luma_gain * self.latest_state[0].strength_solver.get_value(block_mean)
892                    } else {
893                        0f64
894                    };
895                    let corr = if channel > 0 {
896                        coeffs[num_coords]
897                    } else {
898                        0f64
899                    };
900                    // Chroma noise:
901                    //    N(0, noise_var) = N(0, uncorr_var) + corr * N(0, luma_strength^2)
902                    // The uncorrelated component:
903                    //   uncorr_var = noise_var - (corr * luma_strength)^2
904                    // But don't allow fully correlated noise (hence the max), since the
905                    // synthesis cannot model it.
906                    let uncorr_std = (noise_var / 16f64)
907                        .max((corr * luma_strength).mul_add(-(corr * luma_strength), noise_var))
908                        .sqrt();
909                    let adjusted_strength = uncorr_std / noise_gain;
910                    self.latest_state[channel]
911                        .strength_solver
912                        .add_measurement(block_mean, adjusted_strength);
913                }
914            }
915        }
916    }
917}
918
919#[derive(Debug, Clone)]
920struct NoiseModelState {
921    eqns: EquationSystem,
922    ar_gain: f64,
923    num_observations: usize,
924    strength_solver: StrengthSolver,
925}
926
927impl NoiseModelState {
928    #[must_use]
929    pub fn new(n: usize) -> Self {
930        const NUM_BINS: usize = 20;
931
932        Self {
933            eqns: EquationSystem::new(n),
934            ar_gain: 1.0f64,
935            num_observations: 0usize,
936            strength_solver: StrengthSolver::new(NUM_BINS),
937        }
938    }
939
940    pub fn ar_equation_system_solve(&mut self, is_chroma: bool) -> bool {
941        let ret = self.eqns.solve();
942        self.ar_gain = 1.0f64;
943        if !ret {
944            return ret;
945        }
946
947        // Update the AR gain from the equation system as it will be used to fit
948        // the noise strength as a function of intensity.  In the Yule-Walker
949        // equations, the diagonal should be the variance of the correlated noise.
950        // In the case of the least squares estimate, there will be some variability
951        // in the diagonal. So use the mean of the diagonal as the estimate of
952        // overall variance (this works for least squares or Yule-Walker formulation).
953        let mut var = 0f64;
954        let n_adjusted = self.eqns.n - usize::from(is_chroma);
955        for i in 0..n_adjusted {
956            var += self.eqns.a[i * self.eqns.n + i] / self.num_observations as f64;
957        }
958        var /= n_adjusted as f64;
959
960        // Keep track of E(Y^2) = <b, x> + E(X^2)
961        // In the case that we are using chroma and have an estimate of correlation
962        // with luma we adjust that estimate slightly to remove the correlated bits by
963        // subtracting out the last column of a scaled by our correlation estimate
964        // from b. E(y^2) = <b - A(:, end)*x(end), x>
965        let mut sum_covar = 0f64;
966        for i in 0..n_adjusted {
967            let mut bi = self.eqns.b[i];
968            if is_chroma {
969                bi -= self.eqns.a[i * self.eqns.n + n_adjusted] * self.eqns.x[n_adjusted];
970            }
971            sum_covar += (bi * self.eqns.x[i]) / self.num_observations as f64;
972        }
973
974        // Now, get an estimate of the variance of uncorrelated noise signal and use
975        // it to determine the gain of the AR filter.
976        let noise_var = (var - sum_covar).max(1e-6f64);
977        self.ar_gain = 1f64.max((var / noise_var).max(1e-6f64).sqrt());
978        ret
979    }
980}
981
982#[derive(Debug, Clone)]
983struct StrengthSolver {
984    eqns: EquationSystem,
985    num_bins: usize,
986    num_equations: usize,
987    total: f64,
988}
989
990impl StrengthSolver {
991    #[must_use]
992    pub fn new(num_bins: usize) -> Self {
993        Self {
994            eqns: EquationSystem::new(num_bins),
995            num_bins,
996            num_equations: 0usize,
997            total: 0f64,
998        }
999    }
1000
1001    pub fn add_measurement(&mut self, block_mean: f64, noise_std: f64) {
1002        let bin = self.get_bin_index(block_mean);
1003        let bin_i0 = bin.floor() as usize;
1004        let bin_i1 = (self.num_bins - 1).min(bin_i0 + 1);
1005        let a = bin - bin_i0 as f64;
1006        let n = self.num_bins;
1007        let eqns = &mut self.eqns;
1008        eqns.a[bin_i0 * n + bin_i0] += (1f64 - a).powi(2);
1009        eqns.a[bin_i1 * n + bin_i0] += a * (1f64 - a);
1010        eqns.a[bin_i1 * n + bin_i1] += a.powi(2);
1011        eqns.a[bin_i0 * n + bin_i1] += (1f64 - a) * a;
1012        eqns.b[bin_i0] += (1f64 - a) * noise_std;
1013        eqns.b[bin_i1] += a * noise_std;
1014        self.total += noise_std;
1015        self.num_equations += 1;
1016    }
1017
1018    pub fn solve(&mut self) -> bool {
1019        // Add regularization proportional to the number of constraints
1020        let n = self.num_bins;
1021        let alpha = 2f64 * self.num_equations as f64 / n as f64;
1022
1023        // Do this in a non-destructive manner so it is not confusing to the caller
1024        let old_a = self.eqns.a.clone();
1025        for i in 0..n {
1026            let i_lo = i.saturating_sub(1);
1027            let i_hi = (n - 1).min(i + 1);
1028            self.eqns.a[i * n + i_lo] -= alpha;
1029            self.eqns.a[i * n + i] += 2f64 * alpha;
1030            self.eqns.a[i * n + i_hi] -= alpha;
1031        }
1032
1033        // Small regularization to give average noise strength
1034        let mean = self.total / self.num_equations as f64;
1035        for i in 0..n {
1036            self.eqns.a[i * n + i] += 1f64 / 8192f64;
1037            self.eqns.b[i] += mean / 8192f64;
1038        }
1039        let result = self.eqns.solve();
1040        self.eqns.a = old_a;
1041        result
1042    }
1043
1044    #[must_use]
1045    pub fn fit_piecewise(&self, max_output_points: usize) -> NoiseStrengthLut {
1046        const TOLERANCE: f64 = 0.00625f64;
1047
1048        let mut lut = NoiseStrengthLut::new(self.num_bins);
1049        for i in 0..self.num_bins {
1050            lut.points[i][0] = self.get_center(i);
1051            lut.points[i][1] = self.eqns.x[i];
1052        }
1053
1054        let mut residual = vec![0.0f64; self.num_bins];
1055        self.update_piecewise_linear_residual(&lut, &mut residual, 0, self.num_bins);
1056
1057        // Greedily remove points if there are too many or if it doesn't hurt local
1058        // approximation (never remove the end points)
1059        while lut.points.len() > 2 {
1060            let mut min_index = 1usize;
1061            for j in 1..(lut.points.len() - 1) {
1062                if residual[j] < residual[min_index] {
1063                    min_index = j;
1064                }
1065            }
1066            let dx = lut.points[min_index + 1][0] - lut.points[min_index - 1][0];
1067            let avg_residual = residual[min_index] / dx;
1068            if lut.points.len() <= max_output_points && avg_residual > TOLERANCE {
1069                break;
1070            }
1071
1072            lut.points.remove(min_index);
1073            self.update_piecewise_linear_residual(
1074                &lut,
1075                &mut residual,
1076                min_index - 1,
1077                min_index + 1,
1078            );
1079        }
1080
1081        lut
1082    }
1083
1084    #[must_use]
1085    pub fn get_value(&self, x: f64) -> f64 {
1086        let bin = self.get_bin_index(x);
1087        let bin_i0 = bin.floor() as usize;
1088        let bin_i1 = (self.num_bins - 1).min(bin_i0 + 1);
1089        let a = bin - bin_i0 as f64;
1090        (1f64 - a).mul_add(self.eqns.x[bin_i0], a * self.eqns.x[bin_i1])
1091    }
1092
1093    pub fn clear(&mut self) {
1094        self.eqns.clear();
1095        self.num_equations = 0;
1096        self.total = 0f64;
1097    }
1098
1099    #[must_use]
1100    fn get_bin_index(&self, value: f64) -> f64 {
1101        let max = 255f64;
1102        let val = clamp(value, 0f64, max);
1103        (self.num_bins - 1) as f64 * val / max
1104    }
1105
1106    fn update_piecewise_linear_residual(
1107        &self,
1108        lut: &NoiseStrengthLut,
1109        residual: &mut [f64],
1110        start: usize,
1111        end: usize,
1112    ) {
1113        let dx = 255f64 / self.num_bins as f64;
1114        #[allow(clippy::needless_range_loop)]
1115        for i in start.max(1)..end.min(lut.points.len() - 1) {
1116            let lower = 0usize.max(self.get_bin_index(lut.points[i - 1][0]).floor() as usize);
1117            let upper =
1118                (self.num_bins - 1).min(self.get_bin_index(lut.points[i + 1][0]).ceil() as usize);
1119            let mut r = 0f64;
1120            for j in lower..=upper {
1121                let x = self.get_center(j);
1122                if x < lut.points[i - 1][0] || x >= lut.points[i + 1][0] {
1123                    continue;
1124                }
1125
1126                let y = self.eqns.x[j];
1127                let a = (x - lut.points[i - 1][0]) / (lut.points[i + 1][0] - lut.points[i - 1][0]);
1128                let estimate_y = lut.points[i - 1][1].mul_add(1f64 - a, lut.points[i + 1][1] * a);
1129                r += (y - estimate_y).abs();
1130            }
1131            residual[i] = r * dx;
1132        }
1133    }
1134
1135    #[must_use]
1136    fn get_center(&self, i: usize) -> f64 {
1137        let range = 255f64;
1138        let n = self.num_bins;
1139        i as f64 / (n - 1) as f64 * range
1140    }
1141}
1142
1143impl Add<&StrengthSolver> for StrengthSolver {
1144    type Output = StrengthSolver;
1145
1146    fn add(self, addend: &StrengthSolver) -> Self::Output {
1147        let mut dest = self;
1148        dest.eqns += &addend.eqns;
1149        dest.num_equations += addend.num_equations;
1150        dest.total += addend.total;
1151        dest
1152    }
1153}
1154
1155impl AddAssign<&StrengthSolver> for StrengthSolver {
1156    fn add_assign(&mut self, rhs: &StrengthSolver) {
1157        *self = self.clone() + rhs;
1158    }
1159}