zune_jpeg/
mcu_prog.rs

1/*
2 * Copyright (c) 2023.
3 *
4 * This software is free software;
5 *
6 * You can redistribute it or modify it under terms of the MIT, Apache License or Zlib license
7 */
8
9//!Routines for progressive decoding
10/*
11This file is needlessly complicated,
12
13It is that way to ensure we don't burn memory anyhow
14
15Memory is a scarce resource in some environments, I would like this to be viable
16in such environments
17
18Half of the complexity comes from the jpeg spec, because progressive decoding,
19is one hell of a ride.
20
21*/
22use alloc::string::ToString;
23use alloc::vec::Vec;
24use alloc::{format, vec};
25use core::cmp::min;
26use zune_core::bytestream::{ZByteReader, ZReaderTrait};
27use zune_core::colorspace::ColorSpace;
28use zune_core::log::{debug, error, warn};
29
30use crate::bitstream::BitStream;
31use crate::components::{ComponentID, SampleRatios};
32use crate::decoder::{JpegDecoder, MAX_COMPONENTS};
33use crate::errors::DecodeErrors;
34use crate::headers::{parse_sos};
35use crate::marker::Marker;
36use crate::mcu::DCT_BLOCK;
37use crate::misc::{calculate_padded_width, setup_component_params};
38
39impl<T: ZReaderTrait> JpegDecoder<T> {
40    /// Decode a progressive image
41    ///
42    /// This routine decodes a progressive image, stopping if it finds any error.
43    #[allow(
44        clippy::needless_range_loop,
45        clippy::cast_sign_loss,
46        clippy::redundant_else,
47        clippy::too_many_lines
48    )]
49    #[inline(never)]
50    pub(crate) fn decode_mcu_ycbcr_progressive(
51        &mut self, pixels: &mut [u8],
52    ) -> Result<(), DecodeErrors> {
53        setup_component_params(self)?;
54
55        let mut mcu_height;
56
57        // memory location for decoded pixels for components
58        let mut block: [Vec<i16>; MAX_COMPONENTS] = [vec![], vec![], vec![], vec![]];
59        let mut mcu_width;
60
61        let mut seen_scans = 1;
62
63        if self.input_colorspace == ColorSpace::Luma && self.is_interleaved {
64            warn!("Grayscale image with down-sampled component, resetting component details");
65            self.reset_params();
66        }
67
68        if self.is_interleaved {
69            // this helps us catch component errors.
70            self.set_upsampling()?;
71        }
72        if self.is_interleaved {
73            mcu_width = self.mcu_x;
74            mcu_height = self.mcu_y;
75        } else {
76            mcu_width = (self.info.width as usize + 7) / 8;
77            mcu_height = (self.info.height as usize + 7) / 8;
78        }
79        if self.is_interleaved
80            && self.input_colorspace.num_components() > 1
81            && self.options.jpeg_get_out_colorspace().num_components() == 1
82            && (self.sub_sample_ratio == SampleRatios::V
83                || self.sub_sample_ratio == SampleRatios::HV)
84        {
85            // For a specific set of images, e.g interleaved,
86            // when converting from YcbCr to grayscale, we need to
87            // take into account mcu height since the MCU decoding needs to take
88            // it into account for padding purposes and the post processor
89            // parses two rows per mcu width.
90            //
91            // set coeff to be 2 to ensure that we increment two rows
92            // for every mcu processed also
93            mcu_height *= self.v_max;
94            mcu_height /= self.h_max;
95            self.coeff = 2;
96        }
97
98        mcu_width *= 64;
99
100
101        for i in 0..self.input_colorspace.num_components() {
102            let comp = &self.components[i];
103            let len = mcu_width * comp.vertical_sample * comp.horizontal_sample * mcu_height;
104
105            block[i] = vec![0; len];
106        }
107
108        let mut stream = BitStream::new_progressive(
109            self.succ_high,
110            self.succ_low,
111            self.spec_start,
112            self.spec_end
113        );
114
115        // there are multiple scans in the stream, this should resolve the first scan
116        self.parse_entropy_coded_data(&mut stream, &mut block)?;
117
118        // extract marker
119        let mut marker = stream
120            .marker
121            .take()
122            .ok_or(DecodeErrors::FormatStatic("Marker missing where expected"))?;
123
124        // if marker is EOI, we are done, otherwise continue scanning.
125        //
126        // In case we have a premature image, we print a warning or return
127        // an error, depending on the strictness of the decoder, so there
128        // is that logic to handle too
129        'eoi: while marker != Marker::EOI {
130            match marker {
131                Marker::SOS => {
132                    parse_sos(self)?;
133
134                    stream.update_progressive_params(
135                        self.succ_high,
136                        self.succ_low,
137                        self.spec_start,
138                        self.spec_end
139                    );
140                    // after every SOS, marker, parse data for that scan.
141                    self.parse_entropy_coded_data(&mut stream, &mut block)?;
142                    // extract marker, might either indicate end of image or we continue
143                    // scanning(hence the continue statement to determine).
144                    match get_marker(&mut self.stream, &mut stream) {
145                        Ok(marker_n) => {
146                            marker = marker_n;
147                            seen_scans += 1;
148                            if seen_scans > self.options.jpeg_get_max_scans() {
149                                return Err(DecodeErrors::Format(format!(
150                                    "Too many scans, exceeded limit of {}",
151                                    self.options.jpeg_get_max_scans()
152                                )));
153                            }
154
155                            stream.reset();
156                            continue 'eoi;
157                        }
158                        Err(msg) => {
159                            if self.options.get_strict_mode() {
160                                return Err(msg);
161                            }
162                            error!("{:?}", msg);
163                            break 'eoi;
164                        }
165                    }
166                }
167                Marker::RST(_n) => {
168                    self.handle_rst(&mut stream)?;
169                }
170                _ => {
171                    self.parse_marker_inner(marker)?;
172                }
173            }
174
175            match get_marker(&mut self.stream, &mut stream) {
176                Ok(marker_n) => {
177                    marker = marker_n;
178                }
179                Err(e) => {
180                    if self.options.get_strict_mode() {
181                        return Err(e);
182                    }
183                    error!("{}", e);
184                }
185            }
186        }
187
188        self.finish_progressive_decoding(&block, mcu_width, pixels)
189    }
190
191    /// Reset progressive parameters
192    fn reset_prog_params(&mut self, stream: &mut BitStream) {
193        stream.reset();
194        self.components.iter_mut().for_each(|x| x.dc_pred = 0);
195
196        // Also reset JPEG restart intervals
197        self.todo = if self.restart_interval != 0 { self.restart_interval } else { usize::MAX };
198    }
199
200    #[allow(clippy::too_many_lines, clippy::cast_sign_loss)]
201    fn parse_entropy_coded_data(
202        &mut self, stream: &mut BitStream, buffer: &mut [Vec<i16>; MAX_COMPONENTS]
203    ) -> Result<(), DecodeErrors> {
204        self.reset_prog_params(stream);
205
206        if usize::from(self.num_scans) > self.input_colorspace.num_components() {
207            return Err(DecodeErrors::Format(format!(
208                "Number of scans {} cannot be greater than number of components, {}",
209                self.num_scans,
210                self.input_colorspace.num_components()
211            )));
212        }
213        if self.num_scans == 1 {
214            // Safety checks
215            if self.spec_end != 0 && self.spec_start == 0 {
216                return Err(DecodeErrors::FormatStatic(
217                    "Can't merge DC and AC corrupt jpeg"
218                ));
219            }
220            // non interleaved data, process one block at a time in trivial scanline order
221
222            let k = self.z_order[0];
223
224            if k >= self.components.len() {
225                return Err(DecodeErrors::Format(format!(
226                    "Cannot find component {k}, corrupt image"
227                )));
228            }
229
230            let (mcu_width, mcu_height);
231
232            if self.components[k].component_id == ComponentID::Y
233                && (self.components[k].vertical_sample != 1
234                    || self.components[k].horizontal_sample != 1)
235                || !self.is_interleaved
236            {
237                // For Y channel  or non interleaved scans ,
238                // mcu's is the image dimensions divided by 8
239                mcu_width = ((self.info.width + 7) / 8) as usize;
240                mcu_height = ((self.info.height + 7) / 8) as usize;
241            } else {
242                // For other channels, in an interleaved mcu, number of MCU's
243                // are determined by some weird maths done in headers.rs->parse_sos()
244                mcu_width = self.mcu_x;
245                mcu_height = self.mcu_y;
246            }
247
248            for i in 0..mcu_height {
249                for j in 0..mcu_width {
250                    if self.spec_start != 0 && self.succ_high == 0 && stream.eob_run > 0 {
251                        // handle EOB runs here.
252                        stream.eob_run -= 1;
253                    } else {
254                        let start = 64 * (j + i * (self.components[k].width_stride / 8));
255
256                        let data: &mut [i16; 64] = buffer
257                            .get_mut(k)
258                            .unwrap()
259                            .get_mut(start..start + 64)
260                            .unwrap()
261                            .try_into()
262                            .unwrap();
263
264                        if self.spec_start == 0 {
265                            let pos = self.components[k].dc_huff_table & (MAX_COMPONENTS - 1);
266                            let dc_table = self
267                                .dc_huffman_tables
268                                .get(pos)
269                                .ok_or(DecodeErrors::FormatStatic(
270                                    "No huffman table for DC component"
271                                ))?
272                                .as_ref()
273                                .ok_or(DecodeErrors::FormatStatic(
274                                    "Huffman table at index  {} not initialized"
275                                ))?;
276
277                            let dc_pred = &mut self.components[k].dc_pred;
278
279                            if self.succ_high == 0 {
280                                // first scan for this mcu
281                                stream.decode_prog_dc_first(
282                                    &mut self.stream,
283                                    dc_table,
284                                    &mut data[0],
285                                    dc_pred
286                                )?;
287                            } else {
288                                // refining scans for this MCU
289                                stream.decode_prog_dc_refine(&mut self.stream, &mut data[0])?;
290                            }
291                        } else {
292                            let pos = self.components[k].ac_huff_table;
293                            let ac_table = self
294                                .ac_huffman_tables
295                                .get(pos)
296                                .ok_or_else(|| {
297                                    DecodeErrors::Format(format!(
298                                        "No huffman table for component:{pos}"
299                                    ))
300                                })?
301                                .as_ref()
302                                .ok_or_else(|| {
303                                    DecodeErrors::Format(format!(
304                                        "Huffman table at index  {pos} not initialized"
305                                    ))
306                                })?;
307
308                            if self.succ_high == 0 {
309                                debug_assert!(stream.eob_run == 0, "EOB run is not zero");
310
311                                stream.decode_mcu_ac_first(&mut self.stream, ac_table, data)?;
312                            } else {
313                                // refinement scan
314                                stream.decode_mcu_ac_refine(&mut self.stream, ac_table, data)?;
315                            }
316                        }
317                    }
318
319                    // + EOB and investigate effect.
320                    self.todo -= 1;
321
322                    self.handle_rst_main(stream)?;
323                }
324            }
325        } else {
326            if self.spec_end != 0 {
327                return Err(DecodeErrors::HuffmanDecode(
328                    "Can't merge dc and AC corrupt jpeg".to_string()
329                ));
330            }
331            // process scan n elements in order
332
333            // Do the error checking with allocs here.
334            // Make the one in the inner loop free of allocations.
335            for k in 0..self.num_scans {
336                let n = self.z_order[k as usize];
337
338                if n >= self.components.len() {
339                    return Err(DecodeErrors::Format(format!(
340                        "Cannot find component {n}, corrupt image"
341                    )));
342                }
343
344                let component = &mut self.components[n];
345                let _ = self
346                    .dc_huffman_tables
347                    .get(component.dc_huff_table)
348                    .ok_or_else(|| {
349                        DecodeErrors::Format(format!(
350                            "No huffman table for component:{}",
351                            component.dc_huff_table
352                        ))
353                    })?
354                    .as_ref()
355                    .ok_or_else(|| {
356                        DecodeErrors::Format(format!(
357                            "Huffman table at index  {} not initialized",
358                            component.dc_huff_table
359                        ))
360                    })?;
361            }
362            // Interleaved scan
363
364            // Components shall not be interleaved in progressive mode, except for
365            // the DC coefficients in the first scan for each component of a progressive frame.
366            for i in 0..self.mcu_y {
367                for j in 0..self.mcu_x {
368                    // process scan n elements in order
369                    for k in 0..self.num_scans {
370                        let n = self.z_order[k as usize];
371                        let component = &mut self.components[n];
372                        let huff_table = self
373                            .dc_huffman_tables
374                            .get(component.dc_huff_table)
375                            .ok_or(DecodeErrors::FormatStatic("No huffman table for component"))?
376                            .as_ref()
377                            .ok_or(DecodeErrors::FormatStatic(
378                                "Huffman table at index not initialized"
379                            ))?;
380
381                        for v_samp in 0..component.vertical_sample {
382                            for h_samp in 0..component.horizontal_sample {
383                                let x2 = j * component.horizontal_sample + h_samp;
384                                let y2 = i * component.vertical_sample + v_samp;
385                                let position = 64 * (x2 + y2 * component.width_stride / 8);
386                                let buf_n = &mut buffer[n];
387
388                                let Some(data) = &mut buf_n.get_mut(position) else {
389                                    // TODO: (CAE), this is another weird sub-sampling bug, so on fix
390                                    // remove this
391                                    return Err(DecodeErrors::FormatStatic("Invalid image"));
392                                };
393
394                                if self.succ_high == 0 {
395                                    stream.decode_prog_dc_first(
396                                        &mut self.stream,
397                                        huff_table,
398                                        data,
399                                        &mut component.dc_pred
400                                    )?;
401                                } else {
402                                    stream.decode_prog_dc_refine(&mut self.stream, data)?;
403                                }
404                            }
405                        }
406                    }
407                    // We want wrapping subtraction here because it means
408                    // we get a higher number in the case this underflows
409                    self.todo -= 1;
410                    // after every scan that's a mcu, count down restart markers.
411                    self.handle_rst_main(stream)?;
412                }
413            }
414        }
415        return Ok(());
416    }
417
418    pub(crate) fn handle_rst_main(&mut self, stream: &mut BitStream) -> Result<(), DecodeErrors> {
419        if self.todo == 0 {
420            stream.refill(&mut self.stream)?;
421        }
422
423        if self.todo == 0
424            && self.restart_interval != 0
425            && stream.marker.is_none()
426            && !stream.seen_eoi
427        {
428            // if no marker and we are to reset RST, look for the marker, this matches
429            // libjpeg-turbo behaviour and allows us to decode images in
430            // https://github.com/etemesi254/zune-image/issues/261
431            let _start = self.stream.get_position();
432            // skip bytes until we find marker
433            let marker = get_marker(&mut self.stream, stream)?;
434            let _end = self.stream.get_position();
435            stream.marker = Some(marker);
436            // NB some warnings may be false positives.
437            warn!(
438                "{} Extraneous bytes before marker {:?}",
439                _end - _start,
440                marker
441            );
442        }
443        if self.todo == 0 {
444            self.handle_rst(stream)?
445        }
446        Ok(())
447    }
448    #[allow(clippy::too_many_lines)]
449    #[allow(clippy::needless_range_loop, clippy::cast_sign_loss)]
450    fn finish_progressive_decoding(
451        &mut self, block: &[Vec<i16>; MAX_COMPONENTS], _mcu_width: usize, pixels: &mut [u8]
452    ) -> Result<(), DecodeErrors> {
453        // This function is complicated because we need to replicate
454        // the function in mcu.rs
455        //
456        // The advantage is that we do very little allocation and very lot
457        // channel reusing.
458        // The trick is to notice that we repeat the same procedure per MCU
459        // width.
460        //
461        // So we can set it up that we only allocate temporary storage large enough
462        // to store a single mcu width, then reuse it per invocation.
463        //
464        // This is advantageous to us.
465        //
466        // Remember we need to have the whole MCU buffer so we store 3 unprocessed
467        // channels in memory, and then we allocate the whole output buffer in memory, both of
468        // which are huge.
469        //
470        //
471
472        let mcu_height = if self.is_interleaved {
473            self.mcu_y
474        } else {
475            // For non-interleaved images( (1*1) subsampling)
476            // number of MCU's are the widths (+7 to account for paddings) divided by 8.
477            ((self.info.height + 7) / 8) as usize
478        };
479
480        // Size of our output image(width*height)
481        let is_hv = usize::from(self.is_interleaved);
482        let upsampler_scratch_size = is_hv * self.components[0].width_stride;
483        let width = usize::from(self.info.width);
484        let padded_width = calculate_padded_width(width, self.sub_sample_ratio);
485
486        //let mut pixels = vec![0; capacity * out_colorspace_components];
487        let mut upsampler_scratch_space = vec![0; upsampler_scratch_size];
488        let mut tmp = [0_i32; DCT_BLOCK];
489
490        for (pos, comp) in self.components.iter_mut().enumerate() {
491            // Allocate only needed components.
492            //
493            // For special colorspaces i.e YCCK and CMYK, just allocate all of the needed
494            // components.
495            if min(
496                self.options.jpeg_get_out_colorspace().num_components() - 1,
497                pos,
498            ) == pos
499                || self.input_colorspace == ColorSpace::YCCK
500                || self.input_colorspace == ColorSpace::CMYK
501            {
502                // allocate enough space to hold a whole MCU width
503                // this means we should take into account sampling ratios
504                // `*8` is because each MCU spans 8 widths.
505                let len = comp.width_stride * comp.vertical_sample * 8;
506
507                comp.needed = true;
508                comp.raw_coeff = vec![0; len];
509            } else {
510                comp.needed = false;
511            }
512        }
513
514        let mut pixels_written = 0;
515
516        // dequantize, idct and color convert.
517        for i in 0..mcu_height {
518            'component: for (position, component) in &mut self.components.iter_mut().enumerate() {
519                if !component.needed {
520                    continue 'component;
521                }
522                let qt_table = &component.quantization_table;
523
524                // step is the number of pixels this iteration wil be handling
525                // Given by the number of mcu's height and the length of the component block
526                // Since the component block contains the whole channel as raw pixels
527                // we this evenly divides the pixels into MCU blocks
528                //
529                // For interleaved images, this gives us the exact pixels comprising a whole MCU
530                // block
531                let step = block[position].len() / mcu_height;
532                // where we will be reading our pixels from.
533                let start = i * step;
534
535                let slice = &block[position][start..start + step];
536
537                let temp_channel = &mut component.raw_coeff;
538
539                // The next logical step is to iterate width wise.
540                // To figure out how many pixels we iterate by we use effective pixels
541                // Given to us by component.x
542                // iterate per effective pixels.
543                let mcu_x = component.width_stride / 8;
544
545                // iterate per every vertical sample.
546                for k in 0..component.vertical_sample {
547                    for j in 0..mcu_x {
548                        // after writing a single stride, we need to skip 8 rows.
549                        // This does the row calculation
550                        let width_stride = k * 8 * component.width_stride;
551                        let start = j * 64 + width_stride;
552
553                        // See https://github.com/etemesi254/zune-image/issues/262 sample 3.
554                        let Some(qt_slice) = slice.get(start..start + 64) else {
555                            return Err(DecodeErrors::FormatStatic("Invalid slice , would panic, invalid image"))
556                        };
557                        // dequantize
558                        for ((x, out), qt_val) in qt_slice
559                            .iter()
560                            .zip(tmp.iter_mut())
561                            .zip(qt_table.iter())
562                        {
563                            *out = i32::from(*x) * qt_val;
564                        }
565                        // determine where to write.
566                        let sl = &mut temp_channel[component.idct_pos..];
567
568                        component.idct_pos += 8;
569                        // tmp now contains a dequantized block so idct it
570                        (self.idct_func)(&mut tmp, sl, component.width_stride);
571                    }
572                    // after every write of 8, skip 7 since idct write stride wise 8 times.
573                    //
574                    // Remember each MCU is 8x8 block, so each idct will write 8 strides into
575                    // sl
576                    //
577                    // and component.idct_pos is one stride long
578                    component.idct_pos += 7 * component.width_stride;
579                }
580                component.idct_pos = 0;
581            }
582
583            // process that width up until it's impossible
584            self.post_process(
585                pixels,
586                i,
587                mcu_height,
588                width,
589                padded_width,
590                &mut pixels_written,
591                &mut upsampler_scratch_space
592            )?;
593        }
594
595        debug!("Finished decoding image");
596
597        return Ok(());
598    }
599    pub(crate) fn reset_params(&mut self) {
600        /*
601        Apparently, grayscale images which can be down sampled exists, which is weird in the sense
602        that it has one component Y, which is not usually down sampled.
603
604        This means some calculations will be wrong, so for that we explicitly reset params
605        for such occurrences, warn and reset the image info to appear as if it were
606        a non-sampled image to ensure decoding works
607        */
608        self.h_max = 1;
609        self.options = self.options.jpeg_set_out_colorspace(ColorSpace::Luma);
610        self.v_max = 1;
611        self.sub_sample_ratio = SampleRatios::None;
612        self.is_interleaved = false;
613        self.components[0].vertical_sample = 1;
614        self.components[0].width_stride = (((self.info.width as usize) + 7) / 8) * 8;
615        self.components[0].horizontal_sample = 1;
616    }
617}
618
619///Get a marker from the bit-stream.
620///
621/// This reads until it gets a marker or end of file is encountered
622pub fn get_marker<T>(
623    reader: &mut ZByteReader<T>, stream: &mut BitStream
624) -> Result<Marker, DecodeErrors>
625where
626    T: ZReaderTrait
627{
628    if let Some(marker) = stream.marker {
629        stream.marker = None;
630        return Ok(marker);
631    }
632
633    // read until we get a marker
634
635    while !reader.eof() {
636        let marker = reader.get_u8_err()?;
637
638        if marker == 255 {
639            let mut r = reader.get_u8_err()?;
640            // 0xFF 0XFF(some images may be like that)
641            while r == 0xFF {
642                r = reader.get_u8_err()?;
643            }
644
645            if r != 0 {
646                return Marker::from_u8(r)
647                    .ok_or_else(|| DecodeErrors::Format(format!("Unknown marker 0xFF{r:X}")));
648            }
649        }
650    }
651    return Err(DecodeErrors::ExhaustedData);
652}