image_webp/
vp8_arithmetic_decoder.rs

1use crate::decoder::DecodingError;
2
3use super::vp8::TreeNode;
4
5#[must_use]
6#[repr(transparent)]
7pub(crate) struct BitResult<T> {
8    value_if_not_past_eof: T,
9}
10
11#[must_use]
12pub(crate) struct BitResultAccumulator;
13
14impl<T> BitResult<T> {
15    const fn ok(value: T) -> Self {
16        Self {
17            value_if_not_past_eof: value,
18        }
19    }
20
21    /// Instead of checking this result now, accumulate the burden of checking
22    /// into an accumulator. This accumulator must be checked in the end.
23    #[inline(always)]
24    pub(crate) fn or_accumulate(self, acc: &mut BitResultAccumulator) -> T {
25        let _ = acc;
26        self.value_if_not_past_eof
27    }
28}
29
30impl<T: Default> BitResult<T> {
31    fn err() -> Self {
32        Self {
33            value_if_not_past_eof: T::default(),
34        }
35    }
36}
37
38#[cfg_attr(test, derive(Debug))]
39pub(crate) struct ArithmeticDecoder {
40    chunks: Box<[[u8; 4]]>,
41    state: State,
42    final_bytes: [u8; 3],
43    final_bytes_remaining: i8,
44}
45
46#[cfg_attr(test, derive(Debug))]
47#[derive(Clone, Copy)]
48struct State {
49    chunk_index: usize,
50    value: u64,
51    range: u32,
52    bit_count: i32,
53}
54
55#[cfg_attr(test, derive(Debug))]
56struct FastDecoder<'a> {
57    chunks: &'a [[u8; 4]],
58    uncommitted_state: State,
59    save_state: &'a mut State,
60}
61
62impl ArithmeticDecoder {
63    pub(crate) fn new() -> ArithmeticDecoder {
64        let state = State {
65            chunk_index: 0,
66            value: 0,
67            range: 255,
68            bit_count: -8,
69        };
70        ArithmeticDecoder {
71            chunks: Box::new([]),
72            state,
73            final_bytes: [0; 3],
74            final_bytes_remaining: Self::FINAL_BYTES_REMAINING_EOF,
75        }
76    }
77
78    pub(crate) fn init(&mut self, mut buf: Vec<[u8; 4]>, len: usize) -> Result<(), DecodingError> {
79        let mut final_bytes = [0; 3];
80        let final_bytes_remaining = if len == 4 * buf.len() {
81            0
82        } else {
83            // Pop the last chunk (which is partial), then get length.
84            let Some(last_chunk) = buf.pop() else {
85                return Err(DecodingError::NotEnoughInitData);
86            };
87            let len_rounded_down = 4 * buf.len();
88            let num_bytes_popped = len - len_rounded_down;
89            debug_assert!(num_bytes_popped <= 3);
90            final_bytes[..num_bytes_popped].copy_from_slice(&last_chunk[..num_bytes_popped]);
91            for i in num_bytes_popped..4 {
92                debug_assert_eq!(last_chunk[i], 0, "unexpected {last_chunk:?}");
93            }
94            num_bytes_popped as i8
95        };
96
97        let chunks = buf.into_boxed_slice();
98        let state = State {
99            chunk_index: 0,
100            value: 0,
101            range: 255,
102            bit_count: -8,
103        };
104        *self = Self {
105            chunks,
106            state,
107            final_bytes,
108            final_bytes_remaining,
109        };
110        Ok(())
111    }
112
113    /// Start a span of reading operations from the buffer, without stopping
114    /// when the buffer runs out. For all valid webp images, the buffer will not
115    /// run out prematurely. Conversely if the buffer ends early, the webp image
116    /// cannot be correctly decoded and any intermediate results need to be
117    /// discarded anyway.
118    ///
119    /// Each call to `start_accumulated_result` must be followed by a call to
120    /// `check` on the *same* `ArithmeticDecoder`.
121    #[inline(always)]
122    pub(crate) fn start_accumulated_result(&mut self) -> BitResultAccumulator {
123        BitResultAccumulator
124    }
125
126    /// Check that the read operations done so far were all valid.
127    #[inline(always)]
128    pub(crate) fn check<T>(
129        &self,
130        acc: BitResultAccumulator,
131        value_if_not_past_eof: T,
132    ) -> Result<T, DecodingError> {
133        // The accumulator does not store any state because doing so is
134        // too computationally expensive. Passing it around is a bit of
135        // formality (that is optimized out) to ensure we call `check` .
136        // Instead we check whether we have read past the end of the file.
137        let BitResultAccumulator = acc;
138
139        if self.is_past_eof() {
140            Err(DecodingError::BitStreamError)
141        } else {
142            Ok(value_if_not_past_eof)
143        }
144    }
145
146    fn keep_accumulating<T>(
147        &self,
148        acc: BitResultAccumulator,
149        value_if_not_past_eof: T,
150    ) -> BitResult<T> {
151        // The BitResult will be checked later by a different accumulator.
152        // Because it does not carry state, that is fine.
153        let BitResultAccumulator = acc;
154
155        BitResult::ok(value_if_not_past_eof)
156    }
157
158    // Do not inline this because inlining seems to worsen performance.
159    #[inline(never)]
160    pub(crate) fn read_bool(&mut self, probability: u8) -> BitResult<bool> {
161        if let Some(b) = self.fast().read_bool(probability) {
162            return BitResult::ok(b);
163        }
164
165        self.cold_read_bool(probability)
166    }
167
168    // Do not inline this because inlining seems to worsen performance.
169    #[inline(never)]
170    pub(crate) fn read_flag(&mut self) -> BitResult<bool> {
171        if let Some(b) = self.fast().read_flag() {
172            return BitResult::ok(b);
173        }
174
175        self.cold_read_flag()
176    }
177
178    // Do not inline this because inlining seems to worsen performance.
179    #[inline(never)]
180    pub(crate) fn read_literal(&mut self, n: u8) -> BitResult<u8> {
181        if let Some(v) = self.fast().read_literal(n) {
182            return BitResult::ok(v);
183        }
184
185        self.cold_read_literal(n)
186    }
187
188    // Do not inline this because inlining seems to worsen performance.
189    #[inline(never)]
190    pub(crate) fn read_optional_signed_value(&mut self, n: u8) -> BitResult<i32> {
191        if let Some(v) = self.fast().read_optional_signed_value(n) {
192            return BitResult::ok(v);
193        }
194
195        self.cold_read_optional_signed_value(n)
196    }
197
198    // This is generic and inlined just to skip the first bounds check.
199    #[inline]
200    pub(crate) fn read_with_tree<const N: usize>(&mut self, tree: &[TreeNode; N]) -> BitResult<i8> {
201        let first_node = tree[0];
202        self.read_with_tree_with_first_node(tree, first_node)
203    }
204
205    // Do not inline this because inlining significantly worsens performance.
206    #[inline(never)]
207    pub(crate) fn read_with_tree_with_first_node(
208        &mut self,
209        tree: &[TreeNode],
210        first_node: TreeNode,
211    ) -> BitResult<i8> {
212        if let Some(v) = self.fast().read_with_tree(tree, first_node) {
213            return BitResult::ok(v);
214        }
215
216        self.cold_read_with_tree(tree, usize::from(first_node.index))
217    }
218
219    // As a similar (but different) speedup to BitResult, the FastDecoder reads
220    // bits under an assumption and validates it at the end.
221    //
222    // The idea here is that for normal-sized webp images, the vast majority
223    // of bits are somewhere other than in the last four bytes. Therefore we
224    // can pretend the buffer has infinite size. After we are done reading,
225    // we check if we actually read past the end of `self.chunks`.
226    // If so, we backtrack (or rather we discard `uncommitted_state`)
227    // and try again with the slow approach. This might result in doing double
228    // work for those last few bytes -- in fact we even keep retrying the fast
229    // method to save an if-statement --, but more than make up for that by
230    // speeding up reading from the other thousands or millions of bytes.
231    fn fast(&mut self) -> FastDecoder<'_> {
232        FastDecoder {
233            chunks: &self.chunks,
234            uncommitted_state: self.state,
235            save_state: &mut self.state,
236        }
237    }
238
239    const FINAL_BYTES_REMAINING_EOF: i8 = -0xE;
240
241    fn load_from_final_bytes(&mut self) {
242        match self.final_bytes_remaining {
243            1.. => {
244                self.final_bytes_remaining -= 1;
245                let byte = self.final_bytes[0];
246                self.final_bytes.rotate_left(1);
247                self.state.value <<= 8;
248                self.state.value |= u64::from(byte);
249                self.state.bit_count += 8;
250            }
251            0 => {
252                // libwebp seems to (sometimes?) allow bitstreams that read one byte past the end.
253                // This replicates that logic.
254                self.final_bytes_remaining -= 1;
255                self.state.value <<= 8;
256                self.state.bit_count += 8;
257            }
258            _ => {
259                self.final_bytes_remaining = Self::FINAL_BYTES_REMAINING_EOF;
260            }
261        }
262    }
263
264    fn is_past_eof(&self) -> bool {
265        self.final_bytes_remaining == Self::FINAL_BYTES_REMAINING_EOF
266    }
267
268    fn cold_read_bit(&mut self, probability: u8) -> BitResult<bool> {
269        if self.state.bit_count < 0 {
270            if let Some(chunk) = self.chunks.get(self.state.chunk_index).copied() {
271                let v = u32::from_be_bytes(chunk);
272                self.state.chunk_index += 1;
273                self.state.value <<= 32;
274                self.state.value |= u64::from(v);
275                self.state.bit_count += 32;
276            } else {
277                self.load_from_final_bytes();
278                if self.is_past_eof() {
279                    return BitResult::err();
280                }
281            }
282        }
283        debug_assert!(self.state.bit_count >= 0);
284
285        let probability = u32::from(probability);
286        let split = 1 + (((self.state.range - 1) * probability) >> 8);
287        let bigsplit = u64::from(split) << self.state.bit_count;
288
289        let retval = if let Some(new_value) = self.state.value.checked_sub(bigsplit) {
290            self.state.range -= split;
291            self.state.value = new_value;
292            true
293        } else {
294            self.state.range = split;
295            false
296        };
297        debug_assert!(self.state.range > 0);
298
299        // Compute shift required to satisfy `self.state.range >= 128`.
300        // Apply that shift to `self.state.range` and `self.state.bitcount`.
301        //
302        // Subtract 24 because we only care about leading zeros in the
303        // lowest byte of `self.state.range` which is a `u32`.
304        let shift = self.state.range.leading_zeros().saturating_sub(24);
305        self.state.range <<= shift;
306        self.state.bit_count -= shift as i32;
307        debug_assert!(self.state.range >= 128);
308
309        BitResult::ok(retval)
310    }
311
312    #[cold]
313    #[inline(never)]
314    fn cold_read_bool(&mut self, probability: u8) -> BitResult<bool> {
315        self.cold_read_bit(probability)
316    }
317
318    #[cold]
319    #[inline(never)]
320    fn cold_read_flag(&mut self) -> BitResult<bool> {
321        self.cold_read_bit(128)
322    }
323
324    #[cold]
325    #[inline(never)]
326    fn cold_read_literal(&mut self, n: u8) -> BitResult<u8> {
327        let mut v = 0u8;
328        let mut res = self.start_accumulated_result();
329
330        for _ in 0..n {
331            let b = self.cold_read_flag().or_accumulate(&mut res);
332            v = (v << 1) + u8::from(b);
333        }
334
335        self.keep_accumulating(res, v)
336    }
337
338    #[cold]
339    #[inline(never)]
340    fn cold_read_optional_signed_value(&mut self, n: u8) -> BitResult<i32> {
341        let mut res = self.start_accumulated_result();
342        let flag = self.cold_read_flag().or_accumulate(&mut res);
343        if !flag {
344            // We should not read further bits if the flag is not set.
345            return self.keep_accumulating(res, 0);
346        }
347        let magnitude = self.cold_read_literal(n).or_accumulate(&mut res);
348        let sign = self.cold_read_flag().or_accumulate(&mut res);
349
350        let value = if sign {
351            -i32::from(magnitude)
352        } else {
353            i32::from(magnitude)
354        };
355        self.keep_accumulating(res, value)
356    }
357
358    #[cold]
359    #[inline(never)]
360    fn cold_read_with_tree(&mut self, tree: &[TreeNode], start: usize) -> BitResult<i8> {
361        let mut index = start;
362        let mut res = self.start_accumulated_result();
363
364        loop {
365            let node = tree[index];
366            let prob = node.prob;
367            let b = self.cold_read_bit(prob).or_accumulate(&mut res);
368            let t = if b { node.right } else { node.left };
369            let new_index = usize::from(t);
370            if new_index < tree.len() {
371                index = new_index;
372            } else {
373                let value = TreeNode::value_from_branch(t);
374                return self.keep_accumulating(res, value);
375            }
376        }
377    }
378}
379
380impl FastDecoder<'_> {
381    fn commit_if_valid<T>(self, value_if_not_past_eof: T) -> Option<T> {
382        // If `chunk_index > self.chunks.len()`, it means we used zeroes
383        // instead of an actual chunk and `value_if_not_past_eof` is nonsense.
384        if self.uncommitted_state.chunk_index <= self.chunks.len() {
385            *self.save_state = self.uncommitted_state;
386            Some(value_if_not_past_eof)
387        } else {
388            None
389        }
390    }
391
392    fn read_bool(mut self, probability: u8) -> Option<bool> {
393        let bit = self.fast_read_bit(probability);
394        self.commit_if_valid(bit)
395    }
396
397    fn read_flag(mut self) -> Option<bool> {
398        let value = self.fast_read_flag();
399        self.commit_if_valid(value)
400    }
401
402    fn read_literal(mut self, n: u8) -> Option<u8> {
403        let value = self.fast_read_literal(n);
404        self.commit_if_valid(value)
405    }
406
407    fn read_optional_signed_value(mut self, n: u8) -> Option<i32> {
408        let flag = self.fast_read_flag();
409        if !flag {
410            // We should not read further bits if the flag is not set.
411            return self.commit_if_valid(0);
412        }
413        let magnitude = self.fast_read_literal(n);
414        let sign = self.fast_read_flag();
415        let value = if sign {
416            -i32::from(magnitude)
417        } else {
418            i32::from(magnitude)
419        };
420        self.commit_if_valid(value)
421    }
422
423    fn read_with_tree(mut self, tree: &[TreeNode], first_node: TreeNode) -> Option<i8> {
424        let value = self.fast_read_with_tree(tree, first_node);
425        self.commit_if_valid(value)
426    }
427
428    fn fast_read_bit(&mut self, probability: u8) -> bool {
429        let State {
430            mut chunk_index,
431            mut value,
432            mut range,
433            mut bit_count,
434        } = self.uncommitted_state;
435
436        if bit_count < 0 {
437            let chunk = self.chunks.get(chunk_index).copied();
438            // We ignore invalid data inside the `fast_` functions,
439            // but we increase `chunk_index` below, so we can check
440            // whether we read invalid data in `commit_if_valid`.
441            let chunk = chunk.unwrap_or_default();
442
443            let v = u32::from_be_bytes(chunk);
444            chunk_index += 1;
445            value <<= 32;
446            value |= u64::from(v);
447            bit_count += 32;
448        }
449        debug_assert!(bit_count >= 0);
450
451        let probability = u32::from(probability);
452        let split = 1 + (((range - 1) * probability) >> 8);
453        let bigsplit = u64::from(split) << bit_count;
454
455        let retval = if let Some(new_value) = value.checked_sub(bigsplit) {
456            range -= split;
457            value = new_value;
458            true
459        } else {
460            range = split;
461            false
462        };
463        debug_assert!(range > 0);
464
465        // Compute shift required to satisfy `range >= 128`.
466        // Apply that shift to `range` and `self.bitcount`.
467        //
468        // Subtract 24 because we only care about leading zeros in the
469        // lowest byte of `range` which is a `u32`.
470        let shift = range.leading_zeros().saturating_sub(24);
471        range <<= shift;
472        bit_count -= shift as i32;
473        debug_assert!(range >= 128);
474
475        self.uncommitted_state = State {
476            chunk_index,
477            value,
478            range,
479            bit_count,
480        };
481        retval
482    }
483
484    fn fast_read_flag(&mut self) -> bool {
485        let State {
486            mut chunk_index,
487            mut value,
488            mut range,
489            mut bit_count,
490        } = self.uncommitted_state;
491
492        if bit_count < 0 {
493            let chunk = self.chunks.get(chunk_index).copied();
494            // We ignore invalid data inside the `fast_` functions,
495            // but we increase `chunk_index` below, so we can check
496            // whether we read invalid data in `commit_if_valid`.
497            let chunk = chunk.unwrap_or_default();
498
499            let v = u32::from_be_bytes(chunk);
500            chunk_index += 1;
501            value <<= 32;
502            value |= u64::from(v);
503            bit_count += 32;
504        }
505        debug_assert!(bit_count >= 0);
506
507        let half_range = range / 2;
508        let split = range - half_range;
509        let bigsplit = u64::from(split) << bit_count;
510
511        let retval = if let Some(new_value) = value.checked_sub(bigsplit) {
512            range = half_range;
513            value = new_value;
514            true
515        } else {
516            range = split;
517            false
518        };
519        debug_assert!(range > 0);
520
521        // Compute shift required to satisfy `range >= 128`.
522        // Apply that shift to `range` and `self.bitcount`.
523        //
524        // Subtract 24 because we only care about leading zeros in the
525        // lowest byte of `range` which is a `u32`.
526        let shift = range.leading_zeros().saturating_sub(24);
527        range <<= shift;
528        bit_count -= shift as i32;
529        debug_assert!(range >= 128);
530
531        self.uncommitted_state = State {
532            chunk_index,
533            value,
534            range,
535            bit_count,
536        };
537        retval
538    }
539
540    fn fast_read_literal(&mut self, n: u8) -> u8 {
541        let mut v = 0u8;
542        for _ in 0..n {
543            let b = self.fast_read_flag();
544            v = (v << 1) + u8::from(b);
545        }
546        v
547    }
548
549    fn fast_read_with_tree(&mut self, tree: &[TreeNode], mut node: TreeNode) -> i8 {
550        loop {
551            let prob = node.prob;
552            let b = self.fast_read_bit(prob);
553            let i = if b { node.right } else { node.left };
554            let Some(next_node) = tree.get(usize::from(i)) else {
555                return TreeNode::value_from_branch(i);
556            };
557            node = *next_node;
558        }
559    }
560}
561
562#[cfg(test)]
563mod tests {
564    use super::*;
565
566    #[test]
567    fn test_arithmetic_decoder_hello_short() {
568        let mut decoder = ArithmeticDecoder::new();
569        let data = b"hel";
570        let size = data.len();
571        let mut buf = vec![[0u8; 4]; 1];
572        buf.as_mut_slice().as_flattened_mut()[..size].copy_from_slice(&data[..]);
573        decoder.init(buf, size).unwrap();
574        let mut res = decoder.start_accumulated_result();
575        assert_eq!(false, decoder.read_flag().or_accumulate(&mut res));
576        assert_eq!(true, decoder.read_bool(10).or_accumulate(&mut res));
577        assert_eq!(false, decoder.read_bool(250).or_accumulate(&mut res));
578        assert_eq!(1, decoder.read_literal(1).or_accumulate(&mut res));
579        assert_eq!(5, decoder.read_literal(3).or_accumulate(&mut res));
580        assert_eq!(64, decoder.read_literal(8).or_accumulate(&mut res));
581        assert_eq!(185, decoder.read_literal(8).or_accumulate(&mut res));
582        decoder.check(res, ()).unwrap();
583    }
584
585    #[test]
586    fn test_arithmetic_decoder_hello_long() {
587        let mut decoder = ArithmeticDecoder::new();
588        let data = b"hello world";
589        let size = data.len();
590        let mut buf = vec![[0u8; 4]; (size + 3) / 4];
591        buf.as_mut_slice().as_flattened_mut()[..size].copy_from_slice(&data[..]);
592        decoder.init(buf, size).unwrap();
593        let mut res = decoder.start_accumulated_result();
594        assert_eq!(false, decoder.read_flag().or_accumulate(&mut res));
595        assert_eq!(true, decoder.read_bool(10).or_accumulate(&mut res));
596        assert_eq!(false, decoder.read_bool(250).or_accumulate(&mut res));
597        assert_eq!(1, decoder.read_literal(1).or_accumulate(&mut res));
598        assert_eq!(5, decoder.read_literal(3).or_accumulate(&mut res));
599        assert_eq!(64, decoder.read_literal(8).or_accumulate(&mut res));
600        assert_eq!(185, decoder.read_literal(8).or_accumulate(&mut res));
601        assert_eq!(31, decoder.read_literal(8).or_accumulate(&mut res));
602        decoder.check(res, ()).unwrap();
603    }
604
605    #[test]
606    fn test_arithmetic_decoder_uninit() {
607        let mut decoder = ArithmeticDecoder::new();
608        let mut res = decoder.start_accumulated_result();
609        let _ = decoder.read_flag().or_accumulate(&mut res);
610        let result = decoder.check(res, ());
611        assert!(result.is_err());
612    }
613}