1mod util;
2
3use std::ops::{Add, AddAssign};
4
5use anyhow::anyhow;
6use arrayvec::ArrayVec;
7use v_frame::{frame::Frame, math::clamp, plane::Plane};
8
9use self::util::{extract_ar_row, get_block_mean, get_noise_var, linsolve, multiply_mat};
10use super::{NoiseStatus, BLOCK_SIZE, BLOCK_SIZE_SQUARED};
11use crate::{
12 diff::solver::util::normalized_cross_correlation, GrainTableSegment, DEFAULT_GRAIN_SEED,
13 NUM_UV_COEFFS, NUM_UV_POINTS, NUM_Y_COEFFS, NUM_Y_POINTS,
14};
15
16const LOW_POLY_NUM_PARAMS: usize = 3;
17const NOISE_MODEL_LAG: usize = 3;
18const BLOCK_NORMALIZATION: f64 = 255.0f64;
19
20#[derive(Debug, Clone)]
21pub(super) struct FlatBlockFinder {
22 a: Box<[f64]>,
23 a_t_a_inv: [f64; LOW_POLY_NUM_PARAMS * LOW_POLY_NUM_PARAMS],
24}
25
26impl FlatBlockFinder {
27 #[must_use]
28 pub fn new() -> Self {
29 let mut eqns = EquationSystem::new(LOW_POLY_NUM_PARAMS);
30 let mut a_t_a_inv = [0.0f64; LOW_POLY_NUM_PARAMS * LOW_POLY_NUM_PARAMS];
31 let mut a = vec![0.0f64; LOW_POLY_NUM_PARAMS * BLOCK_SIZE_SQUARED];
32
33 let bs_half = (BLOCK_SIZE / 2) as f64;
34 (0..BLOCK_SIZE).for_each(|y| {
35 let yd = (y as f64 - bs_half) / bs_half;
36 (0..BLOCK_SIZE).for_each(|x| {
37 let xd = (x as f64 - bs_half) / bs_half;
38 let coords = [yd, xd, 1.0f64];
39 let row = y * BLOCK_SIZE + x;
40 a[LOW_POLY_NUM_PARAMS * row] = yd;
41 a[LOW_POLY_NUM_PARAMS * row + 1] = xd;
42 a[LOW_POLY_NUM_PARAMS * row + 2] = 1.0f64;
43
44 (0..LOW_POLY_NUM_PARAMS).for_each(|i| {
45 (0..LOW_POLY_NUM_PARAMS).for_each(|j| {
46 eqns.a[LOW_POLY_NUM_PARAMS * i + j] += coords[i] * coords[j];
47 });
48 });
49 });
50 });
51
52 (0..LOW_POLY_NUM_PARAMS).for_each(|i| {
54 eqns.b.fill(0.0f64);
55 eqns.b[i] = 1.0f64;
56 eqns.solve();
57
58 (0..LOW_POLY_NUM_PARAMS).for_each(|j| {
59 a_t_a_inv[j * LOW_POLY_NUM_PARAMS + i] = eqns.x[j];
60 });
61 });
62
63 FlatBlockFinder {
64 a: a.into_boxed_slice(),
65 a_t_a_inv,
66 }
67 }
68
69 #[must_use]
75 #[allow(clippy::too_many_lines)]
76 pub fn run(&self, plane: &Plane<u8>) -> (Vec<u8>, usize) {
77 const TRACE_THRESHOLD: f64 = 0.15f64 / BLOCK_SIZE_SQUARED as f64;
78 const RATIO_THRESHOLD: f64 = 1.25f64;
79 const NORM_THRESHOLD: f64 = 0.08f64 / BLOCK_SIZE_SQUARED as f64;
80 const VAR_THRESHOLD: f64 = 0.005f64 / BLOCK_SIZE_SQUARED as f64;
81
82 const VAR_WEIGHT: f64 = -6682f64;
87 const RATIO_WEIGHT: f64 = -0.2056f64;
88 const TRACE_WEIGHT: f64 = 13087f64;
89 const NORM_WEIGHT: f64 = -12434f64;
90 const OFFSET: f64 = 2.5694f64;
91
92 let num_blocks_w = (plane.cfg.width + BLOCK_SIZE - 1) / BLOCK_SIZE;
93 let num_blocks_h = (plane.cfg.height + BLOCK_SIZE - 1) / BLOCK_SIZE;
94 let num_blocks = num_blocks_w * num_blocks_h;
95 let mut flat_blocks = vec![0u8; num_blocks];
96 let mut num_flat = 0;
97 let mut plane_result = [0.0f64; BLOCK_SIZE_SQUARED];
98 let mut block_result = [0.0f64; BLOCK_SIZE_SQUARED];
99 let mut scores = vec![IndexAndScore::default(); num_blocks];
100
101 for by in 0..num_blocks_h {
102 for bx in 0..num_blocks_w {
103 let mut gxx = 0f64;
105 let mut gxy = 0f64;
106 let mut gyy = 0f64;
107 let mut var = 0f64;
108 let mut mean = 0f64;
109
110 self.extract_block(
111 plane,
112 bx * BLOCK_SIZE,
113 by * BLOCK_SIZE,
114 &mut plane_result,
115 &mut block_result,
116 );
117 for yi in 1..(BLOCK_SIZE - 1) {
118 for xi in 1..(BLOCK_SIZE - 1) {
119 unsafe {
121 let result_ptr = block_result.as_ptr().add(yi * BLOCK_SIZE + xi);
122
123 let gx = (*result_ptr.add(1) - *result_ptr.sub(1)) / 2f64;
124 let gy =
125 (*result_ptr.add(BLOCK_SIZE) - *result_ptr.sub(BLOCK_SIZE)) / 2f64;
126 gxx += gx * gx;
127 gxy += gx * gy;
128 gyy += gy * gy;
129
130 let block_val = *result_ptr;
131 mean += block_val;
132 var += block_val * block_val;
133 }
134 }
135 }
136 let block_size_norm_factor = (BLOCK_SIZE - 2).pow(2) as f64;
137 mean /= block_size_norm_factor;
138
139 gxx /= block_size_norm_factor;
141 gxy /= block_size_norm_factor;
142 gyy /= block_size_norm_factor;
143 var = mean.mul_add(-mean, var / block_size_norm_factor);
144
145 let trace = gxx + gyy;
146 let det = gxx.mul_add(gyy, -gxy.powi(2));
147 let e_sub = (trace.mul_add(trace, -4f64 * det)).max(0.).sqrt();
148 let e1 = (trace + e_sub) / 2.0f64;
149 let e2 = (trace - e_sub) / 2.0f64;
150 let norm = e1;
152 let ratio = e1 / e2.max(1.0e-6_f64);
153 let is_flat = trace < TRACE_THRESHOLD
154 && ratio < RATIO_THRESHOLD
155 && norm < NORM_THRESHOLD
156 && var > VAR_THRESHOLD;
157
158 let sum_weights = NORM_WEIGHT.mul_add(
159 norm,
160 TRACE_WEIGHT.mul_add(
161 trace,
162 VAR_WEIGHT.mul_add(var, RATIO_WEIGHT.mul_add(ratio, OFFSET)),
163 ),
164 );
165 let sum_weights = clamp(sum_weights, -25.0f64, 100.0f64);
167 let score = (1.0f64 / (1.0f64 + (-sum_weights).exp())) as f32;
168 unsafe {
170 let index = by * num_blocks_w + bx;
171 *flat_blocks.get_unchecked_mut(index) = if is_flat { 255 } else { 0 };
172 *scores.get_unchecked_mut(index) = IndexAndScore {
173 score: if var > VAR_THRESHOLD { score } else { 0f32 },
174 index,
175 };
176 }
177 if is_flat {
178 num_flat += 1;
179 }
180 }
181 }
182
183 scores.sort_unstable_by(|a, b| a.score.partial_cmp(&b.score).expect("Shouldn't be NaN"));
184 unsafe {
186 let top_nth_percentile = num_blocks * 90 / 100;
187 let score_threshold = scores.get_unchecked(top_nth_percentile).score;
188 for score in &scores {
189 if score.score >= score_threshold {
190 let block_ref = flat_blocks.get_unchecked_mut(score.index);
191 if *block_ref == 0 {
192 num_flat += 1;
193 }
194 *block_ref |= 1;
195 }
196 }
197 }
198
199 (flat_blocks, num_flat)
200 }
201
202 fn extract_block(
203 &self,
204 plane: &Plane<u8>,
205 offset_x: usize,
206 offset_y: usize,
207 plane_result: &mut [f64; BLOCK_SIZE_SQUARED],
208 block_result: &mut [f64; BLOCK_SIZE_SQUARED],
209 ) {
210 let mut plane_coords = [0f64; LOW_POLY_NUM_PARAMS];
211 let mut a_t_a_inv_b = [0f64; LOW_POLY_NUM_PARAMS];
212 let plane_origin = plane.data_origin();
213
214 for yi in 0..BLOCK_SIZE {
215 let y = clamp(offset_y + yi, 0, plane.cfg.height - 1);
216 for xi in 0..BLOCK_SIZE {
217 let x = clamp(offset_x + xi, 0, plane.cfg.width - 1);
218 unsafe {
221 *block_result.get_unchecked_mut(yi * BLOCK_SIZE + xi) =
222 f64::from(*plane_origin.get_unchecked(y * plane.cfg.stride + x))
223 / BLOCK_NORMALIZATION;
224 }
225 }
226 }
227
228 multiply_mat(
229 block_result,
230 &self.a,
231 &mut a_t_a_inv_b,
232 1,
233 BLOCK_SIZE_SQUARED,
234 LOW_POLY_NUM_PARAMS,
235 );
236 multiply_mat(
237 &self.a_t_a_inv,
238 &a_t_a_inv_b,
239 &mut plane_coords,
240 LOW_POLY_NUM_PARAMS,
241 LOW_POLY_NUM_PARAMS,
242 1,
243 );
244 multiply_mat(
245 &self.a,
246 &plane_coords,
247 plane_result,
248 BLOCK_SIZE_SQUARED,
249 LOW_POLY_NUM_PARAMS,
250 1,
251 );
252
253 for (block_res, plane_res) in block_result.iter_mut().zip(plane_result.iter()) {
254 *block_res -= *plane_res;
255 }
256 }
257}
258
259#[derive(Debug, Clone, Copy, Default)]
260struct IndexAndScore {
261 pub index: usize,
262 pub score: f32,
263}
264
265#[derive(Debug, Clone)]
267struct EquationSystem {
268 a: Vec<f64>,
269 b: Vec<f64>,
270 x: Vec<f64>,
271 n: usize,
272}
273
274impl EquationSystem {
275 #[must_use]
276 pub fn new(n: usize) -> Self {
277 Self {
278 a: vec![0.0f64; n * n],
279 b: vec![0.0f64; n],
280 x: vec![0.0f64; n],
281 n,
282 }
283 }
284
285 pub fn solve(&mut self) -> bool {
286 let n = self.n;
287 let mut a = self.a.clone();
288 let mut b = self.b.clone();
289
290 linsolve(n, &mut a, self.n, &mut b, &mut self.x)
291 }
292
293 pub fn set_chroma_coefficient_fallback_solution(&mut self) {
294 const TOLERANCE: f64 = 1.0e-6f64;
295 let last = self.n - 1;
296 self.x.fill(0f64);
299 if self.a[last * self.n + last].abs() > TOLERANCE {
300 self.x[last] = self.b[last] / self.a[last * self.n + last];
301 }
302 }
303
304 pub fn copy_from(&mut self, other: &Self) {
305 assert_eq!(self.n, other.n);
306 self.a.copy_from_slice(&other.a);
307 self.x.copy_from_slice(&other.x);
308 self.b.copy_from_slice(&other.b);
309 }
310
311 pub fn clear(&mut self) {
312 self.a.fill(0f64);
313 self.b.fill(0f64);
314 self.x.fill(0f64);
315 }
316}
317
318impl Add<&EquationSystem> for EquationSystem {
319 type Output = EquationSystem;
320
321 fn add(self, addend: &EquationSystem) -> Self::Output {
322 let mut dest = self.clone();
323 let n = self.n;
324 for i in 0..n {
325 for j in 0..n {
326 dest.a[i * n + j] += addend.a[i * n + j];
327 }
328 dest.b[i] += addend.b[i];
329 }
330 dest
331 }
332}
333
334impl AddAssign<&EquationSystem> for EquationSystem {
335 fn add_assign(&mut self, rhs: &EquationSystem) {
336 *self = self.clone() + rhs;
337 }
338}
339
340struct NoiseStrengthLut {
344 points: Vec<[f64; 2]>,
345}
346
347impl NoiseStrengthLut {
348 #[must_use]
349 pub fn new(num_bins: usize) -> Self {
350 assert!(num_bins > 0);
351 Self {
352 points: vec![[0f64; 2]; num_bins],
353 }
354 }
355}
356
357#[derive(Debug, Clone)]
358pub(super) struct NoiseModel {
359 combined_state: [NoiseModelState; 3],
360 latest_state: [NoiseModelState; 3],
361 n: usize,
362 coords: Vec<[isize; 2]>,
363}
364
365impl NoiseModel {
366 #[must_use]
367 pub fn new() -> Self {
368 let n = Self::num_coeffs();
369 let combined_state = [
370 NoiseModelState::new(n),
371 NoiseModelState::new(n + 1),
372 NoiseModelState::new(n + 1),
373 ];
374 let latest_state = [
375 NoiseModelState::new(n),
376 NoiseModelState::new(n + 1),
377 NoiseModelState::new(n + 1),
378 ];
379 let mut coords = Vec::new();
380
381 let neg_lag = -(NOISE_MODEL_LAG as isize);
382 for y in neg_lag..=0 {
383 let max_x = if y == 0 {
384 -1isize
385 } else {
386 NOISE_MODEL_LAG as isize
387 };
388 for x in neg_lag..=max_x {
389 coords.push([x, y]);
390 }
391 }
392 assert!(n == coords.len());
393
394 Self {
395 combined_state,
396 latest_state,
397 n,
398 coords,
399 }
400 }
401
402 pub fn update(
403 &mut self,
404 source: &Frame<u8>,
405 denoised: &Frame<u8>,
406 flat_blocks: &[u8],
407 ) -> NoiseStatus {
408 let num_blocks_w = (source.planes[0].cfg.width + BLOCK_SIZE - 1) / BLOCK_SIZE;
409 let num_blocks_h = (source.planes[0].cfg.height + BLOCK_SIZE - 1) / BLOCK_SIZE;
410 let mut y_model_different = false;
411
412 for i in 0..3 {
414 self.latest_state[i].eqns.clear();
415 self.latest_state[i].num_observations = 0;
416 self.latest_state[i].strength_solver.clear();
417 }
418
419 let num_blocks = flat_blocks.iter().filter(|b| **b > 0).count();
421 if num_blocks <= 1 {
422 return NoiseStatus::Error(anyhow!("Not enough flat blocks to update noise estimate"));
423 }
424
425 let frame_dims = (source.planes[0].cfg.width, source.planes[0].cfg.height);
426 for channel in 0..3 {
427 if source.planes[channel].data.is_empty() {
428 break;
430 }
431 let is_chroma = channel > 0;
432 let alt_source = (channel > 0).then(|| &source.planes[0]);
433 let alt_denoised = (channel > 0).then(|| &denoised.planes[0]);
434 self.add_block_observations(
435 channel,
436 &source.planes[channel],
437 &denoised.planes[channel],
438 alt_source,
439 alt_denoised,
440 frame_dims,
441 flat_blocks,
442 num_blocks_w,
443 num_blocks_h,
444 );
445 if !self.latest_state[channel].ar_equation_system_solve(is_chroma) {
446 if is_chroma {
447 self.latest_state[channel]
448 .eqns
449 .set_chroma_coefficient_fallback_solution();
450 } else {
451 return NoiseStatus::Error(anyhow!(
452 "Solving latest noise equation system failed on plane {}",
453 channel
454 ));
455 }
456 }
457 self.add_noise_std_observations(
458 channel,
459 &source.planes[channel],
460 &denoised.planes[channel],
461 alt_source,
462 frame_dims,
463 flat_blocks,
464 num_blocks_w,
465 num_blocks_h,
466 );
467 if !self.latest_state[channel].strength_solver.solve() {
468 return NoiseStatus::Error(anyhow!(
469 "Failed to solve strength solver for latest state"
470 ));
471 }
472
473 if channel == 0
475 && self.combined_state[channel].strength_solver.num_equations > 0
476 && self.is_different()
477 {
478 y_model_different = true;
479 }
480
481 if y_model_different {
482 continue;
483 }
484
485 self.combined_state[channel].num_observations +=
486 self.latest_state[channel].num_observations;
487 self.combined_state[channel].eqns += &self.latest_state[channel].eqns;
488 if !self.combined_state[channel].ar_equation_system_solve(is_chroma) {
489 if is_chroma {
490 self.combined_state[channel]
491 .eqns
492 .set_chroma_coefficient_fallback_solution();
493 } else {
494 return NoiseStatus::Error(anyhow!(
495 "Solving combined noise equation system failed on plane {}",
496 channel
497 ));
498 }
499 }
500
501 self.combined_state[channel].strength_solver +=
502 &self.latest_state[channel].strength_solver;
503
504 if !self.combined_state[channel].strength_solver.solve() {
505 return NoiseStatus::Error(anyhow!(
506 "Failed to solve strength solver for combined state"
507 ));
508 };
509 }
510
511 if y_model_different {
512 return NoiseStatus::DifferentType;
513 }
514
515 NoiseStatus::Ok
516 }
517
518 #[allow(clippy::too_many_lines)]
519 #[must_use]
520 pub fn get_grain_parameters(&self, start_ts: u64, end_ts: u64) -> GrainTableSegment {
521 let scaling_points_y = self.combined_state[0]
525 .strength_solver
526 .fit_piecewise(NUM_Y_POINTS)
527 .points;
528 let scaling_points_cb = self.combined_state[1]
529 .strength_solver
530 .fit_piecewise(NUM_UV_POINTS)
531 .points;
532 let scaling_points_cr = self.combined_state[2]
533 .strength_solver
534 .fit_piecewise(NUM_UV_POINTS)
535 .points;
536
537 let mut max_scaling_value: f64 = 1.0e-4f64;
538 for p in scaling_points_y
539 .iter()
540 .chain(scaling_points_cb.iter())
541 .chain(scaling_points_cr.iter())
542 .map(|p| p[1])
543 {
544 if p > max_scaling_value {
545 max_scaling_value = p;
546 }
547 }
548
549 let max_scaling_value_log2 =
551 clamp((max_scaling_value.log2() + 1f64).floor() as u8, 2u8, 5u8);
552 let scale_factor = f64::from(1u32 << (8u8 - max_scaling_value_log2));
553 let map_scaling_point = |p: [f64; 2]| {
554 [
555 (p[0] + 0.5f64) as u8,
556 clamp(scale_factor.mul_add(p[1], 0.5f64) as i32, 0i32, 255i32) as u8,
557 ]
558 };
559
560 let scaling_points_y: ArrayVec<_, NUM_Y_POINTS> = scaling_points_y
561 .into_iter()
562 .map(map_scaling_point)
563 .collect();
564 let scaling_points_cb: ArrayVec<_, NUM_UV_POINTS> = scaling_points_cb
565 .into_iter()
566 .map(map_scaling_point)
567 .collect();
568 let scaling_points_cr: ArrayVec<_, NUM_UV_POINTS> = scaling_points_cr
569 .into_iter()
570 .map(map_scaling_point)
571 .collect();
572
573 let n_coeff = self.combined_state[0].eqns.n;
575 let mut max_coeff = 1.0e-4f64;
576 let mut min_coeff = 1.0e-4f64;
577 let mut y_corr = [0f64; 2];
578 let mut avg_luma_strength = 0f64;
579 for c in 0..3 {
580 let eqns = &self.combined_state[c].eqns;
581 for i in 0..n_coeff {
582 if eqns.x[i] > max_coeff {
583 max_coeff = eqns.x[i];
584 }
585 if eqns.x[i] < min_coeff {
586 min_coeff = eqns.x[i];
587 }
588 }
589
590 let solver = &self.combined_state[c].strength_solver;
593 let mut average_strength = 0f64;
595 let mut total_weight = 0f64;
596 for i in 0..solver.eqns.n {
597 let mut w = 0f64;
598 for j in 0..solver.eqns.n {
599 w += solver.eqns.a[i * solver.eqns.n + j];
600 }
601 w = w.sqrt();
602 average_strength += solver.eqns.x[i] * w;
603 total_weight += w;
604 }
605 if total_weight.abs() < f64::EPSILON {
606 average_strength = 1f64;
607 } else {
608 average_strength /= total_weight;
609 }
610 if c == 0 {
611 avg_luma_strength = average_strength;
612 } else {
613 y_corr[c - 1] = avg_luma_strength * eqns.x[n_coeff] / average_strength;
614 max_coeff = max_coeff.max(y_corr[c - 1]);
615 min_coeff = min_coeff.min(y_corr[c - 1]);
616 }
617 }
618
619 let ar_coeff_shift = clamp(
622 7i32 - (1.0f64 + max_coeff.log2().floor()).max((-min_coeff).log2().ceil()) as i32,
623 6i32,
624 9i32,
625 ) as u8;
626 let scale_ar_coeff = f64::from(1u16 << ar_coeff_shift);
627 let ar_coeffs_y = self.get_ar_coeffs_y(n_coeff, scale_ar_coeff);
628 let ar_coeffs_cb = self.get_ar_coeffs_uv(1, n_coeff, scale_ar_coeff, y_corr);
629 let ar_coeffs_cr = self.get_ar_coeffs_uv(2, n_coeff, scale_ar_coeff, y_corr);
630
631 GrainTableSegment {
632 random_seed: if start_ts == 0 { DEFAULT_GRAIN_SEED } else { 0 },
633 start_time: start_ts,
634 end_time: end_ts,
635 ar_coeff_lag: NOISE_MODEL_LAG as u8,
636 scaling_points_y,
637 scaling_points_cb,
638 scaling_points_cr,
639 scaling_shift: 5 + (8 - max_scaling_value_log2),
640 ar_coeff_shift,
641 ar_coeffs_y,
642 ar_coeffs_cb,
643 ar_coeffs_cr,
644 cb_mult: 128,
647 cb_luma_mult: 192,
648 cb_offset: 256,
649 cr_mult: 128,
650 cr_luma_mult: 192,
651 cr_offset: 256,
652 chroma_scaling_from_luma: false,
653 grain_scale_shift: 0,
654 overlap_flag: true,
655 }
656 }
657
658 pub fn save_latest(&mut self) {
659 for c in 0..3 {
660 let latest_state = &self.latest_state[c];
661 let combined_state = &mut self.combined_state[c];
662 combined_state.eqns.copy_from(&latest_state.eqns);
663 combined_state
664 .strength_solver
665 .eqns
666 .copy_from(&latest_state.strength_solver.eqns);
667 combined_state.strength_solver.num_equations =
668 latest_state.strength_solver.num_equations;
669 combined_state.num_observations = latest_state.num_observations;
670 combined_state.ar_gain = latest_state.ar_gain;
671 }
672 }
673
674 #[must_use]
675 const fn num_coeffs() -> usize {
676 let n = 2 * NOISE_MODEL_LAG + 1;
677 (n * n) / 2
678 }
679
680 #[must_use]
681 fn get_ar_coeffs_y(&self, n_coeff: usize, scale_ar_coeff: f64) -> ArrayVec<i8, NUM_Y_COEFFS> {
682 assert!(n_coeff <= NUM_Y_COEFFS);
683 let mut coeffs = ArrayVec::new();
684 let eqns = &self.combined_state[0].eqns;
685 for i in 0..n_coeff {
686 coeffs.push(clamp((scale_ar_coeff * eqns.x[i]).round() as i32, -128i32, 127i32) as i8);
687 }
688 coeffs
689 }
690
691 #[must_use]
692 fn get_ar_coeffs_uv(
693 &self,
694 channel: usize,
695 n_coeff: usize,
696 scale_ar_coeff: f64,
697 y_corr: [f64; 2],
698 ) -> ArrayVec<i8, NUM_UV_COEFFS> {
699 assert!(n_coeff <= NUM_Y_COEFFS);
700 let mut coeffs = ArrayVec::new();
701 let eqns = &self.combined_state[channel].eqns;
702 for i in 0..n_coeff {
703 coeffs.push(clamp((scale_ar_coeff * eqns.x[i]).round() as i32, -128i32, 127i32) as i8);
704 }
705 coeffs.push(clamp(
706 (scale_ar_coeff * y_corr[channel - 1]).round() as i32,
707 -128i32,
708 127i32,
709 ) as i8);
710 coeffs
711 }
712
713 #[must_use]
718 fn is_different(&self) -> bool {
719 const COEFF_THRESHOLD: f64 = 0.9f64;
720 const STRENGTH_THRESHOLD: f64 = 0.005f64;
721
722 let latest = &self.latest_state[0];
723 let combined = &self.combined_state[0];
724 let corr = normalized_cross_correlation(&latest.eqns.x, &combined.eqns.x, combined.eqns.n);
725 if corr < COEFF_THRESHOLD {
726 return true;
727 }
728
729 let dx = 1.0f64 / latest.strength_solver.num_bins as f64;
730 let latest_eqns = &latest.strength_solver.eqns;
731 let combined_eqns = &combined.strength_solver.eqns;
732 let mut diff = 0.0f64;
733 let mut total_weight = 0.0f64;
734 for j in 0..latest_eqns.n {
735 let mut weight = 0.0f64;
736 for i in 0..latest_eqns.n {
737 weight += latest_eqns.a[i * latest_eqns.n + j];
738 }
739 weight = weight.sqrt();
740 diff += weight * (latest_eqns.x[j] - combined_eqns.x[j]).abs();
741 total_weight += weight;
742 }
743
744 diff * dx / total_weight > STRENGTH_THRESHOLD
745 }
746
747 #[allow(clippy::too_many_arguments)]
748 fn add_block_observations(
749 &mut self,
750 channel: usize,
751 source: &Plane<u8>,
752 denoised: &Plane<u8>,
753 alt_source: Option<&Plane<u8>>,
754 alt_denoised: Option<&Plane<u8>>,
755 frame_dims: (usize, usize),
756 flat_blocks: &[u8],
757 num_blocks_w: usize,
758 num_blocks_h: usize,
759 ) {
760 let num_coords = self.n;
761 let state = &mut self.latest_state[channel];
762 let a = &mut state.eqns.a;
763 let b = &mut state.eqns.b;
764 let mut buffer = vec![0f64; num_coords + 1].into_boxed_slice();
765 let n = state.eqns.n;
766 let block_w = BLOCK_SIZE >> source.cfg.xdec;
767 let block_h = BLOCK_SIZE >> source.cfg.ydec;
768
769 let dec = (source.cfg.xdec, source.cfg.ydec);
770 let stride = source.cfg.stride;
771 let source_origin = source.data_origin();
772 let denoised_origin = denoised.data_origin();
773 let alt_stride = alt_source.map_or(0, |s| s.cfg.stride);
774 let alt_source_origin = alt_source.map(|s| s.data_origin());
775 let alt_denoised_origin = alt_denoised.map(|s| s.data_origin());
776
777 for by in 0..num_blocks_h {
778 let y_o = by * block_h;
779 for bx in 0..num_blocks_w {
780 unsafe {
782 let flat_block_ptr = flat_blocks.as_ptr().add(by * num_blocks_w + bx);
783 let x_o = bx * block_w;
784 if *flat_block_ptr == 0 {
785 continue;
786 }
787 let y_start = if by > 0 && *flat_block_ptr.sub(num_blocks_w) > 0 {
788 0
789 } else {
790 NOISE_MODEL_LAG
791 };
792 let x_start = if bx > 0 && *flat_block_ptr.sub(1) > 0 {
793 0
794 } else {
795 NOISE_MODEL_LAG
796 };
797 let y_end = ((frame_dims.1 >> dec.1) - by * block_h).min(block_h);
798 let x_end = ((frame_dims.0 >> dec.0) - bx * block_w - NOISE_MODEL_LAG).min(
799 if bx + 1 < num_blocks_w && *flat_block_ptr.add(1) > 0 {
800 block_w
801 } else {
802 block_w - NOISE_MODEL_LAG
803 },
804 );
805 for y in y_start..y_end {
806 for x in x_start..x_end {
807 let val = extract_ar_row(
808 &self.coords,
809 num_coords,
810 source_origin,
811 denoised_origin,
812 stride,
813 dec,
814 alt_source_origin,
815 alt_denoised_origin,
816 alt_stride,
817 x + x_o,
818 y + y_o,
819 &mut buffer,
820 );
821 for i in 0..n {
822 for j in 0..n {
823 *a.get_unchecked_mut(i * n + j) += (*buffer.get_unchecked(i)
824 * *buffer.get_unchecked(j))
825 / BLOCK_NORMALIZATION.powi(2);
826 }
827 *b.get_unchecked_mut(i) +=
828 (*buffer.get_unchecked(i) * val) / BLOCK_NORMALIZATION.powi(2);
829 }
830 state.num_observations += 1;
831 }
832 }
833 }
834 }
835 }
836 }
837
838 #[allow(clippy::too_many_arguments)]
839 fn add_noise_std_observations(
840 &mut self,
841 channel: usize,
842 source: &Plane<u8>,
843 denoised: &Plane<u8>,
844 alt_source: Option<&Plane<u8>>,
845 frame_dims: (usize, usize),
846 flat_blocks: &[u8],
847 num_blocks_w: usize,
848 num_blocks_h: usize,
849 ) {
850 let coeffs = &self.latest_state[channel].eqns.x;
851 let num_coords = self.n;
852 let luma_gain = self.latest_state[0].ar_gain;
853 let noise_gain = self.latest_state[channel].ar_gain;
854 let block_w = BLOCK_SIZE >> source.cfg.xdec;
855 let block_h = BLOCK_SIZE >> source.cfg.ydec;
856
857 for by in 0..num_blocks_h {
858 let y_o = by * block_h;
859 for bx in 0..num_blocks_w {
860 let x_o = bx * block_w;
861 if flat_blocks[by * num_blocks_w + bx] == 0 {
862 continue;
863 }
864 let num_samples_h = ((frame_dims.1 >> source.cfg.ydec) - by * block_h).min(block_h);
865 let num_samples_w = ((frame_dims.0 >> source.cfg.xdec) - bx * block_w).min(block_w);
866 if num_samples_w * num_samples_h > BLOCK_SIZE {
869 let block_mean = get_block_mean(
870 alt_source.unwrap_or(source),
871 frame_dims,
872 x_o << source.cfg.xdec,
873 y_o << source.cfg.ydec,
874 );
875 let noise_var = get_noise_var(
876 source,
877 denoised,
878 (
879 frame_dims.0 >> source.cfg.xdec,
880 frame_dims.1 >> source.cfg.ydec,
881 ),
882 x_o,
883 y_o,
884 block_w,
885 block_h,
886 );
887 let luma_strength = if channel > 0 {
891 luma_gain * self.latest_state[0].strength_solver.get_value(block_mean)
892 } else {
893 0f64
894 };
895 let corr = if channel > 0 {
896 coeffs[num_coords]
897 } else {
898 0f64
899 };
900 let uncorr_std = (noise_var / 16f64)
907 .max((corr * luma_strength).mul_add(-(corr * luma_strength), noise_var))
908 .sqrt();
909 let adjusted_strength = uncorr_std / noise_gain;
910 self.latest_state[channel]
911 .strength_solver
912 .add_measurement(block_mean, adjusted_strength);
913 }
914 }
915 }
916 }
917}
918
919#[derive(Debug, Clone)]
920struct NoiseModelState {
921 eqns: EquationSystem,
922 ar_gain: f64,
923 num_observations: usize,
924 strength_solver: StrengthSolver,
925}
926
927impl NoiseModelState {
928 #[must_use]
929 pub fn new(n: usize) -> Self {
930 const NUM_BINS: usize = 20;
931
932 Self {
933 eqns: EquationSystem::new(n),
934 ar_gain: 1.0f64,
935 num_observations: 0usize,
936 strength_solver: StrengthSolver::new(NUM_BINS),
937 }
938 }
939
940 pub fn ar_equation_system_solve(&mut self, is_chroma: bool) -> bool {
941 let ret = self.eqns.solve();
942 self.ar_gain = 1.0f64;
943 if !ret {
944 return ret;
945 }
946
947 let mut var = 0f64;
954 let n_adjusted = self.eqns.n - usize::from(is_chroma);
955 for i in 0..n_adjusted {
956 var += self.eqns.a[i * self.eqns.n + i] / self.num_observations as f64;
957 }
958 var /= n_adjusted as f64;
959
960 let mut sum_covar = 0f64;
966 for i in 0..n_adjusted {
967 let mut bi = self.eqns.b[i];
968 if is_chroma {
969 bi -= self.eqns.a[i * self.eqns.n + n_adjusted] * self.eqns.x[n_adjusted];
970 }
971 sum_covar += (bi * self.eqns.x[i]) / self.num_observations as f64;
972 }
973
974 let noise_var = (var - sum_covar).max(1e-6f64);
977 self.ar_gain = 1f64.max((var / noise_var).max(1e-6f64).sqrt());
978 ret
979 }
980}
981
982#[derive(Debug, Clone)]
983struct StrengthSolver {
984 eqns: EquationSystem,
985 num_bins: usize,
986 num_equations: usize,
987 total: f64,
988}
989
990impl StrengthSolver {
991 #[must_use]
992 pub fn new(num_bins: usize) -> Self {
993 Self {
994 eqns: EquationSystem::new(num_bins),
995 num_bins,
996 num_equations: 0usize,
997 total: 0f64,
998 }
999 }
1000
1001 pub fn add_measurement(&mut self, block_mean: f64, noise_std: f64) {
1002 let bin = self.get_bin_index(block_mean);
1003 let bin_i0 = bin.floor() as usize;
1004 let bin_i1 = (self.num_bins - 1).min(bin_i0 + 1);
1005 let a = bin - bin_i0 as f64;
1006 let n = self.num_bins;
1007 let eqns = &mut self.eqns;
1008 eqns.a[bin_i0 * n + bin_i0] += (1f64 - a).powi(2);
1009 eqns.a[bin_i1 * n + bin_i0] += a * (1f64 - a);
1010 eqns.a[bin_i1 * n + bin_i1] += a.powi(2);
1011 eqns.a[bin_i0 * n + bin_i1] += (1f64 - a) * a;
1012 eqns.b[bin_i0] += (1f64 - a) * noise_std;
1013 eqns.b[bin_i1] += a * noise_std;
1014 self.total += noise_std;
1015 self.num_equations += 1;
1016 }
1017
1018 pub fn solve(&mut self) -> bool {
1019 let n = self.num_bins;
1021 let alpha = 2f64 * self.num_equations as f64 / n as f64;
1022
1023 let old_a = self.eqns.a.clone();
1025 for i in 0..n {
1026 let i_lo = i.saturating_sub(1);
1027 let i_hi = (n - 1).min(i + 1);
1028 self.eqns.a[i * n + i_lo] -= alpha;
1029 self.eqns.a[i * n + i] += 2f64 * alpha;
1030 self.eqns.a[i * n + i_hi] -= alpha;
1031 }
1032
1033 let mean = self.total / self.num_equations as f64;
1035 for i in 0..n {
1036 self.eqns.a[i * n + i] += 1f64 / 8192f64;
1037 self.eqns.b[i] += mean / 8192f64;
1038 }
1039 let result = self.eqns.solve();
1040 self.eqns.a = old_a;
1041 result
1042 }
1043
1044 #[must_use]
1045 pub fn fit_piecewise(&self, max_output_points: usize) -> NoiseStrengthLut {
1046 const TOLERANCE: f64 = 0.00625f64;
1047
1048 let mut lut = NoiseStrengthLut::new(self.num_bins);
1049 for i in 0..self.num_bins {
1050 lut.points[i][0] = self.get_center(i);
1051 lut.points[i][1] = self.eqns.x[i];
1052 }
1053
1054 let mut residual = vec![0.0f64; self.num_bins];
1055 self.update_piecewise_linear_residual(&lut, &mut residual, 0, self.num_bins);
1056
1057 while lut.points.len() > 2 {
1060 let mut min_index = 1usize;
1061 for j in 1..(lut.points.len() - 1) {
1062 if residual[j] < residual[min_index] {
1063 min_index = j;
1064 }
1065 }
1066 let dx = lut.points[min_index + 1][0] - lut.points[min_index - 1][0];
1067 let avg_residual = residual[min_index] / dx;
1068 if lut.points.len() <= max_output_points && avg_residual > TOLERANCE {
1069 break;
1070 }
1071
1072 lut.points.remove(min_index);
1073 self.update_piecewise_linear_residual(
1074 &lut,
1075 &mut residual,
1076 min_index - 1,
1077 min_index + 1,
1078 );
1079 }
1080
1081 lut
1082 }
1083
1084 #[must_use]
1085 pub fn get_value(&self, x: f64) -> f64 {
1086 let bin = self.get_bin_index(x);
1087 let bin_i0 = bin.floor() as usize;
1088 let bin_i1 = (self.num_bins - 1).min(bin_i0 + 1);
1089 let a = bin - bin_i0 as f64;
1090 (1f64 - a).mul_add(self.eqns.x[bin_i0], a * self.eqns.x[bin_i1])
1091 }
1092
1093 pub fn clear(&mut self) {
1094 self.eqns.clear();
1095 self.num_equations = 0;
1096 self.total = 0f64;
1097 }
1098
1099 #[must_use]
1100 fn get_bin_index(&self, value: f64) -> f64 {
1101 let max = 255f64;
1102 let val = clamp(value, 0f64, max);
1103 (self.num_bins - 1) as f64 * val / max
1104 }
1105
1106 fn update_piecewise_linear_residual(
1107 &self,
1108 lut: &NoiseStrengthLut,
1109 residual: &mut [f64],
1110 start: usize,
1111 end: usize,
1112 ) {
1113 let dx = 255f64 / self.num_bins as f64;
1114 #[allow(clippy::needless_range_loop)]
1115 for i in start.max(1)..end.min(lut.points.len() - 1) {
1116 let lower = 0usize.max(self.get_bin_index(lut.points[i - 1][0]).floor() as usize);
1117 let upper =
1118 (self.num_bins - 1).min(self.get_bin_index(lut.points[i + 1][0]).ceil() as usize);
1119 let mut r = 0f64;
1120 for j in lower..=upper {
1121 let x = self.get_center(j);
1122 if x < lut.points[i - 1][0] || x >= lut.points[i + 1][0] {
1123 continue;
1124 }
1125
1126 let y = self.eqns.x[j];
1127 let a = (x - lut.points[i - 1][0]) / (lut.points[i + 1][0] - lut.points[i - 1][0]);
1128 let estimate_y = lut.points[i - 1][1].mul_add(1f64 - a, lut.points[i + 1][1] * a);
1129 r += (y - estimate_y).abs();
1130 }
1131 residual[i] = r * dx;
1132 }
1133 }
1134
1135 #[must_use]
1136 fn get_center(&self, i: usize) -> f64 {
1137 let range = 255f64;
1138 let n = self.num_bins;
1139 i as f64 / (n - 1) as f64 * range
1140 }
1141}
1142
1143impl Add<&StrengthSolver> for StrengthSolver {
1144 type Output = StrengthSolver;
1145
1146 fn add(self, addend: &StrengthSolver) -> Self::Output {
1147 let mut dest = self;
1148 dest.eqns += &addend.eqns;
1149 dest.num_equations += addend.num_equations;
1150 dest.total += addend.total;
1151 dest
1152 }
1153}
1154
1155impl AddAssign<&StrengthSolver> for StrengthSolver {
1156 fn add_assign(&mut self, rhs: &StrengthSolver) {
1157 *self = self.clone() + rhs;
1158 }
1159}