av1_grain/diff/solver/
util.rs

1use std::ptr;
2
3use v_frame::plane::Plane;
4
5use crate::diff::BLOCK_SIZE;
6
7/// Solves Ax = b, where x and b are column vectors of size nx1 and A is nxn
8#[allow(clippy::many_single_char_names)]
9pub(super) fn linsolve(
10    n: usize,
11    a: &mut [f64],
12    stride: usize,
13    b: &mut [f64],
14    x: &mut [f64],
15) -> bool {
16    // SAFETY: We need to ensure that `n` doesn't exceed the bounds of these arrays.
17    // But this is a crate-private function, so we control all input.
18    unsafe {
19        // Forward elimination
20        for k in 0..(n - 1) {
21            // Bring the largest magnitude to the diagonal position
22            ((k + 1)..n).rev().for_each(|i| {
23                if a.get_unchecked((i - 1) * stride + k).abs()
24                    < a.get_unchecked(i * stride + k).abs()
25                {
26                    (0..n).for_each(|j| {
27                        swap_unchecked(a, i * stride + j, (i - 1) * stride + j);
28                    });
29                    swap_unchecked(b, i, i - 1);
30                }
31            });
32
33            for i in k..(n - 1) {
34                if a.get_unchecked(k * stride + k).abs() < f64::EPSILON {
35                    return false;
36                }
37                let c = *a.get_unchecked((i + 1) * stride + k) / *a.get_unchecked(k * stride + k);
38                (0..n).for_each(|j| {
39                    let a2_val = *a.get_unchecked(k * stride + j);
40                    let a_val = a.get_unchecked_mut((i + 1) * stride + j);
41                    *a_val = c.mul_add(-a2_val, *a_val);
42                });
43                let b2_val = *b.get_unchecked(k);
44                let b_val = b.get_unchecked_mut(i + 1);
45                *b_val = c.mul_add(-b2_val, *b_val);
46            }
47        }
48
49        // Backward substitution
50        for i in (0..n).rev() {
51            if a.get_unchecked(i * stride + i).abs() < f64::EPSILON {
52                return false;
53            }
54            let mut c = 0.0f64;
55            for j in (i + 1)..n {
56                c = a
57                    .get_unchecked(i * stride + j)
58                    .mul_add(*x.get_unchecked(j), c);
59            }
60            *x.get_unchecked_mut(i) = (*b.get_unchecked(i) - c) / *a.get_unchecked(i * stride + i);
61        }
62    }
63
64    true
65}
66
67// TODO: This is unstable upstream. Once it's stable upstream, use that.
68unsafe fn swap_unchecked<T>(slice: &mut [T], a: usize, b: usize) {
69    let ptr = slice.as_mut_ptr();
70    // SAFETY: caller has to guarantee that `a < self.len()` and `b < self.len()`
71    unsafe {
72        ptr::swap(ptr.add(a), ptr.add(b));
73    }
74}
75
76pub(super) fn multiply_mat(
77    m1: &[f64],
78    m2: &[f64],
79    res: &mut [f64],
80    m1_rows: usize,
81    inner_dim: usize,
82    m2_cols: usize,
83) {
84    assert!(res.len() >= m1_rows * m2_cols);
85    assert!(m1.len() >= m1_rows * inner_dim);
86    assert!(m2.len() >= m2_cols * inner_dim);
87    let mut idx = 0;
88    for row in 0..m1_rows {
89        for col in 0..m2_cols {
90            let mut sum = 0f64;
91            for inner in 0..inner_dim {
92                // SAFETY: We do the bounds checks once at the top to improve performance.
93                unsafe {
94                    sum += m1.get_unchecked(row * inner_dim + inner)
95                        * m2.get_unchecked(inner * m2_cols + col);
96                }
97            }
98            // SAFETY: We do the bounds checks once at the top to improve performance.
99            unsafe {
100                *res.get_unchecked_mut(idx) = sum;
101            }
102            idx += 1;
103        }
104    }
105}
106
107#[must_use]
108pub(super) fn normalized_cross_correlation(a: &[f64], b: &[f64], n: usize) -> f64 {
109    let mut c = 0f64;
110    let mut a_len = 0f64;
111    let mut b_len = 0f64;
112    for (a, b) in a.iter().zip(b.iter()).take(n) {
113        a_len = (*a).mul_add(*a, a_len);
114        b_len = (*b).mul_add(*b, b_len);
115        c = (*a).mul_add(*b, c);
116    }
117    c / (a_len.sqrt() * b_len.sqrt())
118}
119
120#[allow(clippy::too_many_arguments)]
121pub(super) fn extract_ar_row(
122    coords: &[[isize; 2]],
123    num_coords: usize,
124    source_origin: &[u8],
125    denoised_origin: &[u8],
126    stride: usize,
127    dec: (usize, usize),
128    alt_source_origin: Option<&[u8]>,
129    alt_denoised_origin: Option<&[u8]>,
130    alt_stride: usize,
131    x: usize,
132    y: usize,
133    buffer: &mut [f64],
134) -> f64 {
135    debug_assert!(buffer.len() > num_coords);
136    debug_assert!(coords.len() >= num_coords);
137
138    // SAFETY: We know the indexes we provide do not overflow the data bounds
139    unsafe {
140        for i in 0..num_coords {
141            let x_i = x as isize + coords.get_unchecked(i)[0];
142            let y_i = y as isize + coords.get_unchecked(i)[1];
143            debug_assert!(x_i >= 0);
144            debug_assert!(y_i >= 0);
145            let index = y_i as usize * stride + x_i as usize;
146            *buffer.get_unchecked_mut(i) = f64::from(*source_origin.get_unchecked(index))
147                - f64::from(*denoised_origin.get_unchecked(index));
148        }
149        let val = f64::from(*source_origin.get_unchecked(y * stride + x))
150            - f64::from(*denoised_origin.get_unchecked(y * stride + x));
151
152        if let Some(alt_source_origin) = alt_source_origin {
153            if let Some(alt_denoised_origin) = alt_denoised_origin {
154                let mut source_sum = 0u64;
155                let mut denoised_sum = 0u64;
156                let mut num_samples = 0usize;
157
158                for dy_i in 0..(1 << dec.1) {
159                    let y_up = (y << dec.1) + dy_i;
160                    for dx_i in 0..(1 << dec.0) {
161                        let x_up = (x << dec.0) + dx_i;
162                        let index = y_up * alt_stride + x_up;
163                        source_sum += u64::from(*alt_source_origin.get_unchecked(index));
164                        denoised_sum += u64::from(*alt_denoised_origin.get_unchecked(index));
165                        num_samples += 1;
166                    }
167                }
168                *buffer.get_unchecked_mut(num_coords) =
169                    (source_sum as f64 - denoised_sum as f64) / num_samples as f64;
170            }
171        }
172
173        val
174    }
175}
176
177#[must_use]
178pub(super) fn get_block_mean(
179    source: &Plane<u8>,
180    frame_dims: (usize, usize),
181    x_o: usize,
182    y_o: usize,
183) -> f64 {
184    let max_h = (frame_dims.1 - y_o).min(BLOCK_SIZE);
185    let max_w = (frame_dims.0 - x_o).min(BLOCK_SIZE);
186
187    let data_origin = source.data_origin();
188    let mut block_sum = 0u64;
189    for y in 0..max_h {
190        for x in 0..max_w {
191            let index = (y_o + y) * source.cfg.stride + x_o + x;
192            // SAFETY: We know the index cannot exceed the dimensions of the plane data
193            unsafe {
194                block_sum += u64::from(*data_origin.get_unchecked(index));
195            }
196        }
197    }
198
199    block_sum as f64 / (max_w * max_h) as f64
200}
201
202#[must_use]
203pub(super) fn get_noise_var(
204    source: &Plane<u8>,
205    denoised: &Plane<u8>,
206    frame_dims: (usize, usize),
207    x_o: usize,
208    y_o: usize,
209    block_w: usize,
210    block_h: usize,
211) -> f64 {
212    let max_h = (frame_dims.1 - y_o).min(block_h);
213    let max_w = (frame_dims.0 - x_o).min(block_w);
214
215    let source_origin = source.data_origin();
216    let denoised_origin = denoised.data_origin();
217    let mut noise_var_sum = 0u64;
218    let mut noise_sum = 0i64;
219    for y in 0..max_h {
220        for x in 0..max_w {
221            let index = (y_o + y) * source.cfg.stride + x_o + x;
222            // SAFETY: We know the index cannot exceed the dimensions of the plane data
223            unsafe {
224                let noise = i64::from(*source_origin.get_unchecked(index))
225                    - i64::from(*denoised_origin.get_unchecked(index));
226                noise_sum += noise;
227                noise_var_sum += noise.pow(2) as u64;
228            }
229        }
230    }
231
232    let noise_mean = noise_sum as f64 / (max_w * max_h) as f64;
233    noise_mean.mul_add(-noise_mean, noise_var_sum as f64 / (max_w * max_h) as f64)
234}