async_compression/generic/bufread/
decoder.rs

1use crate::{
2    codecs::DecodeV2,
3    core::util::{PartialBuffer, WriteBuffer},
4};
5use std::{io::Result, ops::ControlFlow, panic::AssertUnwindSafe};
6
7#[derive(Debug)]
8enum State {
9    Decoding,
10    Flushing,
11    Done,
12    Next,
13    Error(AssertUnwindSafe<std::io::Error>),
14}
15
16#[derive(Debug)]
17pub struct Decoder {
18    state: State,
19    multiple_members: bool,
20}
21
22impl Default for Decoder {
23    fn default() -> Self {
24        Self {
25            state: State::Decoding,
26            multiple_members: false,
27        }
28    }
29}
30
31impl Decoder {
32    pub fn multiple_members(&mut self, enabled: bool) {
33        self.multiple_members = enabled;
34    }
35
36    pub fn do_poll_read(
37        &mut self,
38        output: &mut WriteBuffer<'_>,
39        decoder: &mut dyn DecodeV2,
40        input: &mut PartialBuffer<&[u8]>,
41        mut first: bool,
42    ) -> ControlFlow<Result<()>> {
43        loop {
44            self.state = match self.state {
45                State::Decoding => {
46                    if input.unwritten().is_empty() && !first {
47                        // Avoid attempting to reinitialise the decoder if the
48                        // reader has returned EOF.
49                        self.multiple_members = false;
50
51                        State::Flushing
52                    } else {
53                        match decoder.decode(input, output) {
54                            Ok(true) => State::Flushing,
55                            // ignore the first error, occurs when input is empty
56                            // but we need to run decode to flush
57                            Err(err) if !first => {
58                                self.state = State::Error(AssertUnwindSafe(err));
59                                if output.written_len() > 0 {
60                                    return ControlFlow::Break(Ok(()));
61                                } else {
62                                    continue;
63                                }
64                            }
65                            // poll for more data for the next decode
66                            _ => break,
67                        }
68                    }
69                }
70
71                State::Flushing => {
72                    match decoder.finish(output) {
73                        Ok(true) => {
74                            if self.multiple_members {
75                                if let Err(err) = decoder.reinit() {
76                                    self.state = State::Error(AssertUnwindSafe(err));
77                                    if output.written_len() > 0 {
78                                        return ControlFlow::Break(Ok(()));
79                                    } else {
80                                        continue;
81                                    }
82                                }
83
84                                // The decode stage might consume all the input,
85                                // the next stage might need to poll again if it's empty.
86                                first = true;
87                                State::Next
88                            } else {
89                                State::Done
90                            }
91                        }
92                        Ok(false) => State::Flushing,
93                        Err(err) => {
94                            self.state = State::Error(AssertUnwindSafe(err));
95                            if output.written_len() > 0 {
96                                return ControlFlow::Break(Ok(()));
97                            } else {
98                                continue;
99                            }
100                        }
101                    }
102                }
103
104                State::Done => return ControlFlow::Break(Ok(())),
105
106                State::Next => {
107                    if input.unwritten().is_empty() {
108                        if first {
109                            // poll for more data to check if there's another stream
110                            break;
111                        }
112                        State::Done
113                    } else {
114                        State::Decoding
115                    }
116                }
117
118                State::Error(_) => {
119                    let State::Error(err) = std::mem::replace(&mut self.state, State::Done) else {
120                        unreachable!()
121                    };
122                    return ControlFlow::Break(Err(err.0));
123                }
124            };
125
126            if output.has_no_spare_space() {
127                return ControlFlow::Break(Ok(()));
128            }
129        }
130
131        if output.has_no_spare_space() {
132            ControlFlow::Break(Ok(()))
133        } else {
134            ControlFlow::Continue(())
135        }
136    }
137}
138
139macro_rules! impl_decoder {
140    () => {
141        use crate::generic::bufread::Decoder as GenericDecoder;
142
143        use std::{ops::ControlFlow, task::ready};
144
145        use pin_project_lite::pin_project;
146
147        pin_project! {
148            #[derive(Debug)]
149            pub struct Decoder<R, D> {
150                #[pin]
151                reader: R,
152                decoder: D,
153                inner: GenericDecoder,
154            }
155        }
156
157        impl<R: AsyncBufRead, D: DecodeV2> Decoder<R, D> {
158            pub fn new(reader: R, decoder: D) -> Self {
159                Self {
160                    reader,
161                    decoder,
162                    inner: GenericDecoder::default(),
163                }
164            }
165        }
166
167        impl<R, D> Decoder<R, D> {
168            pub fn get_ref(&self) -> &R {
169                &self.reader
170            }
171
172            pub fn get_mut(&mut self) -> &mut R {
173                &mut self.reader
174            }
175
176            pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
177                self.project().reader
178            }
179
180            pub fn into_inner(self) -> R {
181                self.reader
182            }
183
184            pub fn multiple_members(&mut self, enabled: bool) {
185                self.inner.multiple_members(enabled);
186            }
187        }
188
189        fn do_poll_read(
190            inner: &mut GenericDecoder,
191            decoder: &mut dyn DecodeV2,
192            mut reader: Pin<&mut dyn AsyncBufRead>,
193            cx: &mut Context<'_>,
194            output: &mut WriteBuffer<'_>,
195        ) -> Poll<Result<()>> {
196            if let ControlFlow::Break(res) =
197                inner.do_poll_read(output, decoder, &mut PartialBuffer::new(&[][..]), true)
198            {
199                return Poll::Ready(res);
200            }
201
202            loop {
203                let mut input = PartialBuffer::new(match reader.as_mut().poll_fill_buf(cx)? {
204                    Poll::Ready(input) => input,
205                    Poll::Pending if output.written().is_empty() => return Poll::Pending,
206                    _ => return Poll::Ready(Ok(())),
207                });
208
209                let control_flow = inner.do_poll_read(output, decoder, &mut input, false);
210
211                let bytes_read = input.written().len();
212                reader.as_mut().consume(bytes_read);
213
214                if let ControlFlow::Break(res) = control_flow {
215                    break Poll::Ready(res);
216                }
217            }
218        }
219
220        impl<R: AsyncBufRead, D: DecodeV2> Decoder<R, D> {
221            fn do_poll_read(
222                self: Pin<&mut Self>,
223                cx: &mut Context<'_>,
224                output: &mut WriteBuffer<'_>,
225            ) -> Poll<Result<()>> {
226                let this = self.project();
227
228                do_poll_read(this.inner, this.decoder, this.reader, cx, output)
229            }
230        }
231    };
232}
233pub(crate) use impl_decoder;