zune_jpeg/bitstream.rs
1/*
2 * Copyright (c) 2023.
3 *
4 * This software is free software;
5 *
6 * You can redistribute it or modify it under terms of the MIT, Apache License or Zlib license
7 */
8
9#![allow(
10 clippy::if_not_else,
11 clippy::similar_names,
12 clippy::inline_always,
13 clippy::doc_markdown,
14 clippy::cast_sign_loss,
15 clippy::cast_possible_truncation
16)]
17
18//! This file exposes a single struct that can decode a huffman encoded
19//! Bitstream in a JPEG file
20//!
21//! This code is optimized for speed.
22//! It's meant to be super duper super fast, because everyone else depends on this being fast.
23//! It's (annoyingly) serial hence we cant use parallel bitstreams(it's variable length coding.)
24//!
25//! Furthermore, on the case of refills, we have to do bytewise processing because the standard decided
26//! that we want to support markers in the middle of streams(seriously few people use RST markers).
27//!
28//! So we pull in all optimization steps:
29//! - use `inline[always]`? ✅ ,
30//! - pre-execute most common cases ✅,
31//! - add random comments ✅
32//! - fast paths ✅.
33//!
34//! Speed-wise: It is probably the fastest JPEG BitStream decoder to ever sail the seven seas because of
35//! a couple of optimization tricks.
36//! 1. Fast refills from libjpeg-turbo
37//! 2. As few as possible branches in decoder fast paths.
38//! 3. Accelerated AC table decoding borrowed from stb_image.h written by Fabian Gissen (@ rygorous),
39//! improved by me to handle more cases.
40//! 4. Safe and extensible routines(e.g. cool ways to eliminate bounds check)
41//! 5. No unsafe here
42//!
43//! Readability comes as a second priority(I tried with variable names this time, and we are wayy better than libjpeg).
44//!
45//! Anyway if you are reading this it means your cool and I hope you get whatever part of the code you are looking for
46//! (or learn something cool)
47//!
48//! Knock yourself out.
49use alloc::string::ToString;
50use core::cmp::min;
51use alloc::format;
52
53use zune_core::bytestream::{ZByteReader, ZReaderTrait};
54
55use crate::errors::DecodeErrors;
56use crate::huffman::{HuffmanTable, HUFF_LOOKAHEAD};
57use crate::marker::Marker;
58use crate::mcu::DCT_BLOCK;
59use crate::misc::UN_ZIGZAG;
60
61macro_rules! decode_huff {
62 ($stream:tt,$symbol:tt,$table:tt) => {
63 let mut code_length = $symbol >> HUFF_LOOKAHEAD;
64
65 ($symbol) &= (1 << HUFF_LOOKAHEAD) - 1;
66
67 if code_length > i32::from(HUFF_LOOKAHEAD)
68 {
69 // if the symbol cannot be resolved in the first HUFF_LOOKAHEAD bits,
70 // we know it lies somewhere between HUFF_LOOKAHEAD and 16 bits since jpeg imposes 16 bit
71 // limit, we can therefore look 16 bits ahead and try to resolve the symbol
72 // starting from 1+HUFF_LOOKAHEAD bits.
73 $symbol = ($stream).peek_bits::<16>() as i32;
74 // (Credits to Sean T. Barrett stb library for this optimization)
75 // maxcode is pre-shifted 16 bytes long so that it has (16-code_length)
76 // zeroes at the end hence we do not need to shift in the inner loop.
77 while code_length < 17{
78 if $symbol < $table.maxcode[code_length as usize] {
79 break;
80 }
81 code_length += 1;
82 }
83
84 if code_length == 17{
85 // symbol could not be decoded.
86 //
87 // We may think, lets fake zeroes, noo
88 // panic, because Huffman codes are sensitive, probably everything
89 // after this will be corrupt, so no need to continue.
90 // panic!("Bad Huffman code length");
91 return Err(DecodeErrors::Format(format!("Bad Huffman Code 0x{:X}, corrupt JPEG",$symbol)))
92 }
93
94 $symbol >>= (16-code_length);
95 ($symbol) = i32::from(
96 ($table).values
97 [(($symbol + ($table).offset[code_length as usize]) & 0xFF) as usize],
98 );
99 }
100 // drop bits read
101 ($stream).drop_bits(code_length as u8);
102 };
103}
104
105/// A `BitStream` struct, a bit by bit reader with super powers
106///
107pub(crate) struct BitStream {
108 /// A MSB type buffer that is used for some certain operations
109 pub buffer: u64,
110 /// A TOP aligned MSB type buffer that is used to accelerate some operations like
111 /// peek_bits and get_bits.
112 ///
113 /// By top aligned, I mean the top bit (63) represents the top bit in the buffer.
114 aligned_buffer: u64,
115 /// Tell us the bits left the two buffer
116 pub(crate) bits_left: u8,
117 /// Did we find a marker(RST/EOF) during decoding?
118 pub marker: Option<Marker>,
119
120 /// Progressive decoding
121 pub successive_high: u8,
122 pub successive_low: u8,
123 spec_start: u8,
124 spec_end: u8,
125 pub eob_run: i32,
126 pub overread_by: usize,
127 /// True if we have seen end of image marker.
128 /// Don't read anything after that.
129 pub seen_eoi: bool,
130}
131
132impl BitStream {
133 /// Create a new BitStream
134 pub(crate) const fn new() -> BitStream {
135 BitStream {
136 buffer: 0,
137 aligned_buffer: 0,
138 bits_left: 0,
139 marker: None,
140 successive_high: 0,
141 successive_low: 0,
142 spec_start: 0,
143 spec_end: 0,
144 eob_run: 0,
145 overread_by: 0,
146 seen_eoi: false,
147 }
148 }
149
150 /// Create a new Bitstream for progressive decoding
151 #[allow(clippy::redundant_field_names)]
152 pub(crate) fn new_progressive(ah: u8, al: u8, spec_start: u8, spec_end: u8) -> BitStream {
153 BitStream {
154 buffer: 0,
155 aligned_buffer: 0,
156 bits_left: 0,
157 marker: None,
158 successive_high: ah,
159 successive_low: al,
160 spec_start: spec_start,
161 spec_end: spec_end,
162 eob_run: 0,
163 overread_by: 0,
164 seen_eoi: false,
165 }
166 }
167
168 /// Refill the bit buffer by (a maximum of) 32 bits
169 ///
170 /// # Arguments
171 /// - `reader`:`&mut BufReader<R>`: A mutable reference to an underlying
172 /// File/Memory buffer containing a valid JPEG stream
173 ///
174 /// This function will only refill if `self.count` is less than 32
175 #[inline(always)] // to many call sites? ( perf improvement by 4%)
176 pub(crate) fn refill<T>(&mut self, reader: &mut ZByteReader<T>) -> Result<bool, DecodeErrors>
177 where
178 T: ZReaderTrait
179 {
180 /// Macro version of a single byte refill.
181 /// Arguments
182 /// buffer-> our io buffer, because rust macros cannot get values from
183 /// the surrounding environment bits_left-> number of bits left
184 /// to full refill
185 macro_rules! refill {
186 ($buffer:expr,$byte:expr,$bits_left:expr) => {
187 // read a byte from the stream
188 $byte = u64::from(reader.get_u8());
189 self.overread_by += usize::from(reader.eof());
190 // append to the buffer
191 // JPEG is a MSB type buffer so that means we append this
192 // to the lower end (0..8) of the buffer and push the rest bits above..
193 $buffer = ($buffer << 8) | $byte;
194 // Increment bits left
195 $bits_left += 8;
196 // Check for special case of OxFF, to see if it's a stream or a marker
197 if $byte == 0xff {
198 // read next byte
199 let mut next_byte = u64::from(reader.get_u8());
200 // Byte snuffing, if we encounter byte snuff, we skip the byte
201 if next_byte != 0x00 {
202 // skip that byte we read
203 while next_byte == 0xFF {
204 next_byte = u64::from(reader.get_u8());
205 }
206
207 if next_byte != 0x00 {
208 // Undo the byte append and return
209 $buffer >>= 8;
210 $bits_left -= 8;
211
212 if $bits_left != 0 {
213 self.aligned_buffer = $buffer << (64 - $bits_left);
214 }
215
216 self.marker = Marker::from_u8(next_byte as u8);
217 if next_byte == 0xD9 {
218 // special handling for eoi, fill some bytes,even if its zero,
219 // removes some panics
220 self.buffer <<= 8;
221 self.bits_left += 8;
222 self.aligned_buffer = self.buffer << (64 - self.bits_left);
223
224 }
225
226 return Ok(false);
227 }
228 }
229 }
230 };
231 }
232
233
234 // 32 bits is enough for a decode(16 bits) and receive_extend(max 16 bits)
235 if self.bits_left < 32 {
236 if self.marker.is_some() || self.overread_by > 0 || self.seen_eoi{
237 // found a marker, or we are in EOI
238 // also we are in over-reading mode, where we fill it with zeroes
239
240 // fill with zeroes
241 self.buffer <<= 32;
242 self.bits_left += 32;
243 self.aligned_buffer = self.buffer << (64 - self.bits_left);
244 return Ok(true);
245 }
246
247
248 // we optimize for the case where we don't have 255 in the stream and have 4 bytes left
249 // as it is the common case
250 //
251 // so we always read 4 bytes, if read_fixed_bytes errors out, the cursor is
252 // guaranteed not to advance in case of failure (is this true), so
253 // we revert the read later on (if we have 255), if this fails, we use the normal
254 // byte at a time read
255
256 if let Ok(bytes) = reader.get_fixed_bytes_or_err::<4>() {
257 // we have 4 bytes to spare, read the 4 bytes into a temporary buffer
258 // create buffer
259 let msb_buf = u32::from_be_bytes(bytes);
260 // check if we have 0xff
261 if !has_byte(msb_buf, 255) {
262 self.bits_left += 32;
263 self.buffer <<= 32;
264 self.buffer |= u64::from(msb_buf);
265 self.aligned_buffer = self.buffer << (64 - self.bits_left);
266 return Ok(true);
267 }
268
269 reader.rewind(4);
270 }
271 // This serves two reasons,
272 // 1: Make clippy shut up
273 // 2: Favour register reuse
274 let mut byte;
275 // 4 refills, if all succeed the stream should contain enough bits to decode a
276 // value
277 refill!(self.buffer, byte, self.bits_left);
278 refill!(self.buffer, byte, self.bits_left);
279 refill!(self.buffer, byte, self.bits_left);
280 refill!(self.buffer, byte, self.bits_left);
281 // Construct an MSB buffer whose top bits are the bitstream we are currently holding.
282 self.aligned_buffer = self.buffer << (64 - self.bits_left);
283 }
284 return Ok(true);
285 }
286 /// Decode the DC coefficient in a MCU block.
287 ///
288 /// The decoded coefficient is written to `dc_prediction`
289 ///
290 #[allow(
291 clippy::cast_possible_truncation,
292 clippy::cast_sign_loss,
293 clippy::unwrap_used
294 )]
295 #[inline(always)]
296 fn decode_dc<T>(
297 &mut self, reader: &mut ZByteReader<T>, dc_table: &HuffmanTable, dc_prediction: &mut i32
298 ) -> Result<bool, DecodeErrors>
299 where
300 T: ZReaderTrait
301 {
302 let (mut symbol, r);
303
304 if self.bits_left < 32 {
305 self.refill(reader)?;
306 };
307 // look a head HUFF_LOOKAHEAD bits into the bitstream
308 symbol = self.peek_bits::<HUFF_LOOKAHEAD>();
309 symbol = dc_table.lookup[symbol as usize];
310
311 decode_huff!(self, symbol, dc_table);
312
313 if symbol != 0 {
314 r = self.get_bits(symbol as u8);
315 symbol = huff_extend(r, symbol);
316 }
317 // Update DC prediction
318 *dc_prediction = dc_prediction.wrapping_add(symbol);
319
320 return Ok(true);
321 }
322
323 /// Decode a Minimum Code Unit(MCU) as quickly as possible
324 ///
325 /// # Arguments
326 /// - reader: The bitstream from where we read more bits.
327 /// - dc_table: The Huffman table used to decode the DC coefficient
328 /// - ac_table: The Huffman table used to decode AC values
329 /// - block: A memory region where we will write out the decoded values
330 /// - DC prediction: Last DC value for this component
331 ///
332 #[allow(
333 clippy::many_single_char_names,
334 clippy::cast_possible_truncation,
335 clippy::cast_sign_loss
336 )]
337 #[inline(never)]
338 pub fn decode_mcu_block<T>(
339 &mut self, reader: &mut ZByteReader<T>, dc_table: &HuffmanTable, ac_table: &HuffmanTable,
340 qt_table: &[i32; DCT_BLOCK], block: &mut [i32; 64], dc_prediction: &mut i32
341 ) -> Result<(), DecodeErrors>
342 where
343 T: ZReaderTrait
344 {
345 // Get fast AC table as a reference before we enter the hot path
346 let ac_lookup = ac_table.ac_lookup.as_ref().unwrap();
347
348 let (mut symbol, mut r, mut fast_ac);
349 // Decode AC coefficients
350 let mut pos: usize = 1;
351
352 // decode DC, dc prediction will contain the value
353 self.decode_dc(reader, dc_table, dc_prediction)?;
354
355 // set dc to be the dc prediction.
356 block[0] = *dc_prediction * qt_table[0];
357
358 while pos < 64 {
359 self.refill(reader)?;
360 symbol = self.peek_bits::<HUFF_LOOKAHEAD>();
361 fast_ac = ac_lookup[symbol as usize];
362 symbol = ac_table.lookup[symbol as usize];
363
364 if fast_ac != 0 {
365 // FAST AC path
366 pos += ((fast_ac >> 4) & 15) as usize; // run
367 let t_pos = UN_ZIGZAG[min(pos, 63)] & 63;
368
369 block[t_pos] = i32::from(fast_ac >> 8) * (qt_table[t_pos]); // Value
370 self.drop_bits((fast_ac & 15) as u8);
371 pos += 1;
372 } else {
373 decode_huff!(self, symbol, ac_table);
374
375 r = symbol >> 4;
376 symbol &= 15;
377
378 if symbol != 0 {
379 pos += r as usize;
380 r = self.get_bits(symbol as u8);
381 symbol = huff_extend(r, symbol);
382 let t_pos = UN_ZIGZAG[pos & 63] & 63;
383
384 block[t_pos] = symbol * qt_table[t_pos];
385
386 pos += 1;
387 } else if r != 15 {
388 return Ok(());
389 } else {
390 pos += 16;
391 }
392 }
393 }
394 return Ok(());
395 }
396
397 /// Peek `look_ahead` bits ahead without discarding them from the buffer
398 #[inline(always)]
399 #[allow(clippy::cast_possible_truncation)]
400 const fn peek_bits<const LOOKAHEAD: u8>(&self) -> i32 {
401 (self.aligned_buffer >> (64 - LOOKAHEAD)) as i32
402 }
403
404 /// Discard the next `N` bits without checking
405 #[inline]
406 fn drop_bits(&mut self, n: u8) {
407 debug_assert!(self.bits_left >= n);
408 //self.bits_left -= n;
409 self.bits_left = self.bits_left.saturating_sub(n);
410 self.aligned_buffer <<= n;
411 }
412
413 /// Read `n_bits` from the buffer and discard them
414 #[inline(always)]
415 #[allow(clippy::cast_possible_truncation)]
416 fn get_bits(&mut self, n_bits: u8) -> i32 {
417 let mask = (1_u64 << n_bits) - 1;
418
419 self.aligned_buffer = self.aligned_buffer.rotate_left(u32::from(n_bits));
420 let bits = (self.aligned_buffer & mask) as i32;
421 self.bits_left = self.bits_left.wrapping_sub(n_bits);
422 bits
423 }
424
425 /// Decode a DC block
426 #[allow(clippy::cast_possible_truncation)]
427 #[inline]
428 pub(crate) fn decode_prog_dc_first<T>(
429 &mut self, reader: &mut ZByteReader<T>, dc_table: &HuffmanTable, block: &mut i16,
430 dc_prediction: &mut i32
431 ) -> Result<(), DecodeErrors>
432 where
433 T: ZReaderTrait
434 {
435 self.decode_dc(reader, dc_table, dc_prediction)?;
436 *block = (*dc_prediction as i16).wrapping_mul(1_i16 << self.successive_low);
437 return Ok(());
438 }
439 #[inline]
440 pub(crate) fn decode_prog_dc_refine<T>(
441 &mut self, reader: &mut ZByteReader<T>, block: &mut i16
442 ) -> Result<(), DecodeErrors>
443 where
444 T: ZReaderTrait
445 {
446 // refinement scan
447 if self.bits_left < 1 {
448 self.refill(reader)?;
449 }
450
451 if self.get_bit() == 1 {
452 *block = block.wrapping_add(1 << self.successive_low);
453 }
454
455 Ok(())
456 }
457
458 /// Get a single bit from the bitstream
459 fn get_bit(&mut self) -> u8 {
460 let k = (self.aligned_buffer >> 63) as u8;
461 // discard a bit
462 self.drop_bits(1);
463 return k;
464 }
465 pub(crate) fn decode_mcu_ac_first<T>(
466 &mut self, reader: &mut ZByteReader<T>, ac_table: &HuffmanTable, block: &mut [i16; 64]
467 ) -> Result<bool, DecodeErrors>
468 where
469 T: ZReaderTrait
470 {
471 let shift = self.successive_low;
472 let fast_ac = ac_table.ac_lookup.as_ref().unwrap();
473
474 let mut k = self.spec_start as usize;
475 let (mut symbol, mut r, mut fac);
476
477 // EOB runs are handled in mcu_prog.rs
478 'block: loop {
479 self.refill(reader)?;
480
481 symbol = self.peek_bits::<HUFF_LOOKAHEAD>();
482 fac = fast_ac[symbol as usize];
483 symbol = ac_table.lookup[symbol as usize];
484
485 if fac != 0 {
486 // fast ac path
487 k += ((fac >> 4) & 15) as usize; // run
488 block[UN_ZIGZAG[min(k, 63)] & 63] = (fac >> 8).wrapping_mul(1 << shift); // value
489 self.drop_bits((fac & 15) as u8);
490 k += 1;
491 } else {
492 decode_huff!(self, symbol, ac_table);
493
494 r = symbol >> 4;
495 symbol &= 15;
496
497 if symbol != 0 {
498 k += r as usize;
499 r = self.get_bits(symbol as u8);
500 symbol = huff_extend(r, symbol);
501 block[UN_ZIGZAG[k & 63] & 63] = (symbol as i16).wrapping_mul(1 << shift);
502 k += 1;
503 } else {
504 if r != 15 {
505 self.eob_run = 1 << r;
506 self.eob_run += self.get_bits(r as u8);
507 self.eob_run -= 1;
508 break;
509 }
510
511 k += 16;
512 }
513 }
514
515 if k > self.spec_end as usize {
516 break 'block;
517 }
518 }
519 return Ok(true);
520 }
521 #[allow(clippy::too_many_lines, clippy::op_ref)]
522 pub(crate) fn decode_mcu_ac_refine<T>(
523 &mut self, reader: &mut ZByteReader<T>, table: &HuffmanTable, block: &mut [i16; 64]
524 ) -> Result<bool, DecodeErrors>
525 where
526 T: ZReaderTrait
527 {
528 let bit = (1 << self.successive_low) as i16;
529
530 let mut k = self.spec_start;
531 let (mut symbol, mut r);
532
533 if self.eob_run == 0 {
534 'no_eob: loop {
535 // Decode a coefficient from the bit stream
536 self.refill(reader)?;
537
538 symbol = self.peek_bits::<HUFF_LOOKAHEAD>();
539 symbol = table.lookup[symbol as usize];
540
541 decode_huff!(self, symbol, table);
542
543 r = symbol >> 4;
544 symbol &= 15;
545
546 if symbol == 0 {
547 if r != 15 {
548 // EOB run is 2^r + bits
549 self.eob_run = 1 << r;
550 self.eob_run += self.get_bits(r as u8);
551 // EOB runs are handled by the eob logic
552 break 'no_eob;
553 }
554 } else {
555 if symbol != 1 {
556 return Err(DecodeErrors::HuffmanDecode(
557 "Bad Huffman code, corrupt JPEG?".to_string()
558 ));
559 }
560 // get sign bit
561 // We assume we have enough bits, which should be correct for sane images
562 // since we refill by 32 above
563 if self.get_bit() == 1 {
564 symbol = i32::from(bit);
565 } else {
566 symbol = i32::from(-bit);
567 }
568 }
569
570 // Advance over already nonzero coefficients appending
571 // correction bits to the non-zeroes.
572 // A correction bit is 1 if the absolute value of the coefficient must be increased
573
574 if k <= self.spec_end {
575 'advance_nonzero: loop {
576 let coefficient = &mut block[UN_ZIGZAG[k as usize & 63] & 63];
577
578 if *coefficient != 0 {
579 if self.get_bit() == 1 && (*coefficient & bit) == 0 {
580 if *coefficient >= 0 {
581 *coefficient += bit;
582 } else {
583 *coefficient -= bit;
584 }
585 }
586
587 if self.bits_left < 1 {
588 self.refill(reader)?;
589 }
590 } else {
591 r -= 1;
592
593 if r < 0 {
594 // reached target zero coefficient.
595 break 'advance_nonzero;
596 }
597 };
598
599 if k == self.spec_end {
600 break 'advance_nonzero;
601 }
602
603 k += 1;
604 }
605 }
606
607 if symbol != 0 {
608 let pos = UN_ZIGZAG[k as usize & 63];
609 // output new non-zero coefficient.
610 block[pos & 63] = symbol as i16;
611 }
612
613 k += 1;
614
615 if k > self.spec_end {
616 break 'no_eob;
617 }
618 }
619 }
620 if self.eob_run > 0 {
621 // only run if block does not consists of purely zeroes
622 if &block[1..] != &[0; 63] {
623 self.refill(reader)?;
624
625 while k <= self.spec_end {
626 let coefficient = &mut block[UN_ZIGZAG[k as usize & 63] & 63];
627
628 if *coefficient != 0 && self.get_bit() == 1 {
629 // check if we already modified it, if so do nothing, otherwise
630 // append the correction bit.
631 if (*coefficient & bit) == 0 {
632 if *coefficient >= 0 {
633 *coefficient = coefficient.wrapping_add(bit);
634 } else {
635 *coefficient = coefficient.wrapping_sub(bit);
636 }
637 }
638 }
639 if self.bits_left < 1 {
640 // refill at the last possible moment
641 self.refill(reader)?;
642 }
643 k += 1;
644 }
645 }
646 // count a block completed in EOB run
647 self.eob_run -= 1;
648 }
649 return Ok(true);
650 }
651
652 pub fn update_progressive_params(&mut self, ah: u8, al: u8, spec_start: u8, spec_end: u8) {
653 self.successive_high = ah;
654 self.successive_low = al;
655 self.spec_start = spec_start;
656 self.spec_end = spec_end;
657 }
658
659 /// Reset the stream if we have a restart marker
660 ///
661 /// Restart markers indicate drop those bits in the stream and zero out
662 /// everything
663 #[cold]
664 pub fn reset(&mut self) {
665 self.bits_left = 0;
666 self.marker = None;
667 self.buffer = 0;
668 self.aligned_buffer = 0;
669 self.eob_run = 0;
670 }
671}
672
673/// Do the equivalent of JPEG HUFF_EXTEND
674#[inline(always)]
675fn huff_extend(x: i32, s: i32) -> i32 {
676 // if x<s return x else return x+offset[s] where offset[s] = ( (-1<<s)+1)
677 (x) + ((((x) - (1 << ((s) - 1))) >> 31) & (((-1) << (s)) + 1))
678}
679
680const fn has_zero(v: u32) -> bool {
681 // Retrieved from Stanford bithacks
682 // @ https://graphics.stanford.edu/~seander/bithacks.html#ZeroInWord
683 return !((((v & 0x7F7F_7F7F) + 0x7F7F_7F7F) | v) | 0x7F7F_7F7F) != 0;
684}
685
686const fn has_byte(b: u32, val: u8) -> bool {
687 // Retrieved from Stanford bithacks
688 // @ https://graphics.stanford.edu/~seander/bithacks.html#ZeroInWord
689 has_zero(b ^ ((!0_u32 / 255) * (val as u32)))
690}
691
692// mod tests {
693// use zune_core::bytestream::ZCursor;
694// use crate::JpegDecoder;
695//
696// #[test]
697// fn test_image() {
698// let img = "/Users/etemesi/Downloads/PHO00008.JPG";
699// let data = std::fs::read(img).unwrap();
700// let mut decoder = JpegDecoder::new(ZCursor::new(&data[..]));
701// decoder.decode().unwrap();
702// }
703// }