rav1e/
mc.rs

1// Copyright (c) 2019-2022, The rav1e contributors. All rights reserved
2//
3// This source code is subject to the terms of the BSD 2 Clause License and
4// the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
5// was not distributed with this source code in the LICENSE file, you can
6// obtain it at www.aomedia.org/license/software. If the Alliance for Open
7// Media Patent License 1.0 was not distributed with this source code in the
8// PATENTS file, you can obtain it at www.aomedia.org/license/patent.
9
10cfg_if::cfg_if! {
11  if #[cfg(nasm_x86_64)] {
12    pub use crate::asm::x86::mc::*;
13  } else if #[cfg(asm_neon)] {
14    pub use crate::asm::aarch64::mc::*;
15  } else {
16    pub use self::rust::*;
17  }
18}
19
20use crate::cpu_features::CpuFeatureLevel;
21use crate::frame::*;
22use crate::tiling::*;
23use crate::util::*;
24
25use simd_helpers::cold_for_target_arch;
26use std::ops;
27
28#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
29pub struct MotionVector {
30  pub row: i16,
31  pub col: i16,
32}
33
34impl MotionVector {
35  #[inline]
36  pub const fn quantize_to_fullpel(self) -> Self {
37    Self { row: (self.row / 8) * 8, col: (self.col / 8) * 8 }
38  }
39
40  #[inline]
41  pub const fn is_zero(self) -> bool {
42    self.row == 0 && self.col == 0
43  }
44
45  #[inline]
46  pub const fn is_valid(self) -> bool {
47    use crate::context::{MV_LOW, MV_UPP};
48    ((MV_LOW as i16) < self.row && self.row < (MV_UPP as i16))
49      && ((MV_LOW as i16) < self.col && self.col < (MV_UPP as i16))
50  }
51}
52
53impl ops::Mul<i16> for MotionVector {
54  type Output = MotionVector;
55
56  #[inline]
57  fn mul(self, rhs: i16) -> MotionVector {
58    MotionVector { row: self.row * rhs, col: self.col * rhs }
59  }
60}
61
62impl ops::Mul<u16> for MotionVector {
63  type Output = MotionVector;
64
65  #[inline]
66  fn mul(self, rhs: u16) -> MotionVector {
67    MotionVector { row: self.row * rhs as i16, col: self.col * rhs as i16 }
68  }
69}
70
71impl ops::Shr<u8> for MotionVector {
72  type Output = MotionVector;
73
74  #[inline]
75  fn shr(self, rhs: u8) -> MotionVector {
76    MotionVector { row: self.row >> rhs, col: self.col >> rhs }
77  }
78}
79
80impl ops::Shl<u8> for MotionVector {
81  type Output = MotionVector;
82
83  #[inline]
84  fn shl(self, rhs: u8) -> MotionVector {
85    MotionVector { row: self.row << rhs, col: self.col << rhs }
86  }
87}
88
89impl ops::Add<MotionVector> for MotionVector {
90  type Output = MotionVector;
91
92  #[inline]
93  fn add(self, rhs: MotionVector) -> MotionVector {
94    MotionVector { row: self.row + rhs.row, col: self.col + rhs.col }
95  }
96}
97
98#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd)]
99#[allow(unused)]
100pub enum FilterMode {
101  REGULAR = 0,
102  SMOOTH = 1,
103  SHARP = 2,
104  BILINEAR = 3,
105  SWITCHABLE = 4,
106}
107
108pub const SUBPEL_FILTER_SIZE: usize = 8;
109
110const SUBPEL_FILTERS: [[[i32; SUBPEL_FILTER_SIZE]; 16]; 6] = [
111  [
112    [0, 0, 0, 128, 0, 0, 0, 0],
113    [0, 2, -6, 126, 8, -2, 0, 0],
114    [0, 2, -10, 122, 18, -4, 0, 0],
115    [0, 2, -12, 116, 28, -8, 2, 0],
116    [0, 2, -14, 110, 38, -10, 2, 0],
117    [0, 2, -14, 102, 48, -12, 2, 0],
118    [0, 2, -16, 94, 58, -12, 2, 0],
119    [0, 2, -14, 84, 66, -12, 2, 0],
120    [0, 2, -14, 76, 76, -14, 2, 0],
121    [0, 2, -12, 66, 84, -14, 2, 0],
122    [0, 2, -12, 58, 94, -16, 2, 0],
123    [0, 2, -12, 48, 102, -14, 2, 0],
124    [0, 2, -10, 38, 110, -14, 2, 0],
125    [0, 2, -8, 28, 116, -12, 2, 0],
126    [0, 0, -4, 18, 122, -10, 2, 0],
127    [0, 0, -2, 8, 126, -6, 2, 0],
128  ],
129  [
130    [0, 0, 0, 128, 0, 0, 0, 0],
131    [0, 2, 28, 62, 34, 2, 0, 0],
132    [0, 0, 26, 62, 36, 4, 0, 0],
133    [0, 0, 22, 62, 40, 4, 0, 0],
134    [0, 0, 20, 60, 42, 6, 0, 0],
135    [0, 0, 18, 58, 44, 8, 0, 0],
136    [0, 0, 16, 56, 46, 10, 0, 0],
137    [0, -2, 16, 54, 48, 12, 0, 0],
138    [0, -2, 14, 52, 52, 14, -2, 0],
139    [0, 0, 12, 48, 54, 16, -2, 0],
140    [0, 0, 10, 46, 56, 16, 0, 0],
141    [0, 0, 8, 44, 58, 18, 0, 0],
142    [0, 0, 6, 42, 60, 20, 0, 0],
143    [0, 0, 4, 40, 62, 22, 0, 0],
144    [0, 0, 4, 36, 62, 26, 0, 0],
145    [0, 0, 2, 34, 62, 28, 2, 0],
146  ],
147  [
148    [0, 0, 0, 128, 0, 0, 0, 0],
149    [-2, 2, -6, 126, 8, -2, 2, 0],
150    [-2, 6, -12, 124, 16, -6, 4, -2],
151    [-2, 8, -18, 120, 26, -10, 6, -2],
152    [-4, 10, -22, 116, 38, -14, 6, -2],
153    [-4, 10, -22, 108, 48, -18, 8, -2],
154    [-4, 10, -24, 100, 60, -20, 8, -2],
155    [-4, 10, -24, 90, 70, -22, 10, -2],
156    [-4, 12, -24, 80, 80, -24, 12, -4],
157    [-2, 10, -22, 70, 90, -24, 10, -4],
158    [-2, 8, -20, 60, 100, -24, 10, -4],
159    [-2, 8, -18, 48, 108, -22, 10, -4],
160    [-2, 6, -14, 38, 116, -22, 10, -4],
161    [-2, 6, -10, 26, 120, -18, 8, -2],
162    [-2, 4, -6, 16, 124, -12, 6, -2],
163    [0, 2, -2, 8, 126, -6, 2, -2],
164  ],
165  [
166    [0, 0, 0, 128, 0, 0, 0, 0],
167    [0, 0, 0, 120, 8, 0, 0, 0],
168    [0, 0, 0, 112, 16, 0, 0, 0],
169    [0, 0, 0, 104, 24, 0, 0, 0],
170    [0, 0, 0, 96, 32, 0, 0, 0],
171    [0, 0, 0, 88, 40, 0, 0, 0],
172    [0, 0, 0, 80, 48, 0, 0, 0],
173    [0, 0, 0, 72, 56, 0, 0, 0],
174    [0, 0, 0, 64, 64, 0, 0, 0],
175    [0, 0, 0, 56, 72, 0, 0, 0],
176    [0, 0, 0, 48, 80, 0, 0, 0],
177    [0, 0, 0, 40, 88, 0, 0, 0],
178    [0, 0, 0, 32, 96, 0, 0, 0],
179    [0, 0, 0, 24, 104, 0, 0, 0],
180    [0, 0, 0, 16, 112, 0, 0, 0],
181    [0, 0, 0, 8, 120, 0, 0, 0],
182  ],
183  [
184    [0, 0, 0, 128, 0, 0, 0, 0],
185    [0, 0, -4, 126, 8, -2, 0, 0],
186    [0, 0, -8, 122, 18, -4, 0, 0],
187    [0, 0, -10, 116, 28, -6, 0, 0],
188    [0, 0, -12, 110, 38, -8, 0, 0],
189    [0, 0, -12, 102, 48, -10, 0, 0],
190    [0, 0, -14, 94, 58, -10, 0, 0],
191    [0, 0, -12, 84, 66, -10, 0, 0],
192    [0, 0, -12, 76, 76, -12, 0, 0],
193    [0, 0, -10, 66, 84, -12, 0, 0],
194    [0, 0, -10, 58, 94, -14, 0, 0],
195    [0, 0, -10, 48, 102, -12, 0, 0],
196    [0, 0, -8, 38, 110, -12, 0, 0],
197    [0, 0, -6, 28, 116, -10, 0, 0],
198    [0, 0, -4, 18, 122, -8, 0, 0],
199    [0, 0, -2, 8, 126, -4, 0, 0],
200  ],
201  [
202    [0, 0, 0, 128, 0, 0, 0, 0],
203    [0, 0, 30, 62, 34, 2, 0, 0],
204    [0, 0, 26, 62, 36, 4, 0, 0],
205    [0, 0, 22, 62, 40, 4, 0, 0],
206    [0, 0, 20, 60, 42, 6, 0, 0],
207    [0, 0, 18, 58, 44, 8, 0, 0],
208    [0, 0, 16, 56, 46, 10, 0, 0],
209    [0, 0, 14, 54, 48, 12, 0, 0],
210    [0, 0, 12, 52, 52, 12, 0, 0],
211    [0, 0, 12, 48, 54, 14, 0, 0],
212    [0, 0, 10, 46, 56, 16, 0, 0],
213    [0, 0, 8, 44, 58, 18, 0, 0],
214    [0, 0, 6, 42, 60, 20, 0, 0],
215    [0, 0, 4, 40, 62, 22, 0, 0],
216    [0, 0, 4, 36, 62, 26, 0, 0],
217    [0, 0, 2, 34, 62, 30, 0, 0],
218  ],
219];
220
221pub(crate) mod rust {
222  use super::*;
223  use num_traits::*;
224
225  unsafe fn run_filter<T: AsPrimitive<i32>>(
226    src: *const T, stride: usize, filter: [i32; 8],
227  ) -> i32 {
228    filter
229      .iter()
230      .enumerate()
231      .map(|(i, f)| {
232        let p = src.add(i * stride);
233        f * (*p).as_()
234      })
235      .sum::<i32>()
236  }
237
238  fn get_filter(
239    mode: FilterMode, frac: i32, length: usize,
240  ) -> [i32; SUBPEL_FILTER_SIZE] {
241    let filter_idx = if mode == FilterMode::BILINEAR || length > 4 {
242      mode as usize
243    } else {
244      (mode as usize).min(1) + 4
245    };
246    SUBPEL_FILTERS[filter_idx][frac as usize]
247  }
248
249  #[cold_for_target_arch("x86_64")]
250  pub fn put_8tap<T: Pixel>(
251    dst: &mut PlaneRegionMut<'_, T>, src: PlaneSlice<'_, T>, width: usize,
252    height: usize, col_frac: i32, row_frac: i32, mode_x: FilterMode,
253    mode_y: FilterMode, bit_depth: usize, _cpu: CpuFeatureLevel,
254  ) {
255    // The assembly only supports even heights and valid uncropped widths
256    assert_eq!(height & 1, 0);
257    assert!(width.is_power_of_two() && (2..=128).contains(&width));
258
259    let ref_stride = src.plane.cfg.stride;
260    let y_filter = get_filter(mode_y, row_frac, height);
261    let x_filter = get_filter(mode_x, col_frac, width);
262    let max_sample_val = (1 << bit_depth) - 1;
263    let intermediate_bits = 4 - if bit_depth == 12 { 2 } else { 0 };
264    match (col_frac, row_frac) {
265      (0, 0) => {
266        for r in 0..height {
267          let src_slice = &src[r];
268          let dst_slice = &mut dst[r];
269          dst_slice[..width].copy_from_slice(&src_slice[..width]);
270        }
271      }
272      (0, _) => {
273        let offset_slice = src.go_up(3);
274        for r in 0..height {
275          let src_slice = &offset_slice[r];
276          let dst_slice = &mut dst[r];
277          for c in 0..width {
278            dst_slice[c] = T::cast_from(
279              round_shift(
280                // SAFETY: We pass this a raw pointer, but it's created from a
281                // checked slice, so we are safe.
282                unsafe {
283                  run_filter(src_slice[c..].as_ptr(), ref_stride, y_filter)
284                },
285                7,
286              )
287              .clamp(0, max_sample_val),
288            );
289          }
290        }
291      }
292      (_, 0) => {
293        let offset_slice = src.go_left(3);
294        for r in 0..height {
295          let src_slice = &offset_slice[r];
296          let dst_slice = &mut dst[r];
297          for c in 0..width {
298            dst_slice[c] = T::cast_from(
299              round_shift(
300                round_shift(
301                  // SAFETY: We pass this a raw pointer, but it's created from a
302                  // checked slice, so we are safe.
303                  unsafe { run_filter(src_slice[c..].as_ptr(), 1, x_filter) },
304                  7 - intermediate_bits,
305                ),
306                intermediate_bits,
307              )
308              .clamp(0, max_sample_val),
309            );
310          }
311        }
312      }
313      (_, _) => {
314        let mut intermediate: [i16; 8 * (128 + 7)] = [0; 8 * (128 + 7)];
315
316        let offset_slice = src.go_left(3).go_up(3);
317        for cg in (0..width).step_by(8) {
318          for r in 0..height + 7 {
319            let src_slice = &offset_slice[r];
320            for c in cg..(cg + 8).min(width) {
321              intermediate[8 * r + (c - cg)] = round_shift(
322                // SAFETY: We pass this a raw pointer, but it's created from a
323                // checked slice, so we are safe.
324                unsafe { run_filter(src_slice[c..].as_ptr(), 1, x_filter) },
325                7 - intermediate_bits,
326              ) as i16;
327            }
328          }
329
330          for r in 0..height {
331            let dst_slice = &mut dst[r];
332            for c in cg..(cg + 8).min(width) {
333              dst_slice[c] = T::cast_from(
334                round_shift(
335                  // SAFETY: We pass this a raw pointer, but it's created from a
336                  // checked slice, so we are safe.
337                  unsafe {
338                    run_filter(
339                      intermediate[8 * r + c - cg..].as_ptr(),
340                      8,
341                      y_filter,
342                    )
343                  },
344                  7 + intermediate_bits,
345                )
346                .clamp(0, max_sample_val),
347              );
348            }
349          }
350        }
351      }
352    }
353  }
354
355  // HBD output interval is [-20588, 36956] (10-bit), [-20602, 36983] (12-bit)
356  // Subtract PREP_BIAS to ensure result fits in i16 and matches dav1d assembly
357  const PREP_BIAS: i32 = 8192;
358
359  #[cold_for_target_arch("x86_64")]
360  pub fn prep_8tap<T: Pixel>(
361    tmp: &mut [i16], src: PlaneSlice<'_, T>, width: usize, height: usize,
362    col_frac: i32, row_frac: i32, mode_x: FilterMode, mode_y: FilterMode,
363    bit_depth: usize, _cpu: CpuFeatureLevel,
364  ) {
365    // The assembly only supports even heights and valid uncropped widths
366    assert_eq!(height & 1, 0);
367    assert!(width.is_power_of_two() && (2..=128).contains(&width));
368
369    let ref_stride = src.plane.cfg.stride;
370    let y_filter = get_filter(mode_y, row_frac, height);
371    let x_filter = get_filter(mode_x, col_frac, width);
372    let intermediate_bits = 4 - if bit_depth == 12 { 2 } else { 0 };
373    let prep_bias = if bit_depth == 8 { 0 } else { PREP_BIAS };
374    match (col_frac, row_frac) {
375      (0, 0) => {
376        for r in 0..height {
377          let src_slice = &src[r];
378          for c in 0..width {
379            tmp[r * width + c] = (i16::cast_from(src_slice[c])
380              << intermediate_bits)
381              - prep_bias as i16;
382          }
383        }
384      }
385      (0, _) => {
386        let offset_slice = src.go_up(3);
387        for r in 0..height {
388          let src_slice = &offset_slice[r];
389          for c in 0..width {
390            tmp[r * width + c] = (round_shift(
391              // SAFETY: We pass this a raw pointer, but it's created from a
392              // checked slice, so we are safe.
393              unsafe {
394                run_filter(src_slice[c..].as_ptr(), ref_stride, y_filter)
395              },
396              7 - intermediate_bits,
397            ) - prep_bias) as i16;
398          }
399        }
400      }
401      (_, 0) => {
402        let offset_slice = src.go_left(3);
403        for r in 0..height {
404          let src_slice = &offset_slice[r];
405          for c in 0..width {
406            tmp[r * width + c] = (round_shift(
407              // SAFETY: We pass this a raw pointer, but it's created from a
408              // checked slice, so we are safe.
409              unsafe { run_filter(src_slice[c..].as_ptr(), 1, x_filter) },
410              7 - intermediate_bits,
411            ) - prep_bias) as i16;
412          }
413        }
414      }
415      (_, _) => {
416        let mut intermediate: [i16; 8 * (128 + 7)] = [0; 8 * (128 + 7)];
417
418        let offset_slice = src.go_left(3).go_up(3);
419        for cg in (0..width).step_by(8) {
420          for r in 0..height + 7 {
421            let src_slice = &offset_slice[r];
422            for c in cg..(cg + 8).min(width) {
423              intermediate[8 * r + (c - cg)] = round_shift(
424                // SAFETY: We pass this a raw pointer, but it's created from a
425                // checked slice, so we are safe.
426                unsafe { run_filter(src_slice[c..].as_ptr(), 1, x_filter) },
427                7 - intermediate_bits,
428              ) as i16;
429            }
430          }
431
432          for r in 0..height {
433            for c in cg..(cg + 8).min(width) {
434              tmp[r * width + c] = (round_shift(
435                // SAFETY: We pass this a raw pointer, but it's created from a
436                // checked slice, so we are safe.
437                unsafe {
438                  run_filter(
439                    intermediate[8 * r + c - cg..].as_ptr(),
440                    8,
441                    y_filter,
442                  )
443                },
444                7,
445              ) - prep_bias) as i16;
446            }
447          }
448        }
449      }
450    }
451  }
452
453  #[cold_for_target_arch("x86_64")]
454  pub fn mc_avg<T: Pixel>(
455    dst: &mut PlaneRegionMut<'_, T>, tmp1: &[i16], tmp2: &[i16], width: usize,
456    height: usize, bit_depth: usize, _cpu: CpuFeatureLevel,
457  ) {
458    // The assembly only supports even heights and valid uncropped widths
459    assert_eq!(height & 1, 0);
460    assert!(width.is_power_of_two() && (2..=128).contains(&width));
461
462    let max_sample_val = (1 << bit_depth) - 1;
463    let intermediate_bits = 4 - if bit_depth == 12 { 2 } else { 0 };
464    let prep_bias = if bit_depth == 8 { 0 } else { PREP_BIAS * 2 };
465    for r in 0..height {
466      let dst_slice = &mut dst[r];
467      for c in 0..width {
468        dst_slice[c] = T::cast_from(
469          round_shift(
470            tmp1[r * width + c] as i32
471              + tmp2[r * width + c] as i32
472              + prep_bias,
473            intermediate_bits + 1,
474          )
475          .clamp(0, max_sample_val),
476        );
477      }
478    }
479  }
480}