rav1e/context/
cdf_context.rs

1// Copyright (c) 2017-2022, The rav1e contributors. All rights reserved
2//
3// This source code is subject to the terms of the BSD 2 Clause License and
4// the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
5// was not distributed with this source code in the LICENSE file, you can
6// obtain it at www.aomedia.org/license/software. If the Alliance for Open
7// Media Patent License 1.0 was not distributed with this source code in the
8// PATENTS file, you can obtain it at www.aomedia.org/license/patent.
9
10use super::*;
11use std::marker::PhantomData;
12
13pub const CDF_LEN_MAX: usize = 16;
14
15#[derive(Clone)]
16pub struct CDFContextCheckpoint {
17  small: usize,
18  large: usize,
19}
20
21#[derive(Clone, Copy)]
22#[repr(C)]
23pub struct CDFContext {
24  pub comp_bwd_ref_cdf: [[[u16; 2]; BWD_REFS - 1]; REF_CONTEXTS],
25  pub comp_mode_cdf: [[u16; 2]; COMP_INTER_CONTEXTS],
26  pub comp_ref_cdf: [[[u16; 2]; FWD_REFS - 1]; REF_CONTEXTS],
27  pub comp_ref_type_cdf: [[u16; 2]; COMP_REF_TYPE_CONTEXTS],
28  pub dc_sign_cdf: [[[u16; 2]; DC_SIGN_CONTEXTS]; PLANE_TYPES],
29  pub drl_cdfs: [[u16; 2]; DRL_MODE_CONTEXTS],
30  pub eob_extra_cdf:
31    [[[[u16; 2]; EOB_COEF_CONTEXTS]; PLANE_TYPES]; TxSize::TX_SIZES],
32  pub filter_intra_cdfs: [[u16; 2]; BlockSize::BLOCK_SIZES_ALL],
33  pub intra_inter_cdfs: [[u16; 2]; INTRA_INTER_CONTEXTS],
34  pub lrf_sgrproj_cdf: [u16; 2],
35  pub lrf_wiener_cdf: [u16; 2],
36  pub newmv_cdf: [[u16; 2]; NEWMV_MODE_CONTEXTS],
37  pub palette_uv_mode_cdfs: [[u16; 2]; PALETTE_UV_MODE_CONTEXTS],
38  pub palette_y_mode_cdfs:
39    [[[u16; 2]; PALETTE_Y_MODE_CONTEXTS]; PALETTE_BSIZE_CTXS],
40  pub refmv_cdf: [[u16; 2]; REFMV_MODE_CONTEXTS],
41  pub single_ref_cdfs: [[[u16; 2]; SINGLE_REFS - 1]; REF_CONTEXTS],
42  pub skip_cdfs: [[u16; 2]; SKIP_CONTEXTS],
43  pub txb_skip_cdf: [[[u16; 2]; TXB_SKIP_CONTEXTS]; TxSize::TX_SIZES],
44  pub txfm_partition_cdf: [[u16; 2]; TXFM_PARTITION_CONTEXTS],
45  pub zeromv_cdf: [[u16; 2]; GLOBALMV_MODE_CONTEXTS],
46  pub tx_size_8x8_cdf: [[u16; MAX_TX_DEPTH]; TX_SIZE_CONTEXTS],
47  pub inter_tx_3_cdf: [[u16; 2]; TX_SIZE_SQR_CONTEXTS],
48
49  pub coeff_base_eob_cdf:
50    [[[[u16; 3]; SIG_COEF_CONTEXTS_EOB]; PLANE_TYPES]; TxSize::TX_SIZES],
51  pub lrf_switchable_cdf: [u16; 3],
52  pub tx_size_cdf: [[[u16; MAX_TX_DEPTH + 1]; TX_SIZE_CONTEXTS]; BIG_TX_CATS],
53
54  pub coeff_base_cdf:
55    [[[[u16; 4]; SIG_COEF_CONTEXTS]; PLANE_TYPES]; TxSize::TX_SIZES],
56  pub coeff_br_cdf:
57    [[[[u16; BR_CDF_SIZE]; LEVEL_CONTEXTS]; PLANE_TYPES]; TxSize::TX_SIZES],
58  pub deblock_delta_cdf: [u16; DELTA_LF_PROBS + 1],
59  pub deblock_delta_multi_cdf: [[u16; DELTA_LF_PROBS + 1]; FRAME_LF_COUNT],
60  pub partition_w8_cdf: [[u16; 4]; PARTITION_TYPES],
61
62  pub eob_flag_cdf16: [[[u16; 5]; 2]; PLANE_TYPES],
63  pub intra_tx_2_cdf: [[[u16; 5]; INTRA_MODES]; TX_SIZE_SQR_CONTEXTS],
64
65  pub eob_flag_cdf32: [[[u16; 6]; 2]; PLANE_TYPES],
66
67  pub angle_delta_cdf: [[u16; 2 * MAX_ANGLE_DELTA + 1]; DIRECTIONAL_MODES],
68  pub eob_flag_cdf64: [[[u16; 7]; 2]; PLANE_TYPES],
69  pub intra_tx_1_cdf: [[[u16; 7]; INTRA_MODES]; TX_SIZE_SQR_CONTEXTS],
70
71  pub cfl_sign_cdf: [u16; CFL_JOINT_SIGNS],
72  pub compound_mode_cdf: [[u16; INTER_COMPOUND_MODES]; INTER_MODE_CONTEXTS],
73  pub eob_flag_cdf128: [[[u16; 8]; 2]; PLANE_TYPES],
74  pub spatial_segmentation_cdfs: [[u16; 8]; 3],
75  pub partition_w128_cdf: [[u16; 8]; PARTITION_TYPES],
76
77  pub eob_flag_cdf256: [[[u16; 9]; 2]; PLANE_TYPES],
78
79  pub eob_flag_cdf512: [[[u16; 10]; 2]; PLANE_TYPES],
80  pub partition_cdf: [[u16; EXT_PARTITION_TYPES]; 3 * PARTITION_TYPES],
81
82  pub eob_flag_cdf1024: [[[u16; 11]; 2]; PLANE_TYPES],
83
84  pub inter_tx_2_cdf: [[u16; 12]; TX_SIZE_SQR_CONTEXTS],
85
86  pub kf_y_cdf: [[[u16; INTRA_MODES]; KF_MODE_CONTEXTS]; KF_MODE_CONTEXTS],
87  pub y_mode_cdf: [[u16; INTRA_MODES]; BLOCK_SIZE_GROUPS],
88  pub uv_mode_cdf: [[u16; INTRA_MODES]; INTRA_MODES],
89
90  pub uv_mode_cfl_cdf: [[u16; UV_INTRA_MODES]; INTRA_MODES],
91
92  pub cfl_alpha_cdf: [[u16; CFL_ALPHABET_SIZE]; CFL_ALPHA_CONTEXTS],
93  pub inter_tx_1_cdf: [[u16; TX_TYPES]; TX_SIZE_SQR_CONTEXTS],
94
95  pub nmv_context: NMVContext,
96}
97
98pub struct CDFOffset<const CDF_LEN: usize> {
99  offset: usize,
100  phantom: PhantomData<[u16; CDF_LEN]>,
101}
102
103impl CDFContext {
104  pub fn new(quantizer: u8) -> CDFContext {
105    let qctx = match quantizer {
106      0..=20 => 0,
107      21..=60 => 1,
108      61..=120 => 2,
109      _ => 3,
110    };
111    CDFContext {
112      partition_w8_cdf: default_partition_w8_cdf,
113      partition_w128_cdf: default_partition_w128_cdf,
114      partition_cdf: default_partition_cdf,
115      kf_y_cdf: default_kf_y_mode_cdf,
116      y_mode_cdf: default_if_y_mode_cdf,
117      uv_mode_cdf: default_uv_mode_cdf,
118      uv_mode_cfl_cdf: default_uv_mode_cfl_cdf,
119      cfl_sign_cdf: default_cfl_sign_cdf,
120      cfl_alpha_cdf: default_cfl_alpha_cdf,
121      newmv_cdf: default_newmv_cdf,
122      zeromv_cdf: default_zeromv_cdf,
123      refmv_cdf: default_refmv_cdf,
124      intra_tx_2_cdf: default_intra_tx_2_cdf,
125      intra_tx_1_cdf: default_intra_tx_1_cdf,
126      inter_tx_3_cdf: default_inter_tx_3_cdf,
127      inter_tx_2_cdf: default_inter_tx_2_cdf,
128      inter_tx_1_cdf: default_inter_tx_1_cdf,
129      tx_size_8x8_cdf: default_tx_size_8x8_cdf,
130      tx_size_cdf: default_tx_size_cdf,
131      txfm_partition_cdf: default_txfm_partition_cdf,
132      skip_cdfs: default_skip_cdfs,
133      intra_inter_cdfs: default_intra_inter_cdf,
134      angle_delta_cdf: default_angle_delta_cdf,
135      filter_intra_cdfs: default_filter_intra_cdfs,
136      palette_y_mode_cdfs: default_palette_y_mode_cdfs,
137      palette_uv_mode_cdfs: default_palette_uv_mode_cdfs,
138      comp_mode_cdf: default_comp_mode_cdf,
139      comp_ref_type_cdf: default_comp_ref_type_cdf,
140      comp_ref_cdf: default_comp_ref_cdf,
141      comp_bwd_ref_cdf: default_comp_bwdref_cdf,
142      single_ref_cdfs: default_single_ref_cdf,
143      drl_cdfs: default_drl_cdf,
144      compound_mode_cdf: default_compound_mode_cdf,
145      nmv_context: default_nmv_context,
146      deblock_delta_multi_cdf: default_delta_lf_multi_cdf,
147      deblock_delta_cdf: default_delta_lf_cdf,
148      spatial_segmentation_cdfs: default_spatial_pred_seg_tree_cdf,
149      lrf_switchable_cdf: default_switchable_restore_cdf,
150      lrf_sgrproj_cdf: default_sgrproj_restore_cdf,
151      lrf_wiener_cdf: default_wiener_restore_cdf,
152
153      // lv_map
154      txb_skip_cdf: av1_default_txb_skip_cdfs[qctx],
155      dc_sign_cdf: av1_default_dc_sign_cdfs[qctx],
156      eob_extra_cdf: av1_default_eob_extra_cdfs[qctx],
157
158      eob_flag_cdf16: av1_default_eob_multi16_cdfs[qctx],
159      eob_flag_cdf32: av1_default_eob_multi32_cdfs[qctx],
160      eob_flag_cdf64: av1_default_eob_multi64_cdfs[qctx],
161      eob_flag_cdf128: av1_default_eob_multi128_cdfs[qctx],
162      eob_flag_cdf256: av1_default_eob_multi256_cdfs[qctx],
163      eob_flag_cdf512: av1_default_eob_multi512_cdfs[qctx],
164      eob_flag_cdf1024: av1_default_eob_multi1024_cdfs[qctx],
165
166      coeff_base_eob_cdf: av1_default_coeff_base_eob_multi_cdfs[qctx],
167      coeff_base_cdf: av1_default_coeff_base_multi_cdfs[qctx],
168      coeff_br_cdf: av1_default_coeff_lps_multi_cdfs[qctx],
169    }
170  }
171
172  pub fn reset_counts(&mut self) {
173    macro_rules! reset_1d {
174      ($field:expr) => {
175        let r = $field.last_mut().unwrap();
176        *r = 0;
177      };
178    }
179    macro_rules! reset_2d {
180      ($field:expr) => {
181        for x in $field.iter_mut() {
182          reset_1d!(x);
183        }
184      };
185    }
186    macro_rules! reset_3d {
187      ($field:expr) => {
188        for x in $field.iter_mut() {
189          reset_2d!(x);
190        }
191      };
192    }
193    macro_rules! reset_4d {
194      ($field:expr) => {
195        for x in $field.iter_mut() {
196          reset_3d!(x);
197        }
198      };
199    }
200
201    reset_2d!(self.partition_w8_cdf);
202    reset_2d!(self.partition_w128_cdf);
203    reset_2d!(self.partition_cdf);
204
205    reset_3d!(self.kf_y_cdf);
206    reset_2d!(self.y_mode_cdf);
207
208    reset_2d!(self.uv_mode_cdf);
209    reset_2d!(self.uv_mode_cfl_cdf);
210    reset_1d!(self.cfl_sign_cdf);
211    reset_2d!(self.cfl_alpha_cdf);
212    reset_2d!(self.newmv_cdf);
213    reset_2d!(self.zeromv_cdf);
214    reset_2d!(self.refmv_cdf);
215
216    reset_3d!(self.intra_tx_2_cdf);
217    reset_3d!(self.intra_tx_1_cdf);
218
219    reset_2d!(self.inter_tx_3_cdf);
220    reset_2d!(self.inter_tx_2_cdf);
221    reset_2d!(self.inter_tx_1_cdf);
222
223    reset_2d!(self.tx_size_8x8_cdf);
224    reset_3d!(self.tx_size_cdf);
225
226    reset_2d!(self.txfm_partition_cdf);
227
228    reset_2d!(self.skip_cdfs);
229    reset_2d!(self.intra_inter_cdfs);
230    reset_2d!(self.angle_delta_cdf);
231    reset_2d!(self.filter_intra_cdfs);
232    reset_3d!(self.palette_y_mode_cdfs);
233    reset_2d!(self.palette_uv_mode_cdfs);
234    reset_2d!(self.comp_mode_cdf);
235    reset_2d!(self.comp_ref_type_cdf);
236    reset_3d!(self.comp_ref_cdf);
237    reset_3d!(self.comp_bwd_ref_cdf);
238    reset_3d!(self.single_ref_cdfs);
239    reset_2d!(self.drl_cdfs);
240    reset_2d!(self.compound_mode_cdf);
241    reset_2d!(self.deblock_delta_multi_cdf);
242    reset_1d!(self.deblock_delta_cdf);
243    reset_2d!(self.spatial_segmentation_cdfs);
244    reset_1d!(self.lrf_switchable_cdf);
245    reset_1d!(self.lrf_sgrproj_cdf);
246    reset_1d!(self.lrf_wiener_cdf);
247
248    reset_1d!(self.nmv_context.joints_cdf);
249    for i in 0..2 {
250      reset_1d!(self.nmv_context.comps[i].classes_cdf);
251      reset_2d!(self.nmv_context.comps[i].class0_fp_cdf);
252      reset_1d!(self.nmv_context.comps[i].fp_cdf);
253      reset_1d!(self.nmv_context.comps[i].sign_cdf);
254      reset_1d!(self.nmv_context.comps[i].class0_hp_cdf);
255      reset_1d!(self.nmv_context.comps[i].hp_cdf);
256      reset_1d!(self.nmv_context.comps[i].class0_cdf);
257      reset_2d!(self.nmv_context.comps[i].bits_cdf);
258    }
259
260    // lv_map
261    reset_3d!(self.txb_skip_cdf);
262    reset_3d!(self.dc_sign_cdf);
263    reset_4d!(self.eob_extra_cdf);
264
265    reset_3d!(self.eob_flag_cdf16);
266    reset_3d!(self.eob_flag_cdf32);
267    reset_3d!(self.eob_flag_cdf64);
268    reset_3d!(self.eob_flag_cdf128);
269    reset_3d!(self.eob_flag_cdf256);
270    reset_3d!(self.eob_flag_cdf512);
271    reset_3d!(self.eob_flag_cdf1024);
272
273    reset_4d!(self.coeff_base_eob_cdf);
274    reset_4d!(self.coeff_base_cdf);
275    reset_4d!(self.coeff_br_cdf);
276  }
277
278  /// # Panics
279  ///
280  /// - If any of the CDF arrays are uninitialized.
281  ///   This should never happen and indicates a development error.
282  pub fn build_map(&self) -> Vec<(&'static str, usize, usize)> {
283    use std::mem::size_of_val;
284
285    let partition_w8_cdf_start =
286      self.partition_w8_cdf.first().unwrap().as_ptr() as usize;
287    let partition_w8_cdf_end =
288      partition_w8_cdf_start + size_of_val(&self.partition_w8_cdf);
289    let partition_w128_cdf_start =
290      self.partition_w128_cdf.first().unwrap().as_ptr() as usize;
291    let partition_w128_cdf_end =
292      partition_w128_cdf_start + size_of_val(&self.partition_w128_cdf);
293    let partition_cdf_start =
294      self.partition_cdf.first().unwrap().as_ptr() as usize;
295    let partition_cdf_end =
296      partition_cdf_start + size_of_val(&self.partition_cdf);
297    let kf_y_cdf_start = self.kf_y_cdf.first().unwrap().as_ptr() as usize;
298    let kf_y_cdf_end = kf_y_cdf_start + size_of_val(&self.kf_y_cdf);
299    let y_mode_cdf_start = self.y_mode_cdf.first().unwrap().as_ptr() as usize;
300    let y_mode_cdf_end = y_mode_cdf_start + size_of_val(&self.y_mode_cdf);
301    let uv_mode_cdf_start =
302      self.uv_mode_cdf.first().unwrap().as_ptr() as usize;
303    let uv_mode_cdf_end = uv_mode_cdf_start + size_of_val(&self.uv_mode_cdf);
304    let uv_mode_cfl_cdf_start =
305      self.uv_mode_cfl_cdf.first().unwrap().as_ptr() as usize;
306    let uv_mode_cfl_cdf_end =
307      uv_mode_cfl_cdf_start + size_of_val(&self.uv_mode_cfl_cdf);
308    let cfl_sign_cdf_start = self.cfl_sign_cdf.as_ptr() as usize;
309    let cfl_sign_cdf_end =
310      cfl_sign_cdf_start + size_of_val(&self.cfl_sign_cdf);
311    let cfl_alpha_cdf_start =
312      self.cfl_alpha_cdf.first().unwrap().as_ptr() as usize;
313    let cfl_alpha_cdf_end =
314      cfl_alpha_cdf_start + size_of_val(&self.cfl_alpha_cdf);
315    let newmv_cdf_start = self.newmv_cdf.first().unwrap().as_ptr() as usize;
316    let newmv_cdf_end = newmv_cdf_start + size_of_val(&self.newmv_cdf);
317    let zeromv_cdf_start = self.zeromv_cdf.first().unwrap().as_ptr() as usize;
318    let zeromv_cdf_end = zeromv_cdf_start + size_of_val(&self.zeromv_cdf);
319    let refmv_cdf_start = self.refmv_cdf.first().unwrap().as_ptr() as usize;
320    let refmv_cdf_end = refmv_cdf_start + size_of_val(&self.refmv_cdf);
321    let intra_tx_2_cdf_start =
322      self.intra_tx_2_cdf.first().unwrap().as_ptr() as usize;
323    let intra_tx_2_cdf_end =
324      intra_tx_2_cdf_start + size_of_val(&self.intra_tx_2_cdf);
325    let intra_tx_1_cdf_start =
326      self.intra_tx_1_cdf.first().unwrap().as_ptr() as usize;
327    let intra_tx_1_cdf_end =
328      intra_tx_1_cdf_start + size_of_val(&self.intra_tx_1_cdf);
329    let inter_tx_3_cdf_start =
330      self.inter_tx_3_cdf.first().unwrap().as_ptr() as usize;
331    let inter_tx_3_cdf_end =
332      inter_tx_3_cdf_start + size_of_val(&self.inter_tx_3_cdf);
333    let inter_tx_2_cdf_start =
334      self.inter_tx_2_cdf.first().unwrap().as_ptr() as usize;
335    let inter_tx_2_cdf_end =
336      inter_tx_2_cdf_start + size_of_val(&self.inter_tx_2_cdf);
337    let inter_tx_1_cdf_start =
338      self.inter_tx_1_cdf.first().unwrap().as_ptr() as usize;
339    let inter_tx_1_cdf_end =
340      inter_tx_1_cdf_start + size_of_val(&self.inter_tx_1_cdf);
341    let tx_size_8x8_cdf_start =
342      self.tx_size_8x8_cdf.first().unwrap().as_ptr() as usize;
343    let tx_size_8x8_cdf_end =
344      tx_size_8x8_cdf_start + size_of_val(&self.tx_size_8x8_cdf);
345    let tx_size_cdf_start =
346      self.tx_size_cdf.first().unwrap().as_ptr() as usize;
347    let tx_size_cdf_end = tx_size_cdf_start + size_of_val(&self.tx_size_cdf);
348    let txfm_partition_cdf_start =
349      self.txfm_partition_cdf.first().unwrap().as_ptr() as usize;
350    let txfm_partition_cdf_end =
351      txfm_partition_cdf_start + size_of_val(&self.txfm_partition_cdf);
352    let skip_cdfs_start = self.skip_cdfs.first().unwrap().as_ptr() as usize;
353    let skip_cdfs_end = skip_cdfs_start + size_of_val(&self.skip_cdfs);
354    let intra_inter_cdfs_start =
355      self.intra_inter_cdfs.first().unwrap().as_ptr() as usize;
356    let intra_inter_cdfs_end =
357      intra_inter_cdfs_start + size_of_val(&self.intra_inter_cdfs);
358    let angle_delta_cdf_start =
359      self.angle_delta_cdf.first().unwrap().as_ptr() as usize;
360    let angle_delta_cdf_end =
361      angle_delta_cdf_start + size_of_val(&self.angle_delta_cdf);
362    let filter_intra_cdfs_start =
363      self.filter_intra_cdfs.first().unwrap().as_ptr() as usize;
364    let filter_intra_cdfs_end =
365      filter_intra_cdfs_start + size_of_val(&self.filter_intra_cdfs);
366    let palette_y_mode_cdfs_start =
367      self.palette_y_mode_cdfs.first().unwrap().as_ptr() as usize;
368    let palette_y_mode_cdfs_end =
369      palette_y_mode_cdfs_start + size_of_val(&self.palette_y_mode_cdfs);
370    let palette_uv_mode_cdfs_start =
371      self.palette_uv_mode_cdfs.first().unwrap().as_ptr() as usize;
372    let palette_uv_mode_cdfs_end =
373      palette_uv_mode_cdfs_start + size_of_val(&self.palette_uv_mode_cdfs);
374    let comp_mode_cdf_start =
375      self.comp_mode_cdf.first().unwrap().as_ptr() as usize;
376    let comp_mode_cdf_end =
377      comp_mode_cdf_start + size_of_val(&self.comp_mode_cdf);
378    let comp_ref_type_cdf_start =
379      self.comp_ref_type_cdf.first().unwrap().as_ptr() as usize;
380    let comp_ref_type_cdf_end =
381      comp_ref_type_cdf_start + size_of_val(&self.comp_ref_type_cdf);
382    let comp_ref_cdf_start =
383      self.comp_ref_cdf.first().unwrap().as_ptr() as usize;
384    let comp_ref_cdf_end =
385      comp_ref_cdf_start + size_of_val(&self.comp_ref_cdf);
386    let comp_bwd_ref_cdf_start =
387      self.comp_bwd_ref_cdf.first().unwrap().as_ptr() as usize;
388    let comp_bwd_ref_cdf_end =
389      comp_bwd_ref_cdf_start + size_of_val(&self.comp_bwd_ref_cdf);
390    let single_ref_cdfs_start =
391      self.single_ref_cdfs.first().unwrap().as_ptr() as usize;
392    let single_ref_cdfs_end =
393      single_ref_cdfs_start + size_of_val(&self.single_ref_cdfs);
394    let drl_cdfs_start = self.drl_cdfs.first().unwrap().as_ptr() as usize;
395    let drl_cdfs_end = drl_cdfs_start + size_of_val(&self.drl_cdfs);
396    let compound_mode_cdf_start =
397      self.compound_mode_cdf.first().unwrap().as_ptr() as usize;
398    let compound_mode_cdf_end =
399      compound_mode_cdf_start + size_of_val(&self.compound_mode_cdf);
400    let nmv_context_start = &self.nmv_context as *const NMVContext as usize;
401    let nmv_context_end = nmv_context_start + size_of_val(&self.nmv_context);
402    let deblock_delta_multi_cdf_start =
403      self.deblock_delta_multi_cdf.first().unwrap().as_ptr() as usize;
404    let deblock_delta_multi_cdf_end = deblock_delta_multi_cdf_start
405      + size_of_val(&self.deblock_delta_multi_cdf);
406    let deblock_delta_cdf_start = self.deblock_delta_cdf.as_ptr() as usize;
407    let deblock_delta_cdf_end =
408      deblock_delta_cdf_start + size_of_val(&self.deblock_delta_cdf);
409    let spatial_segmentation_cdfs_start =
410      self.spatial_segmentation_cdfs.first().unwrap().as_ptr() as usize;
411    let spatial_segmentation_cdfs_end = spatial_segmentation_cdfs_start
412      + size_of_val(&self.spatial_segmentation_cdfs);
413    let lrf_switchable_cdf_start = self.lrf_switchable_cdf.as_ptr() as usize;
414    let lrf_switchable_cdf_end =
415      lrf_switchable_cdf_start + size_of_val(&self.lrf_switchable_cdf);
416    let lrf_sgrproj_cdf_start = self.lrf_sgrproj_cdf.as_ptr() as usize;
417    let lrf_sgrproj_cdf_end =
418      lrf_sgrproj_cdf_start + size_of_val(&self.lrf_sgrproj_cdf);
419    let lrf_wiener_cdf_start = self.lrf_wiener_cdf.as_ptr() as usize;
420    let lrf_wiener_cdf_end =
421      lrf_wiener_cdf_start + size_of_val(&self.lrf_wiener_cdf);
422
423    let txb_skip_cdf_start =
424      self.txb_skip_cdf.first().unwrap().as_ptr() as usize;
425    let txb_skip_cdf_end =
426      txb_skip_cdf_start + size_of_val(&self.txb_skip_cdf);
427    let dc_sign_cdf_start =
428      self.dc_sign_cdf.first().unwrap().as_ptr() as usize;
429    let dc_sign_cdf_end = dc_sign_cdf_start + size_of_val(&self.dc_sign_cdf);
430    let eob_extra_cdf_start =
431      self.eob_extra_cdf.first().unwrap().as_ptr() as usize;
432    let eob_extra_cdf_end =
433      eob_extra_cdf_start + size_of_val(&self.eob_extra_cdf);
434    let eob_flag_cdf16_start =
435      self.eob_flag_cdf16.first().unwrap().as_ptr() as usize;
436    let eob_flag_cdf16_end =
437      eob_flag_cdf16_start + size_of_val(&self.eob_flag_cdf16);
438    let eob_flag_cdf32_start =
439      self.eob_flag_cdf32.first().unwrap().as_ptr() as usize;
440    let eob_flag_cdf32_end =
441      eob_flag_cdf32_start + size_of_val(&self.eob_flag_cdf32);
442    let eob_flag_cdf64_start =
443      self.eob_flag_cdf64.first().unwrap().as_ptr() as usize;
444    let eob_flag_cdf64_end =
445      eob_flag_cdf64_start + size_of_val(&self.eob_flag_cdf64);
446    let eob_flag_cdf128_start =
447      self.eob_flag_cdf128.first().unwrap().as_ptr() as usize;
448    let eob_flag_cdf128_end =
449      eob_flag_cdf128_start + size_of_val(&self.eob_flag_cdf128);
450    let eob_flag_cdf256_start =
451      self.eob_flag_cdf256.first().unwrap().as_ptr() as usize;
452    let eob_flag_cdf256_end =
453      eob_flag_cdf256_start + size_of_val(&self.eob_flag_cdf256);
454    let eob_flag_cdf512_start =
455      self.eob_flag_cdf512.first().unwrap().as_ptr() as usize;
456    let eob_flag_cdf512_end =
457      eob_flag_cdf512_start + size_of_val(&self.eob_flag_cdf512);
458    let eob_flag_cdf1024_start =
459      self.eob_flag_cdf1024.first().unwrap().as_ptr() as usize;
460    let eob_flag_cdf1024_end =
461      eob_flag_cdf1024_start + size_of_val(&self.eob_flag_cdf1024);
462    let coeff_base_eob_cdf_start =
463      self.coeff_base_eob_cdf.first().unwrap().as_ptr() as usize;
464    let coeff_base_eob_cdf_end =
465      coeff_base_eob_cdf_start + size_of_val(&self.coeff_base_eob_cdf);
466    let coeff_base_cdf_start =
467      self.coeff_base_cdf.first().unwrap().as_ptr() as usize;
468    let coeff_base_cdf_end =
469      coeff_base_cdf_start + size_of_val(&self.coeff_base_cdf);
470    let coeff_br_cdf_start =
471      self.coeff_br_cdf.first().unwrap().as_ptr() as usize;
472    let coeff_br_cdf_end =
473      coeff_br_cdf_start + size_of_val(&self.coeff_br_cdf);
474
475    vec![
476      ("partition_w8_cdf", partition_w8_cdf_start, partition_w8_cdf_end),
477      ("partition_w128_cdf", partition_w128_cdf_start, partition_w128_cdf_end),
478      ("partition_cdf", partition_cdf_start, partition_cdf_end),
479      ("kf_y_cdf", kf_y_cdf_start, kf_y_cdf_end),
480      ("y_mode_cdf", y_mode_cdf_start, y_mode_cdf_end),
481      ("uv_mode_cdf", uv_mode_cdf_start, uv_mode_cdf_end),
482      ("uv_mode_cfl_cdf", uv_mode_cfl_cdf_start, uv_mode_cfl_cdf_end),
483      ("cfl_sign_cdf", cfl_sign_cdf_start, cfl_sign_cdf_end),
484      ("cfl_alpha_cdf", cfl_alpha_cdf_start, cfl_alpha_cdf_end),
485      ("newmv_cdf", newmv_cdf_start, newmv_cdf_end),
486      ("zeromv_cdf", zeromv_cdf_start, zeromv_cdf_end),
487      ("refmv_cdf", refmv_cdf_start, refmv_cdf_end),
488      ("intra_tx_2_cdf", intra_tx_2_cdf_start, intra_tx_2_cdf_end),
489      ("intra_tx_1_cdf", intra_tx_1_cdf_start, intra_tx_1_cdf_end),
490      ("inter_tx_3_cdf", inter_tx_3_cdf_start, inter_tx_3_cdf_end),
491      ("inter_tx_2_cdf", inter_tx_2_cdf_start, inter_tx_2_cdf_end),
492      ("inter_tx_1_cdf", inter_tx_1_cdf_start, inter_tx_1_cdf_end),
493      ("tx_size_8x8_cdf", tx_size_8x8_cdf_start, tx_size_8x8_cdf_end),
494      ("tx_size_cdf", tx_size_cdf_start, tx_size_cdf_end),
495      ("txfm_partition_cdf", txfm_partition_cdf_start, txfm_partition_cdf_end),
496      ("skip_cdfs", skip_cdfs_start, skip_cdfs_end),
497      ("intra_inter_cdfs", intra_inter_cdfs_start, intra_inter_cdfs_end),
498      ("angle_delta_cdf", angle_delta_cdf_start, angle_delta_cdf_end),
499      ("filter_intra_cdfs", filter_intra_cdfs_start, filter_intra_cdfs_end),
500      (
501        "palette_y_mode_cdfs",
502        palette_y_mode_cdfs_start,
503        palette_y_mode_cdfs_end,
504      ),
505      (
506        "palette_uv_mode_cdfs",
507        palette_uv_mode_cdfs_start,
508        palette_uv_mode_cdfs_end,
509      ),
510      ("comp_mode_cdf", comp_mode_cdf_start, comp_mode_cdf_end),
511      ("comp_ref_type_cdf", comp_ref_type_cdf_start, comp_ref_type_cdf_end),
512      ("comp_ref_cdf", comp_ref_cdf_start, comp_ref_cdf_end),
513      ("comp_bwd_ref_cdf", comp_bwd_ref_cdf_start, comp_bwd_ref_cdf_end),
514      ("single_ref_cdfs", single_ref_cdfs_start, single_ref_cdfs_end),
515      ("drl_cdfs", drl_cdfs_start, drl_cdfs_end),
516      ("compound_mode_cdf", compound_mode_cdf_start, compound_mode_cdf_end),
517      ("nmv_context", nmv_context_start, nmv_context_end),
518      (
519        "deblock_delta_multi_cdf",
520        deblock_delta_multi_cdf_start,
521        deblock_delta_multi_cdf_end,
522      ),
523      ("deblock_delta_cdf", deblock_delta_cdf_start, deblock_delta_cdf_end),
524      (
525        "spatial_segmentation_cdfs",
526        spatial_segmentation_cdfs_start,
527        spatial_segmentation_cdfs_end,
528      ),
529      ("lrf_switchable_cdf", lrf_switchable_cdf_start, lrf_switchable_cdf_end),
530      ("lrf_sgrproj_cdf", lrf_sgrproj_cdf_start, lrf_sgrproj_cdf_end),
531      ("lrf_wiener_cdf", lrf_wiener_cdf_start, lrf_wiener_cdf_end),
532      ("txb_skip_cdf", txb_skip_cdf_start, txb_skip_cdf_end),
533      ("dc_sign_cdf", dc_sign_cdf_start, dc_sign_cdf_end),
534      ("eob_extra_cdf", eob_extra_cdf_start, eob_extra_cdf_end),
535      ("eob_flag_cdf16", eob_flag_cdf16_start, eob_flag_cdf16_end),
536      ("eob_flag_cdf32", eob_flag_cdf32_start, eob_flag_cdf32_end),
537      ("eob_flag_cdf64", eob_flag_cdf64_start, eob_flag_cdf64_end),
538      ("eob_flag_cdf128", eob_flag_cdf128_start, eob_flag_cdf128_end),
539      ("eob_flag_cdf256", eob_flag_cdf256_start, eob_flag_cdf256_end),
540      ("eob_flag_cdf512", eob_flag_cdf512_start, eob_flag_cdf512_end),
541      ("eob_flag_cdf1024", eob_flag_cdf1024_start, eob_flag_cdf1024_end),
542      ("coeff_base_eob_cdf", coeff_base_eob_cdf_start, coeff_base_eob_cdf_end),
543      ("coeff_base_cdf", coeff_base_cdf_start, coeff_base_cdf_end),
544      ("coeff_br_cdf", coeff_br_cdf_start, coeff_br_cdf_end),
545    ]
546  }
547
548  pub fn offset<const CDF_LEN: usize>(
549    &self, cdf: *const [u16; CDF_LEN],
550  ) -> CDFOffset<CDF_LEN> {
551    CDFOffset {
552      offset: cdf as usize - self as *const _ as usize,
553      phantom: PhantomData,
554    }
555  }
556}
557
558impl fmt::Debug for CDFContext {
559  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
560    write!(f, "CDFContext contains too many numbers to print :-(")
561  }
562}
563
564macro_rules! symbol_with_update {
565  ($self:ident, $w:ident, $s:expr, $cdf:expr) => {
566    let cdf = $self.fc.offset($cdf);
567    $w.symbol_with_update($s, cdf, &mut $self.fc_log, &mut $self.fc);
568    symbol_with_update!($self, $cdf);
569  };
570  ($self:ident, $cdf:expr) => {
571    #[cfg(feature = "desync_finder")]
572    {
573      let cdf: &[_] = $cdf;
574      if let Some(map) = $self.fc_map.as_ref() {
575        map.lookup(cdf.as_ptr() as usize);
576      }
577    }
578  };
579}
580
581#[derive(Clone)]
582pub struct ContextWriterCheckpoint {
583  pub fc: CDFContextCheckpoint,
584  pub bc: BlockContextCheckpoint,
585}
586
587struct CDFContextLogPartition<const CDF_LEN_MAX_PLUS_1: usize> {
588  pub data: Vec<[u16; CDF_LEN_MAX_PLUS_1]>,
589}
590
591impl<const CDF_LEN_MAX_PLUS_1: usize>
592  CDFContextLogPartition<CDF_LEN_MAX_PLUS_1>
593{
594  fn new(capacity: usize) -> Self {
595    Self { data: Vec::with_capacity(capacity) }
596  }
597  #[inline(always)]
598  fn push<const CDF_LEN: usize>(
599    &mut self, fc: &mut CDFContext, cdf: CDFOffset<CDF_LEN>,
600  ) -> &mut [u16; CDF_LEN] {
601    debug_assert!(CDF_LEN < CDF_LEN_MAX_PLUS_1);
602    debug_assert!(cdf.offset <= u16::MAX.into());
603    // SAFETY: Maintain an invariant of non-zero spare capacity, so that
604    // branching may be deferred until writes are issued. Benchmarks indicate
605    // this is faster than first testing capacity and possibly reallocating.
606    unsafe {
607      let len = self.data.len();
608      let new_len = len + 1;
609      let capacity = self.data.capacity();
610      debug_assert!(new_len <= capacity);
611      let dst = self.data.as_mut_ptr().add(len) as *mut u16;
612      let base = fc as *mut _ as *mut u8;
613      let src = base.add(cdf.offset) as *const u16;
614      dst.copy_from_nonoverlapping(src, CDF_LEN_MAX_PLUS_1 - 1);
615      *dst.add(CDF_LEN_MAX_PLUS_1 - 1) = cdf.offset as u16;
616      self.data.set_len(new_len);
617      if CDF_LEN_MAX_PLUS_1 > capacity.wrapping_sub(new_len) {
618        self.data.reserve(CDF_LEN_MAX_PLUS_1);
619      }
620      let cdf = base.add(cdf.offset) as *mut [u16; CDF_LEN];
621      &mut *cdf
622    }
623  }
624  #[inline(always)]
625  fn rollback(&mut self, fc: &mut CDFContext, checkpoint: usize) {
626    let base = fc as *mut _ as *mut u8;
627    let mut len = self.data.len();
628    // SAFETY: We use unchecked pointers here for performance.
629    // Since we know the length, we can ensure not to go OOB.
630    unsafe {
631      let mut src = self.data.as_mut_ptr().add(len);
632      while len > checkpoint {
633        len -= 1;
634        src = src.sub(1);
635        let src = src as *mut u16;
636        let offset = *src.add(CDF_LEN_MAX_PLUS_1 - 1) as usize;
637        let dst = base.add(offset) as *mut u16;
638        dst.copy_from_nonoverlapping(src, CDF_LEN_MAX_PLUS_1 - 1);
639      }
640      self.data.set_len(len);
641    }
642  }
643}
644
645const CDF_LEN_SMALL: usize = 4;
646
647pub struct CDFContextLog {
648  small: CDFContextLogPartition<{ CDF_LEN_SMALL + 1 }>,
649  large: CDFContextLogPartition<{ CDF_LEN_MAX + 1 }>,
650}
651
652impl Default for CDFContextLog {
653  fn default() -> Self {
654    Self {
655      small: CDFContextLogPartition::new(1 << 16),
656      large: CDFContextLogPartition::new(1 << 9),
657    }
658  }
659}
660
661impl CDFContextLog {
662  fn checkpoint(&self) -> CDFContextCheckpoint {
663    CDFContextCheckpoint {
664      small: self.small.data.len(),
665      large: self.large.data.len(),
666    }
667  }
668  #[inline(always)]
669  pub fn push<const CDF_LEN: usize>(
670    &mut self, fc: &mut CDFContext, cdf: CDFOffset<CDF_LEN>,
671  ) -> &mut [u16; CDF_LEN] {
672    if CDF_LEN <= CDF_LEN_SMALL {
673      self.small.push(fc, cdf)
674    } else {
675      self.large.push(fc, cdf)
676    }
677  }
678  #[inline(always)]
679  pub fn rollback(
680    &mut self, fc: &mut CDFContext, checkpoint: &CDFContextCheckpoint,
681  ) {
682    self.small.rollback(fc, checkpoint.small);
683    self.large.rollback(fc, checkpoint.large);
684  }
685  pub fn clear(&mut self) {
686    self.small.data.clear();
687    self.large.data.clear();
688  }
689}
690
691pub struct ContextWriter<'a> {
692  pub bc: BlockContext<'a>,
693  pub fc: &'a mut CDFContext,
694  pub fc_log: CDFContextLog,
695  #[cfg(feature = "desync_finder")]
696  pub fc_map: Option<FieldMap>, // For debugging purposes
697}
698
699impl<'a> ContextWriter<'a> {
700  #[allow(clippy::let_and_return)]
701  pub fn new(fc: &'a mut CDFContext, bc: BlockContext<'a>) -> Self {
702    let fc_log = CDFContextLog::default();
703    #[allow(unused_mut)]
704    let mut cw = ContextWriter {
705      bc,
706      fc,
707      fc_log,
708      #[cfg(feature = "desync_finder")]
709      fc_map: Default::default(),
710    };
711    #[cfg(feature = "desync_finder")]
712    {
713      if std::env::var_os("RAV1E_DEBUG").is_some() {
714        cw.fc_map = Some(FieldMap { map: cw.fc.build_map() });
715      }
716    }
717
718    cw
719  }
720
721  pub const fn cdf_element_prob(cdf: &[u16], element: usize) -> u16 {
722    (if element > 0 { cdf[element - 1] } else { 32768 })
723      - (if element + 1 < cdf.len() { cdf[element] } else { 0 })
724  }
725
726  pub fn checkpoint(
727    &self, tile_bo: &TileBlockOffset, chroma_sampling: ChromaSampling,
728  ) -> ContextWriterCheckpoint {
729    ContextWriterCheckpoint {
730      fc: self.fc_log.checkpoint(),
731      bc: self.bc.checkpoint(tile_bo, chroma_sampling),
732    }
733  }
734
735  pub fn rollback(&mut self, checkpoint: &ContextWriterCheckpoint) {
736    self.fc_log.rollback(self.fc, &checkpoint.fc);
737    self.bc.rollback(&checkpoint.bc);
738    #[cfg(feature = "desync_finder")]
739    {
740      if self.fc_map.is_some() {
741        self.fc_map = Some(FieldMap { map: self.fc.build_map() });
742      }
743    }
744  }
745}