rav1e/transform/
mod.rs

1// Copyright (c) 2017-2022, The rav1e contributors. All rights reserved
2//
3// This source code is subject to the terms of the BSD 2 Clause License and
4// the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
5// was not distributed with this source code in the LICENSE file, you can
6// obtain it at www.aomedia.org/license/software. If the Alliance for Open
7// Media Patent License 1.0 was not distributed with this source code in the
8// PATENTS file, you can obtain it at www.aomedia.org/license/patent.
9
10#![allow(non_camel_case_types)]
11#![allow(dead_code)]
12
13#[macro_use]
14pub mod forward_shared;
15
16pub use self::forward::forward_transform;
17pub use self::inverse::inverse_transform_add;
18
19use crate::context::MI_SIZE_LOG2;
20use crate::partition::{BlockSize, BlockSize::*};
21use crate::util::*;
22
23use TxSize::*;
24
25pub mod forward;
26pub mod inverse;
27
28pub static RAV1E_TX_TYPES: &[TxType] = &[
29  TxType::DCT_DCT,
30  TxType::ADST_DCT,
31  TxType::DCT_ADST,
32  TxType::ADST_ADST,
33  // TODO: Add a speed setting for FLIPADST
34  // TxType::FLIPADST_DCT,
35  // TxType::DCT_FLIPADST,
36  // TxType::FLIPADST_FLIPADST,
37  // TxType::ADST_FLIPADST,
38  // TxType::FLIPADST_ADST,
39  TxType::IDTX,
40  TxType::V_DCT,
41  TxType::H_DCT,
42  //TxType::V_FLIPADST,
43  //TxType::H_FLIPADST,
44];
45
46pub mod consts {
47  pub static SQRT2_BITS: usize = 12;
48  pub static SQRT2: i32 = 5793; // 2^12 * sqrt(2)
49  pub static INV_SQRT2: i32 = 2896; // 2^12 / sqrt(2)
50}
51
52pub const TX_TYPES: usize = 16;
53pub const TX_TYPES_PLUS_LL: usize = 17;
54
55#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord)]
56pub enum TxType {
57  DCT_DCT = 0,   // DCT  in both horizontal and vertical
58  ADST_DCT = 1,  // ADST in vertical, DCT in horizontal
59  DCT_ADST = 2,  // DCT  in vertical, ADST in horizontal
60  ADST_ADST = 3, // ADST in both directions
61  FLIPADST_DCT = 4,
62  DCT_FLIPADST = 5,
63  FLIPADST_FLIPADST = 6,
64  ADST_FLIPADST = 7,
65  FLIPADST_ADST = 8,
66  IDTX = 9,
67  V_DCT = 10,
68  H_DCT = 11,
69  V_ADST = 12,
70  H_ADST = 13,
71  V_FLIPADST = 14,
72  H_FLIPADST = 15,
73  WHT_WHT = 16,
74}
75
76impl TxType {
77  /// Compute transform type for inter chroma.
78  ///
79  /// <https://aomediacodec.github.io/av1-spec/#compute-transform-type-function>
80  #[inline]
81  pub fn uv_inter(self, uv_tx_size: TxSize) -> Self {
82    use TxType::*;
83    if uv_tx_size.sqr_up() == TX_32X32 {
84      match self {
85        IDTX => IDTX,
86        _ => DCT_DCT,
87      }
88    } else if uv_tx_size.sqr() == TX_16X16 {
89      match self {
90        V_ADST | H_ADST | V_FLIPADST | H_FLIPADST => DCT_DCT,
91        _ => self,
92      }
93    } else {
94      self
95    }
96  }
97}
98
99/// Transform Size
100#[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Eq, Ord)]
101pub enum TxSize {
102  TX_4X4,
103  TX_8X8,
104  TX_16X16,
105  TX_32X32,
106  TX_64X64,
107
108  TX_4X8,
109  TX_8X4,
110  TX_8X16,
111  TX_16X8,
112  TX_16X32,
113  TX_32X16,
114  TX_32X64,
115  TX_64X32,
116
117  TX_4X16,
118  TX_16X4,
119  TX_8X32,
120  TX_32X8,
121  TX_16X64,
122  TX_64X16,
123}
124
125impl TxSize {
126  /// Number of square transform sizes
127  pub const TX_SIZES: usize = 5;
128
129  /// Number of transform sizes (including non-square sizes)
130  pub const TX_SIZES_ALL: usize = 14 + 5;
131
132  #[inline]
133  pub const fn width(self) -> usize {
134    1 << self.width_log2()
135  }
136
137  #[inline]
138  pub const fn width_log2(self) -> usize {
139    match self {
140      TX_4X4 | TX_4X8 | TX_4X16 => 2,
141      TX_8X8 | TX_8X4 | TX_8X16 | TX_8X32 => 3,
142      TX_16X16 | TX_16X8 | TX_16X32 | TX_16X4 | TX_16X64 => 4,
143      TX_32X32 | TX_32X16 | TX_32X64 | TX_32X8 => 5,
144      TX_64X64 | TX_64X32 | TX_64X16 => 6,
145    }
146  }
147
148  #[inline]
149  pub const fn width_index(self) -> usize {
150    self.width_log2() - TX_4X4.width_log2()
151  }
152
153  #[inline]
154  pub const fn height(self) -> usize {
155    1 << self.height_log2()
156  }
157
158  #[inline]
159  pub const fn height_log2(self) -> usize {
160    match self {
161      TX_4X4 | TX_8X4 | TX_16X4 => 2,
162      TX_8X8 | TX_4X8 | TX_16X8 | TX_32X8 => 3,
163      TX_16X16 | TX_8X16 | TX_32X16 | TX_4X16 | TX_64X16 => 4,
164      TX_32X32 | TX_16X32 | TX_64X32 | TX_8X32 => 5,
165      TX_64X64 | TX_32X64 | TX_16X64 => 6,
166    }
167  }
168
169  #[inline]
170  pub const fn height_index(self) -> usize {
171    self.height_log2() - TX_4X4.height_log2()
172  }
173
174  #[inline]
175  pub const fn width_mi(self) -> usize {
176    self.width() >> MI_SIZE_LOG2
177  }
178
179  #[inline]
180  pub const fn area(self) -> usize {
181    1 << self.area_log2()
182  }
183
184  #[inline]
185  pub const fn area_log2(self) -> usize {
186    self.width_log2() + self.height_log2()
187  }
188
189  #[inline]
190  pub const fn height_mi(self) -> usize {
191    self.height() >> MI_SIZE_LOG2
192  }
193
194  #[inline]
195  pub const fn block_size(self) -> BlockSize {
196    match self {
197      TX_4X4 => BLOCK_4X4,
198      TX_8X8 => BLOCK_8X8,
199      TX_16X16 => BLOCK_16X16,
200      TX_32X32 => BLOCK_32X32,
201      TX_64X64 => BLOCK_64X64,
202      TX_4X8 => BLOCK_4X8,
203      TX_8X4 => BLOCK_8X4,
204      TX_8X16 => BLOCK_8X16,
205      TX_16X8 => BLOCK_16X8,
206      TX_16X32 => BLOCK_16X32,
207      TX_32X16 => BLOCK_32X16,
208      TX_32X64 => BLOCK_32X64,
209      TX_64X32 => BLOCK_64X32,
210      TX_4X16 => BLOCK_4X16,
211      TX_16X4 => BLOCK_16X4,
212      TX_8X32 => BLOCK_8X32,
213      TX_32X8 => BLOCK_32X8,
214      TX_16X64 => BLOCK_16X64,
215      TX_64X16 => BLOCK_64X16,
216    }
217  }
218
219  #[inline]
220  pub const fn sqr(self) -> TxSize {
221    match self {
222      TX_4X4 | TX_4X8 | TX_8X4 | TX_4X16 | TX_16X4 => TX_4X4,
223      TX_8X8 | TX_8X16 | TX_16X8 | TX_8X32 | TX_32X8 => TX_8X8,
224      TX_16X16 | TX_16X32 | TX_32X16 | TX_16X64 | TX_64X16 => TX_16X16,
225      TX_32X32 | TX_32X64 | TX_64X32 => TX_32X32,
226      TX_64X64 => TX_64X64,
227    }
228  }
229
230  #[inline]
231  pub const fn sqr_up(self) -> TxSize {
232    match self {
233      TX_4X4 => TX_4X4,
234      TX_8X8 | TX_4X8 | TX_8X4 => TX_8X8,
235      TX_16X16 | TX_8X16 | TX_16X8 | TX_4X16 | TX_16X4 => TX_16X16,
236      TX_32X32 | TX_16X32 | TX_32X16 | TX_8X32 | TX_32X8 => TX_32X32,
237      TX_64X64 | TX_32X64 | TX_64X32 | TX_16X64 | TX_64X16 => TX_64X64,
238    }
239  }
240
241  #[inline]
242  pub fn by_dims(w: usize, h: usize) -> TxSize {
243    match (w, h) {
244      (4, 4) => TX_4X4,
245      (8, 8) => TX_8X8,
246      (16, 16) => TX_16X16,
247      (32, 32) => TX_32X32,
248      (64, 64) => TX_64X64,
249      (4, 8) => TX_4X8,
250      (8, 4) => TX_8X4,
251      (8, 16) => TX_8X16,
252      (16, 8) => TX_16X8,
253      (16, 32) => TX_16X32,
254      (32, 16) => TX_32X16,
255      (32, 64) => TX_32X64,
256      (64, 32) => TX_64X32,
257      (4, 16) => TX_4X16,
258      (16, 4) => TX_16X4,
259      (8, 32) => TX_8X32,
260      (32, 8) => TX_32X8,
261      (16, 64) => TX_16X64,
262      (64, 16) => TX_64X16,
263      _ => unreachable!(),
264    }
265  }
266
267  #[inline]
268  pub const fn is_rect(self) -> bool {
269    self.width_log2() != self.height_log2()
270  }
271}
272
273#[derive(Copy, Clone, PartialEq, Eq, PartialOrd)]
274pub enum TxSet {
275  // DCT only
276  TX_SET_DCTONLY,
277  // DCT + Identity only
278  TX_SET_INTER_3, // TX_SET_DCT_IDTX
279  // Discrete Trig transforms w/o flip (4) + Identity (1)
280  TX_SET_INTRA_2, // TX_SET_DTT4_IDTX
281  // Discrete Trig transforms w/o flip (4) + Identity (1) + 1D Hor/vert DCT (2)
282  TX_SET_INTRA_1, // TX_SET_DTT4_IDTX_1DDCT
283  // Discrete Trig transforms w/ flip (9) + Identity (1) + 1D Hor/Ver DCT (2)
284  TX_SET_INTER_2, // TX_SET_DTT9_IDTX_1DDCT
285  // Discrete Trig transforms w/ flip (9) + Identity (1) + 1D Hor/Ver (6)
286  TX_SET_INTER_1, // TX_SET_ALL16
287}
288
289/// Utility function that returns the log of the ratio of the col and row sizes.
290#[inline]
291pub fn get_rect_tx_log_ratio(col: usize, row: usize) -> i8 {
292  debug_assert!(col > 0 && row > 0);
293  ILog::ilog(col) as i8 - ILog::ilog(row) as i8
294}
295
296// performs half a butterfly
297#[inline]
298const fn half_btf(w0: i32, in0: i32, w1: i32, in1: i32, bit: usize) -> i32 {
299  // Ensure defined behaviour for when w0*in0 + w1*in1 is negative and
300  //   overflows, but w0*in0 + w1*in1 + rounding isn't.
301  let result = (w0 * in0).wrapping_add(w1 * in1);
302  // Implement a version of round_shift with wrapping
303  if bit == 0 {
304    result
305  } else {
306    result.wrapping_add(1 << (bit - 1)) >> bit
307  }
308}
309
310// clamps value to a signed integer type of bit bits
311#[inline]
312fn clamp_value(value: i32, bit: usize) -> i32 {
313  let max_value: i32 = ((1i64 << (bit - 1)) - 1) as i32;
314  let min_value: i32 = (-(1i64 << (bit - 1))) as i32;
315  clamp(value, min_value, max_value)
316}
317
318pub fn av1_round_shift_array(arr: &mut [i32], size: usize, bit: i8) {
319  if bit == 0 {
320    return;
321  }
322  if bit > 0 {
323    let bit = bit as usize;
324    arr.iter_mut().take(size).for_each(|i| {
325      *i = round_shift(*i, bit);
326    })
327  } else {
328    arr.iter_mut().take(size).for_each(|i| {
329      *i <<= -bit;
330    })
331  }
332}
333
334#[derive(Debug, Clone, Copy)]
335enum TxType1D {
336  DCT,
337  ADST,
338  FLIPADST,
339  IDTX,
340  WHT,
341}
342
343const fn get_1d_tx_types(tx_type: TxType) -> (TxType1D, TxType1D) {
344  match tx_type {
345    TxType::DCT_DCT => (TxType1D::DCT, TxType1D::DCT),
346    TxType::ADST_DCT => (TxType1D::ADST, TxType1D::DCT),
347    TxType::DCT_ADST => (TxType1D::DCT, TxType1D::ADST),
348    TxType::ADST_ADST => (TxType1D::ADST, TxType1D::ADST),
349    TxType::FLIPADST_DCT => (TxType1D::FLIPADST, TxType1D::DCT),
350    TxType::DCT_FLIPADST => (TxType1D::DCT, TxType1D::FLIPADST),
351    TxType::FLIPADST_FLIPADST => (TxType1D::FLIPADST, TxType1D::FLIPADST),
352    TxType::ADST_FLIPADST => (TxType1D::ADST, TxType1D::FLIPADST),
353    TxType::FLIPADST_ADST => (TxType1D::FLIPADST, TxType1D::ADST),
354    TxType::IDTX => (TxType1D::IDTX, TxType1D::IDTX),
355    TxType::V_DCT => (TxType1D::DCT, TxType1D::IDTX),
356    TxType::H_DCT => (TxType1D::IDTX, TxType1D::DCT),
357    TxType::V_ADST => (TxType1D::ADST, TxType1D::IDTX),
358    TxType::H_ADST => (TxType1D::IDTX, TxType1D::ADST),
359    TxType::V_FLIPADST => (TxType1D::FLIPADST, TxType1D::IDTX),
360    TxType::H_FLIPADST => (TxType1D::IDTX, TxType1D::FLIPADST),
361    TxType::WHT_WHT => (TxType1D::WHT, TxType1D::WHT),
362  }
363}
364
365const VTX_TAB: [TxType1D; TX_TYPES_PLUS_LL] = [
366  TxType1D::DCT,
367  TxType1D::ADST,
368  TxType1D::DCT,
369  TxType1D::ADST,
370  TxType1D::FLIPADST,
371  TxType1D::DCT,
372  TxType1D::FLIPADST,
373  TxType1D::ADST,
374  TxType1D::FLIPADST,
375  TxType1D::IDTX,
376  TxType1D::DCT,
377  TxType1D::IDTX,
378  TxType1D::ADST,
379  TxType1D::IDTX,
380  TxType1D::FLIPADST,
381  TxType1D::IDTX,
382  TxType1D::WHT,
383];
384
385const HTX_TAB: [TxType1D; TX_TYPES_PLUS_LL] = [
386  TxType1D::DCT,
387  TxType1D::DCT,
388  TxType1D::ADST,
389  TxType1D::ADST,
390  TxType1D::DCT,
391  TxType1D::FLIPADST,
392  TxType1D::FLIPADST,
393  TxType1D::FLIPADST,
394  TxType1D::ADST,
395  TxType1D::IDTX,
396  TxType1D::IDTX,
397  TxType1D::DCT,
398  TxType1D::IDTX,
399  TxType1D::ADST,
400  TxType1D::IDTX,
401  TxType1D::FLIPADST,
402  TxType1D::WHT,
403];
404
405#[inline]
406pub const fn valid_av1_transform(tx_size: TxSize, tx_type: TxType) -> bool {
407  let size_sq = tx_size.sqr_up();
408  use TxSize::*;
409  use TxType::*;
410  match (size_sq, tx_type) {
411    (TX_64X64, DCT_DCT) => true,
412    (TX_64X64, _) => false,
413    (TX_32X32, DCT_DCT) => true,
414    (TX_32X32, IDTX) => true,
415    (TX_32X32, _) => false,
416    (_, _) => true,
417  }
418}
419
420#[cfg(any(test, feature = "bench"))]
421pub fn get_valid_txfm_types(tx_size: TxSize) -> &'static [TxType] {
422  let size_sq = tx_size.sqr_up();
423  use TxType::*;
424  if size_sq == TxSize::TX_64X64 {
425    &[DCT_DCT]
426  } else if size_sq == TxSize::TX_32X32 {
427    &[DCT_DCT, IDTX]
428  } else if size_sq == TxSize::TX_4X4 {
429    &[
430      DCT_DCT,
431      ADST_DCT,
432      DCT_ADST,
433      ADST_ADST,
434      FLIPADST_DCT,
435      DCT_FLIPADST,
436      FLIPADST_FLIPADST,
437      ADST_FLIPADST,
438      FLIPADST_ADST,
439      IDTX,
440      V_DCT,
441      H_DCT,
442      V_ADST,
443      H_ADST,
444      V_FLIPADST,
445      H_FLIPADST,
446      WHT_WHT,
447    ]
448  } else {
449    &[
450      DCT_DCT,
451      ADST_DCT,
452      DCT_ADST,
453      ADST_ADST,
454      FLIPADST_DCT,
455      DCT_FLIPADST,
456      FLIPADST_FLIPADST,
457      ADST_FLIPADST,
458      FLIPADST_ADST,
459      IDTX,
460      V_DCT,
461      H_DCT,
462      V_ADST,
463      H_ADST,
464      V_FLIPADST,
465      H_FLIPADST,
466    ]
467  }
468}
469
470#[cfg(test)]
471mod test {
472  use super::TxType::*;
473  use super::*;
474  use crate::context::av1_get_coded_tx_size;
475  use crate::cpu_features::CpuFeatureLevel;
476  use crate::frame::*;
477  use rand::random;
478  use std::mem::MaybeUninit;
479
480  fn test_roundtrip<T: Pixel>(
481    tx_size: TxSize, tx_type: TxType, tolerance: i16,
482  ) {
483    let cpu = CpuFeatureLevel::default();
484
485    let coeff_area: usize = av1_get_coded_tx_size(tx_size).area();
486    let mut src_storage = [T::cast_from(0); 64 * 64];
487    let src = &mut src_storage[..tx_size.area()];
488    let mut dst = Plane::from_slice(
489      &[T::zero(); 64 * 64][..tx_size.area()],
490      tx_size.width(),
491    );
492    let mut res_storage = [0i16; 64 * 64];
493    let res = &mut res_storage[..tx_size.area()];
494    let mut freq_storage = [MaybeUninit::uninit(); 64 * 64];
495    let freq = &mut freq_storage[..tx_size.area()];
496    for ((r, s), d) in
497      res.iter_mut().zip(src.iter_mut()).zip(dst.data.iter_mut())
498    {
499      *s = T::cast_from(random::<u8>());
500      *d = T::cast_from(random::<u8>());
501      *r = i16::cast_from(*s) - i16::cast_from(*d);
502    }
503    forward_transform(res, freq, tx_size.width(), tx_size, tx_type, 8, cpu);
504    // SAFETY: forward_transform initialized freq
505    let freq = unsafe { slice_assume_init_mut(freq) };
506    inverse_transform_add(
507      freq,
508      &mut dst.as_region_mut(),
509      coeff_area.try_into().unwrap(),
510      tx_size,
511      tx_type,
512      8,
513      cpu,
514    );
515
516    for (s, d) in src.iter().zip(dst.data.iter()) {
517      assert!(i16::abs(i16::cast_from(*s) - i16::cast_from(*d)) <= tolerance);
518    }
519  }
520
521  #[test]
522  fn log_tx_ratios() {
523    let combinations = [
524      (TxSize::TX_4X4, 0),
525      (TxSize::TX_8X8, 0),
526      (TxSize::TX_16X16, 0),
527      (TxSize::TX_32X32, 0),
528      (TxSize::TX_64X64, 0),
529      (TxSize::TX_4X8, -1),
530      (TxSize::TX_8X4, 1),
531      (TxSize::TX_8X16, -1),
532      (TxSize::TX_16X8, 1),
533      (TxSize::TX_16X32, -1),
534      (TxSize::TX_32X16, 1),
535      (TxSize::TX_32X64, -1),
536      (TxSize::TX_64X32, 1),
537      (TxSize::TX_4X16, -2),
538      (TxSize::TX_16X4, 2),
539      (TxSize::TX_8X32, -2),
540      (TxSize::TX_32X8, 2),
541      (TxSize::TX_16X64, -2),
542      (TxSize::TX_64X16, 2),
543    ];
544
545    for &(tx_size, expected) in combinations.iter() {
546      println!(
547        "Testing combination {:?}, {:?}",
548        tx_size.width(),
549        tx_size.height()
550      );
551      assert!(
552        get_rect_tx_log_ratio(tx_size.width(), tx_size.height()) == expected
553      );
554    }
555  }
556
557  fn roundtrips<T: Pixel>() {
558    let combinations = [
559      (TX_4X4, WHT_WHT, 0),
560      (TX_4X4, DCT_DCT, 0),
561      (TX_4X4, ADST_DCT, 0),
562      (TX_4X4, DCT_ADST, 0),
563      (TX_4X4, ADST_ADST, 0),
564      (TX_4X4, FLIPADST_DCT, 0),
565      (TX_4X4, DCT_FLIPADST, 0),
566      (TX_4X4, IDTX, 0),
567      (TX_4X4, V_DCT, 0),
568      (TX_4X4, H_DCT, 0),
569      (TX_4X4, V_ADST, 0),
570      (TX_4X4, H_ADST, 0),
571      (TX_8X8, DCT_DCT, 1),
572      (TX_8X8, ADST_DCT, 1),
573      (TX_8X8, DCT_ADST, 1),
574      (TX_8X8, ADST_ADST, 1),
575      (TX_8X8, FLIPADST_DCT, 1),
576      (TX_8X8, DCT_FLIPADST, 1),
577      (TX_8X8, IDTX, 0),
578      (TX_8X8, V_DCT, 0),
579      (TX_8X8, H_DCT, 0),
580      (TX_8X8, V_ADST, 0),
581      (TX_8X8, H_ADST, 1),
582      (TX_16X16, DCT_DCT, 1),
583      (TX_16X16, ADST_DCT, 1),
584      (TX_16X16, DCT_ADST, 1),
585      (TX_16X16, ADST_ADST, 1),
586      (TX_16X16, FLIPADST_DCT, 1),
587      (TX_16X16, DCT_FLIPADST, 1),
588      (TX_16X16, IDTX, 0),
589      (TX_16X16, V_DCT, 1),
590      (TX_16X16, H_DCT, 1),
591      // 32x transforms only use DCT_DCT and IDTX
592      (TX_32X32, DCT_DCT, 2),
593      (TX_32X32, IDTX, 0),
594      // 64x transforms only use DCT_DCT and IDTX
595      //(TX_64X64, DCT_DCT, 0),
596      (TX_4X8, DCT_DCT, 1),
597      (TX_8X4, DCT_DCT, 1),
598      (TX_4X16, DCT_DCT, 1),
599      (TX_16X4, DCT_DCT, 1),
600      (TX_8X16, DCT_DCT, 1),
601      (TX_16X8, DCT_DCT, 1),
602      (TX_8X32, DCT_DCT, 2),
603      (TX_32X8, DCT_DCT, 2),
604      (TX_16X32, DCT_DCT, 2),
605      (TX_32X16, DCT_DCT, 2),
606    ];
607    for &(tx_size, tx_type, tolerance) in combinations.iter() {
608      println!("Testing combination {:?}, {:?}", tx_size, tx_type);
609      test_roundtrip::<T>(tx_size, tx_type, tolerance);
610    }
611  }
612
613  #[test]
614  fn roundtrips_u8() {
615    roundtrips::<u8>();
616  }
617
618  #[test]
619  fn roundtrips_u16() {
620    roundtrips::<u16>();
621  }
622}