rav1e/
segmentation.rs

1// Copyright (c) 2018-2023, The rav1e contributors. All rights reserved
2//
3// This source code is subject to the terms of the BSD 2 Clause License and
4// the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
5// was not distributed with this source code in the LICENSE file, you can
6// obtain it at www.aomedia.org/license/software. If the Alliance for Open
7// Media Patent License 1.0 was not distributed with this source code in the
8// PATENTS file, you can obtain it at www.aomedia.org/license/patent.
9
10use crate::context::*;
11use crate::header::PRIMARY_REF_NONE;
12use crate::partition::BlockSize;
13use crate::rdo::spatiotemporal_scale;
14use crate::rdo::DistortionScale;
15use crate::tiling::TileStateMut;
16use crate::util::Pixel;
17use crate::FrameInvariants;
18use crate::FrameState;
19
20pub const MAX_SEGMENTS: usize = 8;
21
22#[profiling::function]
23pub fn segmentation_optimize<T: Pixel>(
24  fi: &FrameInvariants<T>, fs: &mut FrameState<T>,
25) {
26  assert!(fi.enable_segmentation);
27  fs.segmentation.enabled = true;
28
29  if fs.segmentation.enabled {
30    fs.segmentation.update_map = true;
31
32    // We don't change the values between frames.
33    fs.segmentation.update_data = fi.primary_ref_frame == PRIMARY_REF_NONE;
34
35    // Avoid going into lossless mode by never bringing qidx below 1.
36    // Because base_q_idx changes more frequently than the segmentation
37    // data, it is still possible for a segment to enter lossless, so
38    // enforcement elsewhere is needed.
39    let offset_lower_limit = 1 - fi.base_q_idx as i16;
40
41    if !fs.segmentation.update_data {
42      let mut min_segment = MAX_SEGMENTS;
43      for i in 0..MAX_SEGMENTS {
44        if fs.segmentation.features[i][SegLvl::SEG_LVL_ALT_Q as usize]
45          && fs.segmentation.data[i][SegLvl::SEG_LVL_ALT_Q as usize]
46            >= offset_lower_limit
47        {
48          min_segment = i;
49          break;
50        }
51      }
52      assert_ne!(min_segment, MAX_SEGMENTS);
53      fs.segmentation.min_segment = min_segment as u8;
54      fs.segmentation.update_threshold(fi.base_q_idx, fi.config.bit_depth);
55      return;
56    }
57
58    segmentation_optimize_inner(fi, fs, offset_lower_limit);
59
60    /* Figure out parameters */
61    fs.segmentation.preskip = false;
62    fs.segmentation.last_active_segid = 0;
63    for i in 0..MAX_SEGMENTS {
64      for j in 0..SegLvl::SEG_LVL_MAX as usize {
65        if fs.segmentation.features[i][j] {
66          fs.segmentation.last_active_segid = i as u8;
67          if j >= SegLvl::SEG_LVL_REF_FRAME as usize {
68            fs.segmentation.preskip = true;
69          }
70        }
71      }
72    }
73  }
74}
75
76// Select target quantizers for each segment by fitting to log(scale).
77fn segmentation_optimize_inner<T: Pixel>(
78  fi: &FrameInvariants<T>, fs: &mut FrameState<T>, offset_lower_limit: i16,
79) {
80  use crate::quantize::{ac_q, select_ac_qi};
81  use crate::util::kmeans;
82  use arrayvec::ArrayVec;
83
84  // Minimize the total distance from a small set of values to all scales.
85  // Find k-means of log(spatiotemporal scale), k in 3..=8
86  let c: ([_; 8], [_; 7], [_; 6], [_; 5], [_; 4], [_; 3]) = {
87    let spatiotemporal_scores =
88      &fi.coded_frame_data.as_ref().unwrap().spatiotemporal_scores;
89    let mut log2_scale_q11 = Vec::with_capacity(spatiotemporal_scores.len());
90    log2_scale_q11.extend(spatiotemporal_scores.iter().map(|&s| s.blog16()));
91    log2_scale_q11.sort_unstable();
92    let l = &log2_scale_q11;
93    (kmeans(l), kmeans(l), kmeans(l), kmeans(l), kmeans(l), kmeans(l))
94  };
95
96  // Find variance in spacing between successive log(scale)
97  let var = |c: &[i16]| {
98    let delta = ArrayVec::<_, MAX_SEGMENTS>::from_iter(
99      c.iter().skip(1).zip(c).map(|(&a, &b)| b as i64 - a as i64),
100    );
101    let mean = delta.iter().sum::<i64>() / delta.len() as i64;
102    delta.iter().map(|&d| (d - mean).pow(2)).sum::<i64>() as u64
103  };
104  let variance =
105    [var(&c.0), var(&c.1), var(&c.2), var(&c.3), var(&c.4), var(&c.5)];
106
107  // Choose the k value with minimal variance in spacing
108  let min_variance = *variance.iter().min().unwrap();
109  let position = variance.iter().rposition(|&v| v == min_variance).unwrap();
110
111  // For the selected centroids, derive a target quantizer:
112  //   scale Q'^2 = Q^2
113  // See `distortion_scale_for` for more information.
114  let compute_delta = |centroids: &[i16]| {
115    use crate::util::{bexp64, blog64};
116    let log2_base_ac_q_q57 =
117      blog64(ac_q(fi.base_q_idx, 0, fi.config.bit_depth).get().into());
118    centroids
119      .iter()
120      .rev()
121      // Rewrite in log form and exponentiate:
122      //   scale Q'^2 = Q^2
123      //           Q' = Q / sqrt(scale)
124      //      log(Q') = log(Q) - 0.5 log(scale)
125      .map(|&log2_scale_q11| {
126        bexp64(log2_base_ac_q_q57 - ((log2_scale_q11 as i64) << (57 - 11 - 1)))
127      })
128      // Find the index of the nearest quantizer to the target,
129      // and take the delta from the base quantizer index.
130      .map(|q| {
131        // Avoid going into lossless mode by never bringing qidx below 1.
132        select_ac_qi(q, fi.config.bit_depth).max(1) as i16
133          - fi.base_q_idx as i16
134      })
135      .collect::<ArrayVec<_, MAX_SEGMENTS>>()
136  };
137
138  // Compute segment deltas for best value of k
139  let seg_delta = match position {
140    0 => compute_delta(&c.0),
141    1 => compute_delta(&c.1),
142    2 => compute_delta(&c.2),
143    3 => compute_delta(&c.3),
144    4 => compute_delta(&c.4),
145    _ => compute_delta(&c.5),
146  };
147
148  // Update the segmentation data
149  fs.segmentation.min_segment = 0;
150  fs.segmentation.max_segment = seg_delta.len() as u8 - 1;
151  for (&delta, (features, data)) in seg_delta
152    .iter()
153    .zip(fs.segmentation.features.iter_mut().zip(&mut fs.segmentation.data))
154  {
155    features[SegLvl::SEG_LVL_ALT_Q as usize] = true;
156    data[SegLvl::SEG_LVL_ALT_Q as usize] = delta.max(offset_lower_limit);
157  }
158
159  fs.segmentation.update_threshold(fi.base_q_idx, fi.config.bit_depth);
160}
161
162#[profiling::function]
163pub fn select_segment<T: Pixel>(
164  fi: &FrameInvariants<T>, ts: &TileStateMut<'_, T>, tile_bo: TileBlockOffset,
165  bsize: BlockSize, skip: bool,
166) -> std::ops::RangeInclusive<u8> {
167  // If skip is true or segmentation is turned off, sidx is not coded.
168  if skip || !fi.enable_segmentation {
169    return 0..=0;
170  }
171
172  use crate::api::SegmentationLevel;
173  if fi.config.speed_settings.segmentation == SegmentationLevel::Full {
174    return ts.segmentation.min_segment..=ts.segmentation.max_segment;
175  }
176
177  let frame_bo = ts.to_frame_block_offset(tile_bo);
178  let scale = spatiotemporal_scale(fi, frame_bo, bsize);
179
180  let sidx = segment_idx_from_distortion(&ts.segmentation.threshold, scale);
181
182  // Avoid going into lossless mode by never bringing qidx below 1.
183  let sidx = sidx.max(ts.segmentation.min_segment);
184
185  if fi.config.speed_settings.segmentation == SegmentationLevel::Complex {
186    return sidx..=ts.segmentation.max_segment.min(sidx.saturating_add(1));
187  }
188
189  sidx..=sidx
190}
191
192fn segment_idx_from_distortion(
193  threshold: &[DistortionScale; MAX_SEGMENTS - 1], s: DistortionScale,
194) -> u8 {
195  threshold.partition_point(|&t| s.0 < t.0) as u8
196}