rav1e/
ec.rs

1// Copyright (c) 2001-2016, Alliance for Open Media. All rights reserved
2// Copyright (c) 2017-2022, The rav1e contributors. All rights reserved
3//
4// This source code is subject to the terms of the BSD 2 Clause License and
5// the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6// was not distributed with this source code in the LICENSE file, you can
7// obtain it at www.aomedia.org/license/software. If the Alliance for Open
8// Media Patent License 1.0 was not distributed with this source code in the
9// PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10
11#![allow(non_camel_case_types)]
12
13cfg_if::cfg_if! {
14  if #[cfg(nasm_x86_64)] {
15    pub use crate::asm::x86::ec::*;
16  } else {
17    pub use self::rust::*;
18  }
19}
20
21use crate::context::{CDFContext, CDFContextLog, CDFOffset};
22use bitstream_io::{BigEndian, BitWrite, BitWriter};
23use std::io;
24
25pub const OD_BITRES: u8 = 3;
26const EC_PROB_SHIFT: u32 = 6;
27const EC_MIN_PROB: u32 = 4;
28type ec_window = u32;
29
30/// Public trait interface to a bitstream `Writer`: a `Counter` can be
31/// used to count bits for cost analysis without actually storing
32/// anything (using a new `WriterCounter` as a `Writer`), to record
33/// tokens for later writing (using a new `WriterRecorder` as a
34/// `Writer`) to write actual final bits out using a range encoder
35/// (using a new `WriterEncoder` as a `Writer`).  A `WriterRecorder`'s
36/// contents can be replayed into a `WriterEncoder`.
37pub trait Writer {
38  /// Write a symbol `s`, using the passed in cdf reference; leaves `cdf` unchanged
39  fn symbol<const CDF_LEN: usize>(&mut self, s: u32, cdf: &[u16; CDF_LEN]);
40  /// return approximate number of fractional bits in `OD_BITRES`
41  /// precision to write a symbol `s` using the passed in cdf reference;
42  /// leaves `cdf` unchanged
43  fn symbol_bits(&self, s: u32, cdf: &[u16]) -> u32;
44  /// Write a symbol `s`, using the passed in cdf reference; updates the referenced cdf.
45  fn symbol_with_update<const CDF_LEN: usize>(
46    &mut self, s: u32, cdf: CDFOffset<CDF_LEN>, log: &mut CDFContextLog,
47    fc: &mut CDFContext,
48  );
49  /// Write a bool using passed in probability
50  fn bool(&mut self, val: bool, f: u16);
51  /// Write a single bit with flat probability
52  fn bit(&mut self, bit: u16);
53  /// Write literal `bits` with flat probability
54  fn literal(&mut self, bits: u8, s: u32);
55  /// Write passed `level` as a golomb code
56  fn write_golomb(&mut self, level: u32);
57  /// Write a value `v` in `[0, n-1]` quasi-uniformly
58  fn write_quniform(&mut self, n: u32, v: u32);
59  /// Return fractional bits needed to write a value `v` in `[0, n-1]`
60  /// quasi-uniformly
61  fn count_quniform(&self, n: u32, v: u32) -> u32;
62  /// Write symbol `v` in `[0, n-1]` with parameter `k` as finite subexponential
63  fn write_subexp(&mut self, n: u32, k: u8, v: u32);
64  /// Return fractional bits needed to write symbol v in `[0, n-1]` with
65  /// parameter k as finite subexponential
66  fn count_subexp(&self, n: u32, k: u8, v: u32) -> u32;
67  /// Write symbol `v` in `[0, n-1]` with parameter `k` as finite
68  /// subexponential based on a reference `r` also in `[0, n-1]`.
69  fn write_unsigned_subexp_with_ref(&mut self, v: u32, mx: u32, k: u8, r: u32);
70  /// Return fractional bits needed to write symbol `v` in `[0, n-1]` with
71  /// parameter `k` as finite subexponential based on a reference `r`
72  /// also in `[0, n-1]`.
73  fn count_unsigned_subexp_with_ref(
74    &self, v: u32, mx: u32, k: u8, r: u32,
75  ) -> u32;
76  /// Write symbol v in `[-(n-1), n-1]` with parameter k as finite
77  /// subexponential based on a reference ref also in `[-(n-1), n-1]`.
78  fn write_signed_subexp_with_ref(
79    &mut self, v: i32, low: i32, high: i32, k: u8, r: i32,
80  );
81  /// Return fractional bits needed to write symbol `v` in `[-(n-1), n-1]`
82  /// with parameter `k` as finite subexponential based on a reference
83  /// `r` also in `[-(n-1), n-1]`.
84  fn count_signed_subexp_with_ref(
85    &self, v: i32, low: i32, high: i32, k: u8, r: i32,
86  ) -> u32;
87  /// Return current length of range-coded bitstream in integer bits
88  fn tell(&mut self) -> u32;
89  /// Return current length of range-coded bitstream in fractional
90  /// bits with `OD_BITRES` decimal precision
91  fn tell_frac(&mut self) -> u32;
92  /// Save current point in coding/recording to a checkpoint
93  fn checkpoint(&mut self) -> WriterCheckpoint;
94  /// Restore saved position in coding/recording from a checkpoint
95  fn rollback(&mut self, _: &WriterCheckpoint);
96  /// Add additional bits from rate estimators without coding a real symbol
97  fn add_bits_frac(&mut self, bits_frac: u32);
98}
99
100/// `StorageBackend` is an internal trait used to tie a specific `Writer`
101/// implementation's storage to the generic `Writer`.  It would be
102/// private, but Rust is deprecating 'private trait in a public
103/// interface' support.
104pub trait StorageBackend {
105  /// Store partially-computed range code into given storage backend
106  fn store(&mut self, fl: u16, fh: u16, nms: u16);
107  /// Return bit-length of encoded stream to date
108  fn stream_bits(&mut self) -> usize;
109  /// Backend implementation of checkpoint to pass through Writer interface
110  fn checkpoint(&mut self) -> WriterCheckpoint;
111  /// Backend implementation of rollback to pass through Writer interface
112  fn rollback(&mut self, _: &WriterCheckpoint);
113}
114
115#[derive(Debug, Clone)]
116pub struct WriterBase<S> {
117  /// The number of values in the current range.
118  rng: u16,
119  /// The number of bits of data in the current value.
120  cnt: i16,
121  #[cfg(feature = "desync_finder")]
122  /// Debug enable flag
123  debug: bool,
124  /// Extra offset added to tell() and tell_frac() to approximate costs
125  /// of actually coding a symbol
126  fake_bits_frac: u32,
127  /// Use-specific storage
128  s: S,
129}
130
131#[derive(Debug, Clone)]
132pub struct WriterCounter {
133  /// Bits that would be shifted out to date
134  bits: usize,
135}
136
137#[derive(Debug, Clone)]
138pub struct WriterRecorder {
139  /// Storage for tokens
140  storage: Vec<(u16, u16, u16)>,
141  /// Bits that would be shifted out to date
142  bits: usize,
143}
144
145#[derive(Debug, Clone)]
146pub struct WriterEncoder {
147  /// A buffer for output bytes with their associated carry flags.
148  precarry: Vec<u16>,
149  /// The low end of the current range.
150  low: ec_window,
151}
152
153#[derive(Clone)]
154pub struct WriterCheckpoint {
155  /// Stream length coded/recorded to date, in the unit used by the Writer,
156  /// which may be bytes or bits. This depends on the assumption
157  /// that a Writer will only ever restore its own Checkpoint.
158  stream_size: usize,
159  /// To be defined by backend
160  backend_var: usize,
161  /// Saved number of values in the current range.
162  rng: u16,
163  /// Saved number of bits of data in the current value.
164  cnt: i16,
165}
166
167/// Constructor for a counting Writer
168impl WriterCounter {
169  #[inline]
170  pub const fn new() -> WriterBase<WriterCounter> {
171    WriterBase::new(WriterCounter { bits: 0 })
172  }
173}
174
175/// Constructor for a recording Writer
176impl WriterRecorder {
177  #[inline]
178  pub const fn new() -> WriterBase<WriterRecorder> {
179    WriterBase::new(WriterRecorder { storage: Vec::new(), bits: 0 })
180  }
181}
182
183/// Constructor for a encoding Writer
184impl WriterEncoder {
185  #[inline]
186  pub const fn new() -> WriterBase<WriterEncoder> {
187    WriterBase::new(WriterEncoder { precarry: Vec::new(), low: 0 })
188  }
189}
190
191/// The Counter stores nothing we write to it, it merely counts the
192/// bit usage like in an Encoder for cost analysis.
193impl StorageBackend for WriterBase<WriterCounter> {
194  #[inline]
195  fn store(&mut self, fl: u16, fh: u16, nms: u16) {
196    let (_l, r) = self.lr_compute(fl, fh, nms);
197    let d = r.leading_zeros() as usize;
198
199    self.s.bits += d;
200    self.rng = r << d;
201  }
202  #[inline]
203  fn stream_bits(&mut self) -> usize {
204    self.s.bits
205  }
206  #[inline]
207  fn checkpoint(&mut self) -> WriterCheckpoint {
208    WriterCheckpoint {
209      stream_size: self.s.bits,
210      backend_var: 0,
211      rng: self.rng,
212      // We do not use `cnt` within Counter, but setting it here allows the compiler
213      // to do a 32-bit merged load/store.
214      cnt: self.cnt,
215    }
216  }
217  #[inline]
218  fn rollback(&mut self, checkpoint: &WriterCheckpoint) {
219    self.rng = checkpoint.rng;
220    self.s.bits = checkpoint.stream_size;
221  }
222}
223
224/// The Recorder does not produce a range-coded bitstream, but it
225/// still tracks the range coding progress like in an Encoder, as it
226/// neds to be able to report bit costs for RDO decisions.  It stores a
227/// pair of mostly-computed range coding values per token recorded.
228impl StorageBackend for WriterBase<WriterRecorder> {
229  #[inline]
230  fn store(&mut self, fl: u16, fh: u16, nms: u16) {
231    let (_l, r) = self.lr_compute(fl, fh, nms);
232    let d = r.leading_zeros() as usize;
233
234    self.s.bits += d;
235    self.rng = r << d;
236    self.s.storage.push((fl, fh, nms));
237  }
238  #[inline]
239  fn stream_bits(&mut self) -> usize {
240    self.s.bits
241  }
242  #[inline]
243  fn checkpoint(&mut self) -> WriterCheckpoint {
244    WriterCheckpoint {
245      stream_size: self.s.bits,
246      backend_var: self.s.storage.len(),
247      rng: self.rng,
248      cnt: self.cnt,
249    }
250  }
251  #[inline]
252  fn rollback(&mut self, checkpoint: &WriterCheckpoint) {
253    self.rng = checkpoint.rng;
254    self.cnt = checkpoint.cnt;
255    self.s.bits = checkpoint.stream_size;
256    self.s.storage.truncate(checkpoint.backend_var);
257  }
258}
259
260/// An Encoder produces an actual range-coded bitstream from passed in
261/// tokens.  It does not retain any information about the coded
262/// tokens, only the resulting bitstream, and so it cannot be replayed
263/// (only checkpointed and rolled back).
264impl StorageBackend for WriterBase<WriterEncoder> {
265  fn store(&mut self, fl: u16, fh: u16, nms: u16) {
266    let (l, r) = self.lr_compute(fl, fh, nms);
267    let mut low = l + self.s.low;
268    let mut c = self.cnt;
269    let d = r.leading_zeros() as usize;
270    let mut s = c + (d as i16);
271
272    if s >= 0 {
273      c += 16;
274      let mut m = (1 << c) - 1;
275      if s >= 8 {
276        self.s.precarry.push((low >> c) as u16);
277        low &= m;
278        c -= 8;
279        m >>= 8;
280      }
281      self.s.precarry.push((low >> c) as u16);
282      s = c + (d as i16) - 24;
283      low &= m;
284    }
285    self.s.low = low << d;
286    self.rng = r << d;
287    self.cnt = s;
288  }
289  #[inline]
290  fn stream_bits(&mut self) -> usize {
291    self.s.precarry.len() * 8
292  }
293  #[inline]
294  fn checkpoint(&mut self) -> WriterCheckpoint {
295    WriterCheckpoint {
296      stream_size: self.s.precarry.len(),
297      backend_var: self.s.low as usize,
298      rng: self.rng,
299      cnt: self.cnt,
300    }
301  }
302  fn rollback(&mut self, checkpoint: &WriterCheckpoint) {
303    self.rng = checkpoint.rng;
304    self.cnt = checkpoint.cnt;
305    self.s.low = checkpoint.backend_var as ec_window;
306    self.s.precarry.truncate(checkpoint.stream_size);
307  }
308}
309
310/// A few local helper functions needed by the Writer that are not
311/// part of the public interface.
312impl<S> WriterBase<S> {
313  /// Internal constructor called by the subtypes that implement the
314  /// actual encoder and Recorder.
315  #[inline]
316  #[cfg(not(feature = "desync_finder"))]
317  const fn new(storage: S) -> Self {
318    WriterBase { rng: 0x8000, cnt: -9, fake_bits_frac: 0, s: storage }
319  }
320
321  #[inline]
322  #[cfg(feature = "desync_finder")]
323  fn new(storage: S) -> Self {
324    WriterBase {
325      rng: 0x8000,
326      cnt: -9,
327      debug: std::env::var_os("RAV1E_DEBUG").is_some(),
328      fake_bits_frac: 0,
329      s: storage,
330    }
331  }
332
333  /// Compute low and range values from token cdf values and local state
334  const fn lr_compute(&self, fl: u16, fh: u16, nms: u16) -> (ec_window, u16) {
335    let r = self.rng as u32;
336    debug_assert!(32768 <= r);
337    let mut u = (((r >> 8) * (fl as u32 >> EC_PROB_SHIFT))
338      >> (7 - EC_PROB_SHIFT))
339      + EC_MIN_PROB * nms as u32;
340    if fl >= 32768 {
341      u = r;
342    }
343    let v = (((r >> 8) * (fh as u32 >> EC_PROB_SHIFT)) >> (7 - EC_PROB_SHIFT))
344      + EC_MIN_PROB * (nms - 1) as u32;
345    (r - u, (u - v) as u16)
346  }
347
348  /// Given the current total integer number of bits used and the current value of
349  /// rng, computes the fraction number of bits used to `OD_BITRES` precision.
350  /// This is used by `od_ec_enc_tell_frac()` and `od_ec_dec_tell_frac()`.
351  /// `nbits_total`: The number of whole bits currently used, i.e., the value
352  ///                returned by `od_ec_enc_tell()` or `od_ec_dec_tell()`.
353  /// `rng`: The current value of rng from either the encoder or decoder state.
354  /// Return: The number of bits scaled by `2**OD_BITRES`.
355  ///         This will always be slightly larger than the exact value (e.g., all
356  ///         rounding error is in the positive direction).
357  fn frac_compute(nbits_total: u32, mut rng: u32) -> u32 {
358    // To handle the non-integral number of bits still left in the encoder/decoder
359    //  state, we compute the worst-case number of bits of val that must be
360    //  encoded to ensure that the value is inside the range for any possible
361    //  subsequent bits.
362    // The computation here is independent of val itself (the decoder does not
363    //  even track that value), even though the real number of bits used after
364    //  od_ec_enc_done() may be 1 smaller if rng is a power of two and the
365    //  corresponding trailing bits of val are all zeros.
366    // If we did try to track that special case, then coding a value with a
367    //  probability of 1/(1 << n) might sometimes appear to use more than n bits.
368    // This may help explain the surprising result that a newly initialized
369    //  encoder or decoder claims to have used 1 bit.
370    let nbits = nbits_total << OD_BITRES;
371    let mut l = 0;
372    for _ in 0..OD_BITRES {
373      rng = (rng * rng) >> 15;
374      let b = rng >> 16;
375      l = (l << 1) | b;
376      rng >>= b;
377    }
378    nbits - l
379  }
380
381  const fn recenter(r: u32, v: u32) -> u32 {
382    if v > (r << 1) {
383      v
384    } else if v >= r {
385      (v - r) << 1
386    } else {
387      ((r - v) << 1) - 1
388    }
389  }
390
391  #[cfg(feature = "desync_finder")]
392  fn print_backtrace(&self, s: u32) {
393    let mut depth = 3;
394    backtrace::trace(|frame| {
395      let ip = frame.ip();
396
397      depth -= 1;
398
399      if depth == 0 {
400        backtrace::resolve(ip, |symbol| {
401          if let Some(name) = symbol.name() {
402            println!("Writing symbol {} from {}", s, name);
403          }
404        });
405        false
406      } else {
407        true
408      }
409    });
410  }
411}
412
413/// Replay implementation specific to the Recorder
414impl WriterBase<WriterRecorder> {
415  /// Replays the partially-computed range tokens out of the Recorder's
416  /// storage and into the passed in Writer, which may be an Encoder
417  /// or another Recorder.  Clears the Recorder after replay.
418  pub fn replay(&mut self, dest: &mut dyn StorageBackend) {
419    for &(fl, fh, nms) in &self.s.storage {
420      dest.store(fl, fh, nms);
421    }
422    self.rng = 0x8000;
423    self.cnt = -9;
424    self.s.storage.truncate(0);
425    self.s.bits = 0;
426  }
427}
428
429/// Done implementation specific to the Encoder
430impl WriterBase<WriterEncoder> {
431  /// Indicates that there are no more symbols to encode.  Flushes
432  /// remaining state into coding and returns a vector containing the
433  /// final bitstream.
434  pub fn done(&mut self) -> Vec<u8> {
435    // We output the minimum number of bits that ensures that the symbols encoded
436    // thus far will be decoded correctly regardless of the bits that follow.
437    let l = self.s.low;
438    let mut c = self.cnt;
439    let mut s = 10;
440    let m = 0x3FFF;
441    let mut e = ((l + m) & !m) | (m + 1);
442
443    s += c;
444
445    if s > 0 {
446      let mut n = (1 << (c + 16)) - 1;
447
448      loop {
449        self.s.precarry.push((e >> (c + 16)) as u16);
450        e &= n;
451        s -= 8;
452        c -= 8;
453        n >>= 8;
454
455        if s <= 0 {
456          break;
457        }
458      }
459    }
460
461    let mut c = 0;
462    let mut offs = self.s.precarry.len();
463    // dynamic allocation: grows during encode
464    let mut out = vec![0_u8; offs];
465    while offs > 0 {
466      offs -= 1;
467      c += self.s.precarry[offs];
468      out[offs] = c as u8;
469      c >>= 8;
470    }
471
472    out
473  }
474}
475
476/// Generic/shared implementation for `Writer`s with `StorageBackend`s
477/// (ie, `Encoder`s and `Recorder`s)
478impl<S> Writer for WriterBase<S>
479where
480  WriterBase<S>: StorageBackend,
481{
482  /// Encode a single binary value.
483  /// `val`: The value to encode (0 or 1).
484  /// `f`: The probability that the val is one, scaled by 32768.
485  fn bool(&mut self, val: bool, f: u16) {
486    debug_assert!(0 < f);
487    debug_assert!(f < 32768);
488    self.symbol(u32::from(val), &[f, 0]);
489  }
490  /// Encode a single boolean value.
491  ///
492  /// - `val`: The value to encode (`false` or `true`).
493  /// - `f`: The probability that the `val` is `true`, scaled by `32768`.
494  fn bit(&mut self, bit: u16) {
495    self.bool(bit == 1, 16384);
496  }
497  // fake add bits
498  fn add_bits_frac(&mut self, bits_frac: u32) {
499    self.fake_bits_frac += bits_frac
500  }
501  /// Encode a literal bitstring, bit by bit in MSB order, with flat
502  /// probability.
503  ///
504  /// - 'bits': Length of bitstring
505  /// - 's': Bit string to encode
506  fn literal(&mut self, bits: u8, s: u32) {
507    for bit in (0..bits).rev() {
508      self.bit((1 & (s >> bit)) as u16);
509    }
510  }
511  /// Encodes a symbol given a cumulative distribution function (CDF) table in Q15.
512  ///
513  /// - `s`: The index of the symbol to encode.
514  /// - `cdf`: The CDF, such that symbol s falls in the range
515  ///        `[s > 0 ? cdf[s - 1] : 0, cdf[s])`.
516  ///       The values must be monotonically non-decreasing, and the last value
517  ///       must be greater than 32704. There should be at most 16 values.
518  ///       The lower 6 bits of the last value hold the count.
519  #[inline(always)]
520  fn symbol<const CDF_LEN: usize>(&mut self, s: u32, cdf: &[u16; CDF_LEN]) {
521    debug_assert!(cdf[cdf.len() - 1] < (1 << EC_PROB_SHIFT));
522    let s = s as usize;
523    debug_assert!(s < cdf.len());
524    // The above is stricter than the following overflow check: s <= cdf.len()
525    let nms = cdf.len() - s;
526    let fl = if s > 0 {
527      // SAFETY: We asserted that s is less than the length of the cdf
528      unsafe { *cdf.get_unchecked(s - 1) }
529    } else {
530      32768
531    };
532    // SAFETY: We asserted that s is less than the length of the cdf
533    let fh = unsafe { *cdf.get_unchecked(s) };
534    debug_assert!((fh >> EC_PROB_SHIFT) <= (fl >> EC_PROB_SHIFT));
535    debug_assert!(fl <= 32768);
536    self.store(fl, fh, nms as u16);
537  }
538  /// Encodes a symbol given a cumulative distribution function (CDF)
539  /// table in Q15, then updates the CDF probabilities to reflect we've
540  /// written one more symbol 's'.
541  ///
542  /// - `s`: The index of the symbol to encode.
543  /// - `cdf`: The CDF, such that symbol s falls in the range
544  ///        `[s > 0 ? cdf[s - 1] : 0, cdf[s])`.
545  ///       The values must be monotonically non-decreasing, and the last value
546  ///       must be greater 32704. There should be at most 16 values.
547  ///       The lower 6 bits of the last value hold the count.
548  fn symbol_with_update<const CDF_LEN: usize>(
549    &mut self, s: u32, cdf: CDFOffset<CDF_LEN>, log: &mut CDFContextLog,
550    fc: &mut CDFContext,
551  ) {
552    #[cfg(feature = "desync_finder")]
553    {
554      if self.debug {
555        self.print_backtrace(s);
556      }
557    }
558    let cdf = log.push(fc, cdf);
559    self.symbol(s, cdf);
560
561    update_cdf(cdf, s);
562  }
563  /// Returns approximate cost for a symbol given a cumulative
564  /// distribution function (CDF) table and current write state.
565  ///
566  /// - `s`: The index of the symbol to encode.
567  /// - `cdf`: The CDF, such that symbol s falls in the range
568  ///        `[s > 0 ? cdf[s - 1] : 0, cdf[s])`.
569  ///       The values must be monotonically non-decreasing, and the last value
570  ///       must be greater than 32704. There should be at most 16 values.
571  ///       The lower 6 bits of the last value hold the count.
572  fn symbol_bits(&self, s: u32, cdf: &[u16]) -> u32 {
573    let mut bits = 0;
574    debug_assert!(cdf[cdf.len() - 1] < (1 << EC_PROB_SHIFT));
575    debug_assert!(32768 <= self.rng);
576    let rng = (self.rng >> 8) as u32;
577    let fh = cdf[s as usize] as u32 >> EC_PROB_SHIFT;
578    let r: u32 = if s > 0 {
579      let fl = cdf[s as usize - 1] as u32 >> EC_PROB_SHIFT;
580      ((rng * fl) >> (7 - EC_PROB_SHIFT)) - ((rng * fh) >> (7 - EC_PROB_SHIFT))
581        + EC_MIN_PROB
582    } else {
583      let nms1 = cdf.len() as u32 - s - 1;
584      self.rng as u32
585        - ((rng * fh) >> (7 - EC_PROB_SHIFT))
586        - nms1 * EC_MIN_PROB
587    };
588
589    // The 9 here counteracts the offset of -9 baked into cnt.  Don't include a termination bit.
590    let pre = Self::frac_compute((self.cnt + 9) as u32, self.rng as u32);
591    let d = r.leading_zeros() - 16;
592    let mut c = self.cnt;
593    let mut sh = c + (d as i16);
594    if sh >= 0 {
595      c += 16;
596      if sh >= 8 {
597        bits += 8;
598        c -= 8;
599      }
600      bits += 8;
601      sh = c + (d as i16) - 24;
602    }
603    // The 9 here counteracts the offset of -9 baked into cnt.  Don't include a termination bit.
604    Self::frac_compute((bits + sh + 9) as u32, r << d) - pre
605  }
606  /// Encode a golomb to the bitstream.
607  ///
608  /// - 'level': passed in value to encode
609  fn write_golomb(&mut self, level: u32) {
610    let x = level + 1;
611    let length = 32 - x.leading_zeros();
612
613    for _ in 0..length - 1 {
614      self.bit(0);
615    }
616
617    for i in (0..length).rev() {
618      self.bit(((x >> i) & 0x01) as u16);
619    }
620  }
621  /// Write a value `v` in `[0, n-1]` quasi-uniformly
622  /// - `n`: size of interval
623  /// - `v`: value to encode
624  fn write_quniform(&mut self, n: u32, v: u32) {
625    if n > 1 {
626      let l = 32 - n.leading_zeros() as u8;
627      let m = (1 << l) - n;
628      if v < m {
629        self.literal(l - 1, v);
630      } else {
631        self.literal(l - 1, m + ((v - m) >> 1));
632        self.literal(1, (v - m) & 1);
633      }
634    }
635  }
636  /// Returns `QOD_BITRES` bits for a value `v` in `[0, n-1]` quasi-uniformly
637  /// - `n`: size of interval
638  /// - `v`: value to encode
639  fn count_quniform(&self, n: u32, v: u32) -> u32 {
640    let mut bits = 0;
641    if n > 1 {
642      let l = 32 - n.leading_zeros();
643      let m = (1 << l) - n;
644      bits += (l - 1) << OD_BITRES;
645      if v >= m {
646        bits += 1 << OD_BITRES;
647      }
648    }
649    bits
650  }
651  /// Write symbol `v` in `[0, n-1]` with parameter `k` as finite subexponential
652  ///
653  /// - `n`: size of interval
654  /// - `k`: "parameter"
655  /// - `v`: value to encode
656  fn write_subexp(&mut self, n: u32, k: u8, v: u32) {
657    let mut i = 0;
658    let mut mk = 0;
659    loop {
660      let b = if i != 0 { k + i - 1 } else { k };
661      let a = 1 << b;
662      if n <= mk + 3 * a {
663        self.write_quniform(n - mk, v - mk);
664        break;
665      } else {
666        let t = v >= mk + a;
667        self.bool(t, 16384);
668        if t {
669          i += 1;
670          mk += a;
671        } else {
672          self.literal(b, v - mk);
673          break;
674        }
675      }
676    }
677  }
678  /// Returns `QOD_BITRES` bits for symbol `v` in `[0, n-1]` with parameter `k`
679  /// as finite subexponential
680  ///
681  /// - `n`: size of interval
682  /// - `k`: "parameter"
683  /// - `v`: value to encode
684  fn count_subexp(&self, n: u32, k: u8, v: u32) -> u32 {
685    let mut i = 0;
686    let mut mk = 0;
687    let mut bits = 0;
688    loop {
689      let b = if i != 0 { k + i - 1 } else { k };
690      let a = 1 << b;
691      if n <= mk + 3 * a {
692        bits += self.count_quniform(n - mk, v - mk);
693        break;
694      } else {
695        let t = v >= mk + a;
696        bits += 1 << OD_BITRES;
697        if t {
698          i += 1;
699          mk += a;
700        } else {
701          bits += (b as u32) << OD_BITRES;
702          break;
703        }
704      }
705    }
706    bits
707  }
708  /// Write symbol `v` in `[0, n-1]` with parameter `k` as finite
709  /// subexponential based on a reference `r` also in `[0, n-1]`.
710  ///
711  /// - `v`: value to encode
712  /// - `n`: size of interval
713  /// - `k`: "parameter"
714  /// - `r`: reference
715  fn write_unsigned_subexp_with_ref(&mut self, v: u32, n: u32, k: u8, r: u32) {
716    if (r << 1) <= n {
717      self.write_subexp(n, k, Self::recenter(r, v));
718    } else {
719      self.write_subexp(n, k, Self::recenter(n - 1 - r, n - 1 - v));
720    }
721  }
722  /// Returns `QOD_BITRES` bits for symbol `v` in `[0, n-1]`
723  /// with parameter `k` as finite subexponential based on a
724  /// reference `r` also in `[0, n-1]`.
725  ///
726  /// - `v`: value to encode
727  /// - `n`: size of interval
728  /// - `k`: "parameter"
729  /// - `r`: reference
730  fn count_unsigned_subexp_with_ref(
731    &self, v: u32, n: u32, k: u8, r: u32,
732  ) -> u32 {
733    if (r << 1) <= n {
734      self.count_subexp(n, k, Self::recenter(r, v))
735    } else {
736      self.count_subexp(n, k, Self::recenter(n - 1 - r, n - 1 - v))
737    }
738  }
739  /// Write symbol `v` in `[-(n-1), n-1]` with parameter `k` as finite
740  /// subexponential based on a reference `r` also in `[-(n-1), n-1]`.
741  ///
742  /// - `v`: value to encode
743  /// - `n`: size of interval
744  /// - `k`: "parameter"
745  /// - `r`: reference
746  fn write_signed_subexp_with_ref(
747    &mut self, v: i32, low: i32, high: i32, k: u8, r: i32,
748  ) {
749    self.write_unsigned_subexp_with_ref(
750      (v - low) as u32,
751      (high - low) as u32,
752      k,
753      (r - low) as u32,
754    );
755  }
756  /// Returns `QOD_BITRES` bits for symbol `v` in `[-(n-1), n-1]`
757  /// with parameter `k` as finite subexponential based on a
758  /// reference `r` also in `[-(n-1), n-1]`.
759  ///
760  /// - `v`: value to encode
761  /// - `n`: size of interval
762  /// - `k`: "parameter"
763  /// - `r`: reference
764
765  fn count_signed_subexp_with_ref(
766    &self, v: i32, low: i32, high: i32, k: u8, r: i32,
767  ) -> u32 {
768    self.count_unsigned_subexp_with_ref(
769      (v - low) as u32,
770      (high - low) as u32,
771      k,
772      (r - low) as u32,
773    )
774  }
775  /// Returns the number of bits "used" by the encoded symbols so far.
776  /// This same number can be computed in either the encoder or the
777  /// decoder, and is suitable for making coding decisions.  The value
778  /// will be the same whether using an `Encoder` or `Recorder`.
779  ///
780  /// Return: The integer number of bits.
781  ///         This will always be slightly larger than the exact value (e.g., all
782  ///          rounding error is in the positive direction).
783  fn tell(&mut self) -> u32 {
784    // The 10 here counteracts the offset of -9 baked into cnt, and adds 1 extra
785    // bit, which we reserve for terminating the stream.
786    (((self.stream_bits()) as i32) + (self.cnt as i32) + 10) as u32
787      + (self.fake_bits_frac >> 8)
788  }
789  /// Returns the number of bits "used" by the encoded symbols so far.
790  /// This same number can be computed in either the encoder or the
791  /// decoder, and is suitable for making coding decisions. The value
792  /// will be the same whether using an `Encoder` or `Recorder`.
793  ///
794  /// Return: The number of bits scaled by `2**OD_BITRES`.
795  ///         This will always be slightly larger than the exact value (e.g., all
796  ///          rounding error is in the positive direction).
797  fn tell_frac(&mut self) -> u32 {
798    Self::frac_compute(self.tell(), self.rng as u32) + self.fake_bits_frac
799  }
800  /// Save current point in coding/recording to a checkpoint that can
801  /// be restored later.  A `WriterCheckpoint` can be generated for an
802  /// `Encoder` or `Recorder`, but can only be used to rollback the `Writer`
803  /// instance from which it was generated.
804  fn checkpoint(&mut self) -> WriterCheckpoint {
805    StorageBackend::checkpoint(self)
806  }
807  /// Roll back a given `Writer` to the state saved in the `WriterCheckpoint`
808  ///
809  /// - 'wc': Saved `Writer` state/posiiton to restore
810  fn rollback(&mut self, wc: &WriterCheckpoint) {
811    StorageBackend::rollback(self, wc)
812  }
813}
814
815pub trait BCodeWriter {
816  fn recenter_nonneg(&mut self, r: u16, v: u16) -> u16;
817  fn recenter_finite_nonneg(&mut self, n: u16, r: u16, v: u16) -> u16;
818  /// # Errors
819  ///
820  /// - Returns `std::io::Error` if the writer cannot be written to.
821  fn write_quniform(&mut self, n: u16, v: u16) -> Result<(), std::io::Error>;
822  /// # Errors
823  ///
824  /// - Returns `std::io::Error` if the writer cannot be written to.
825  fn write_subexpfin(
826    &mut self, n: u16, k: u16, v: u16,
827  ) -> Result<(), std::io::Error>;
828  /// # Errors
829  ///
830  /// - Returns `std::io::Error` if the writer cannot be written to.
831  fn write_refsubexpfin(
832    &mut self, n: u16, k: u16, r: i16, v: i16,
833  ) -> Result<(), std::io::Error>;
834  /// # Errors
835  ///
836  /// - Returns `std::io::Error` if the writer cannot be written to.
837  fn write_s_refsubexpfin(
838    &mut self, n: u16, k: u16, r: i16, v: i16,
839  ) -> Result<(), std::io::Error>;
840}
841
842impl<W: io::Write> BCodeWriter for BitWriter<W, BigEndian> {
843  fn recenter_nonneg(&mut self, r: u16, v: u16) -> u16 {
844    /* Recenters a non-negative literal v around a reference r */
845    if v > (r << 1) {
846      v
847    } else if v >= r {
848      (v - r) << 1
849    } else {
850      ((r - v) << 1) - 1
851    }
852  }
853  fn recenter_finite_nonneg(&mut self, n: u16, r: u16, v: u16) -> u16 {
854    /* Recenters a non-negative literal v in [0, n-1] around a
855    reference r also in [0, n-1] */
856    if (r << 1) <= n {
857      self.recenter_nonneg(r, v)
858    } else {
859      self.recenter_nonneg(n - 1 - r, n - 1 - v)
860    }
861  }
862  fn write_quniform(&mut self, n: u16, v: u16) -> Result<(), std::io::Error> {
863    if n > 1 {
864      let l = 16 - n.leading_zeros() as u8;
865      let m = (1 << l) - n;
866      if v < m {
867        self.write(l as u32 - 1, v)
868      } else {
869        self.write(l as u32 - 1, m + ((v - m) >> 1))?;
870        self.write(1, (v - m) & 1)
871      }
872    } else {
873      Ok(())
874    }
875  }
876  fn write_subexpfin(
877    &mut self, n: u16, k: u16, v: u16,
878  ) -> Result<(), std::io::Error> {
879    /* Finite subexponential code that codes a symbol v in [0, n-1] with parameter k */
880    let mut i = 0;
881    let mut mk = 0;
882    loop {
883      let b = if i > 0 { k + i - 1 } else { k };
884      let a = 1 << b;
885      if n <= mk + 3 * a {
886        return self.write_quniform(n - mk, v - mk);
887      } else {
888        let t = v >= mk + a;
889        self.write_bit(t)?;
890        if t {
891          i += 1;
892          mk += a;
893        } else {
894          return self.write(b as u32, v - mk);
895        }
896      }
897    }
898  }
899  fn write_refsubexpfin(
900    &mut self, n: u16, k: u16, r: i16, v: i16,
901  ) -> Result<(), std::io::Error> {
902    /* Finite subexponential code that codes a symbol v in [0, n-1] with
903    parameter k based on a reference ref also in [0, n-1].
904    Recenters symbol around r first and then uses a finite subexponential code. */
905    let recentered_v = self.recenter_finite_nonneg(n, r as u16, v as u16);
906    self.write_subexpfin(n, k, recentered_v)
907  }
908  fn write_s_refsubexpfin(
909    &mut self, n: u16, k: u16, r: i16, v: i16,
910  ) -> Result<(), std::io::Error> {
911    /* Signed version of the above function */
912    self.write_refsubexpfin(
913      (n << 1) - 1,
914      k,
915      r + (n - 1) as i16,
916      v + (n - 1) as i16,
917    )
918  }
919}
920
921pub(crate) fn cdf_to_pdf<const CDF_LEN: usize>(
922  cdf: &[u16; CDF_LEN],
923) -> [u16; CDF_LEN] {
924  let mut pdf = [0; CDF_LEN];
925  let mut z = 32768u16 >> EC_PROB_SHIFT;
926  for (d, &a) in pdf.iter_mut().zip(cdf.iter()) {
927    *d = z - (a >> EC_PROB_SHIFT);
928    z = a >> EC_PROB_SHIFT;
929  }
930  pdf
931}
932
933pub(crate) mod rust {
934  // Function to update the CDF for Writer calls that do so.
935  #[inline]
936  pub fn update_cdf<const N: usize>(cdf: &mut [u16; N], val: u32) {
937    use crate::context::CDF_LEN_MAX;
938    let nsymbs = cdf.len();
939    let mut rate = 3 + (nsymbs >> 1).min(2);
940    if let Some(count) = cdf.last_mut() {
941      rate += (*count >> 4) as usize;
942      *count += 1 - (*count >> 5);
943    } else {
944      return;
945    }
946    // Single loop (faster)
947    for (i, v) in
948      cdf[..nsymbs - 1].iter_mut().enumerate().take(CDF_LEN_MAX - 1)
949    {
950      if i as u32 >= val {
951        *v -= *v >> rate;
952      } else {
953        *v += (32768 - *v) >> rate;
954      }
955    }
956  }
957}
958
959#[cfg(test)]
960mod test {
961  use super::*;
962
963  const WINDOW_SIZE: i16 = 32;
964  const LOTS_OF_BITS: i16 = 0x4000;
965
966  #[derive(Debug)]
967  struct Reader<'a> {
968    buf: &'a [u8],
969    bptr: usize,
970    dif: ec_window,
971    rng: u16,
972    cnt: i16,
973  }
974
975  impl<'a> Reader<'a> {
976    fn new(buf: &'a [u8]) -> Self {
977      let mut r = Reader {
978        buf,
979        bptr: 0,
980        dif: (1 << (WINDOW_SIZE - 1)) - 1,
981        rng: 0x8000,
982        cnt: -15,
983      };
984      r.refill();
985      r
986    }
987
988    fn refill(&mut self) {
989      let mut s = WINDOW_SIZE - 9 - (self.cnt + 15);
990      while s >= 0 && self.bptr < self.buf.len() {
991        assert!(s <= WINDOW_SIZE - 8);
992        self.dif ^= (self.buf[self.bptr] as ec_window) << s;
993        self.cnt += 8;
994        s -= 8;
995        self.bptr += 1;
996      }
997      if self.bptr >= self.buf.len() {
998        self.cnt = LOTS_OF_BITS;
999      }
1000    }
1001
1002    fn normalize(&mut self, dif: ec_window, rng: u32) {
1003      assert!(rng <= 65536);
1004      let d = rng.leading_zeros() - 16;
1005      //let d = 16 - (32-rng.leading_zeros());
1006      self.cnt -= d as i16;
1007      /*This is equivalent to shifting in 1's instead of 0's.*/
1008      self.dif = ((dif + 1) << d) - 1;
1009      self.rng = (rng << d) as u16;
1010      if self.cnt < 0 {
1011        self.refill()
1012      }
1013    }
1014
1015    fn bool(&mut self, f: u32) -> bool {
1016      assert!(f < 32768);
1017      let r = self.rng as u32;
1018      assert!(self.dif >> (WINDOW_SIZE - 16) < r);
1019      assert!(32768 <= r);
1020      let v = (((r >> 8) * (f >> EC_PROB_SHIFT)) >> (7 - EC_PROB_SHIFT))
1021        + EC_MIN_PROB;
1022      let vw = v << (WINDOW_SIZE - 16);
1023      let (dif, rng, ret) = if self.dif >= vw {
1024        (self.dif - vw, r - v, false)
1025      } else {
1026        (self.dif, v, true)
1027      };
1028      self.normalize(dif, rng);
1029      ret
1030    }
1031
1032    fn symbol(&mut self, icdf: &[u16]) -> i32 {
1033      let r = self.rng as u32;
1034      assert!(self.dif >> (WINDOW_SIZE - 16) < r);
1035      assert!(32768 <= r);
1036      let n = icdf.len() as u32 - 1;
1037      let c = self.dif >> (WINDOW_SIZE - 16);
1038      let mut v = self.rng as u32;
1039      let mut ret = 0i32;
1040      let mut u = v;
1041      v = ((r >> 8) * (icdf[ret as usize] as u32 >> EC_PROB_SHIFT))
1042        >> (7 - EC_PROB_SHIFT);
1043      v += EC_MIN_PROB * (n - ret as u32);
1044      while c < v {
1045        u = v;
1046        ret += 1;
1047        v = ((r >> 8) * (icdf[ret as usize] as u32 >> EC_PROB_SHIFT))
1048          >> (7 - EC_PROB_SHIFT);
1049        v += EC_MIN_PROB * (n - ret as u32);
1050      }
1051      assert!(v < u);
1052      assert!(u <= r);
1053      let new_dif = self.dif - (v << (WINDOW_SIZE - 16));
1054      self.normalize(new_dif, u - v);
1055      ret
1056    }
1057  }
1058
1059  #[test]
1060  fn booleans() {
1061    let mut w = WriterEncoder::new();
1062
1063    w.bool(false, 1);
1064    w.bool(true, 2);
1065    w.bool(false, 3);
1066    w.bool(true, 1);
1067    w.bool(true, 2);
1068    w.bool(false, 3);
1069
1070    let b = w.done();
1071
1072    let mut r = Reader::new(&b);
1073
1074    assert!(!r.bool(1));
1075    assert!(r.bool(2));
1076    assert!(!r.bool(3));
1077    assert!(r.bool(1));
1078    assert!(r.bool(2));
1079    assert!(!r.bool(3));
1080  }
1081
1082  #[test]
1083  fn cdf() {
1084    let cdf = [7296, 3819, 1716, 0];
1085
1086    let mut w = WriterEncoder::new();
1087
1088    w.symbol(0, &cdf);
1089    w.symbol(0, &cdf);
1090    w.symbol(0, &cdf);
1091    w.symbol(1, &cdf);
1092    w.symbol(1, &cdf);
1093    w.symbol(1, &cdf);
1094    w.symbol(2, &cdf);
1095    w.symbol(2, &cdf);
1096    w.symbol(2, &cdf);
1097
1098    let b = w.done();
1099
1100    let mut r = Reader::new(&b);
1101
1102    assert_eq!(r.symbol(&cdf), 0);
1103    assert_eq!(r.symbol(&cdf), 0);
1104    assert_eq!(r.symbol(&cdf), 0);
1105    assert_eq!(r.symbol(&cdf), 1);
1106    assert_eq!(r.symbol(&cdf), 1);
1107    assert_eq!(r.symbol(&cdf), 1);
1108    assert_eq!(r.symbol(&cdf), 2);
1109    assert_eq!(r.symbol(&cdf), 2);
1110    assert_eq!(r.symbol(&cdf), 2);
1111  }
1112
1113  #[test]
1114  fn mixed() {
1115    let cdf = [7296, 3819, 1716, 0];
1116
1117    let mut w = WriterEncoder::new();
1118
1119    w.symbol(0, &cdf);
1120    w.bool(true, 2);
1121    w.symbol(0, &cdf);
1122    w.bool(true, 2);
1123    w.symbol(0, &cdf);
1124    w.bool(true, 2);
1125    w.symbol(1, &cdf);
1126    w.bool(true, 1);
1127    w.symbol(1, &cdf);
1128    w.bool(false, 2);
1129    w.symbol(1, &cdf);
1130    w.symbol(2, &cdf);
1131    w.symbol(2, &cdf);
1132    w.symbol(2, &cdf);
1133
1134    let b = w.done();
1135
1136    let mut r = Reader::new(&b);
1137
1138    assert_eq!(r.symbol(&cdf), 0);
1139    assert!(r.bool(2));
1140    assert_eq!(r.symbol(&cdf), 0);
1141    assert!(r.bool(2));
1142    assert_eq!(r.symbol(&cdf), 0);
1143    assert!(r.bool(2));
1144    assert_eq!(r.symbol(&cdf), 1);
1145    assert!(r.bool(1));
1146    assert_eq!(r.symbol(&cdf), 1);
1147    assert!(!r.bool(2));
1148    assert_eq!(r.symbol(&cdf), 1);
1149    assert_eq!(r.symbol(&cdf), 2);
1150    assert_eq!(r.symbol(&cdf), 2);
1151    assert_eq!(r.symbol(&cdf), 2);
1152  }
1153}