async_compression/generic/bufread/
decoder.rs

1use crate::{
2    codecs::DecodeV2,
3    core::util::{PartialBuffer, WriteBuffer},
4};
5
6use std::{io::Result, ops::ControlFlow};
7
8#[derive(Debug)]
9enum State {
10    Decoding,
11    Flushing,
12    Done,
13    Next,
14}
15
16#[derive(Debug)]
17pub struct Decoder {
18    state: State,
19    multiple_members: bool,
20}
21
22impl Default for Decoder {
23    fn default() -> Self {
24        Self {
25            state: State::Decoding,
26            multiple_members: false,
27        }
28    }
29}
30
31impl Decoder {
32    pub fn multiple_members(&mut self, enabled: bool) {
33        self.multiple_members = enabled;
34    }
35
36    pub fn do_poll_read(
37        &mut self,
38        output: &mut WriteBuffer<'_>,
39        decoder: &mut dyn DecodeV2,
40        input: &mut PartialBuffer<&[u8]>,
41        mut first: bool,
42    ) -> ControlFlow<Result<()>> {
43        loop {
44            self.state = match self.state {
45                State::Decoding => {
46                    if input.unwritten().is_empty() && !first {
47                        // Avoid attempting to reinitialise the decoder if the
48                        // reader has returned EOF.
49                        self.multiple_members = false;
50
51                        State::Flushing
52                    } else {
53                        match decoder.decode(input, output) {
54                            Ok(true) => State::Flushing,
55                            // ignore the first error, occurs when input is empty
56                            // but we need to run decode to flush
57                            Err(err) if !first => return ControlFlow::Break(Err(err)),
58                            // poll for more data for the next decode
59                            _ => break,
60                        }
61                    }
62                }
63
64                State::Flushing => {
65                    match decoder.finish(output) {
66                        Ok(true) => {
67                            if self.multiple_members {
68                                if let Err(err) = decoder.reinit() {
69                                    return ControlFlow::Break(Err(err));
70                                }
71
72                                // The decode stage might consume all the input,
73                                // the next stage might need to poll again if it's empty.
74                                first = true;
75                                State::Next
76                            } else {
77                                State::Done
78                            }
79                        }
80                        Ok(false) => State::Flushing,
81                        Err(err) => return ControlFlow::Break(Err(err)),
82                    }
83                }
84
85                State::Done => return ControlFlow::Break(Ok(())),
86
87                State::Next => {
88                    if input.unwritten().is_empty() {
89                        if first {
90                            // poll for more data to check if there's another stream
91                            break;
92                        }
93                        State::Done
94                    } else {
95                        State::Decoding
96                    }
97                }
98            };
99
100            if output.has_no_spare_space() {
101                return ControlFlow::Break(Ok(()));
102            }
103        }
104
105        if output.has_no_spare_space() {
106            ControlFlow::Break(Ok(()))
107        } else {
108            ControlFlow::Continue(())
109        }
110    }
111}
112
113macro_rules! impl_decoder {
114    () => {
115        use crate::generic::bufread::Decoder as GenericDecoder;
116
117        use std::ops::ControlFlow;
118
119        use futures_core::ready;
120        use pin_project_lite::pin_project;
121
122        pin_project! {
123            #[derive(Debug)]
124            pub struct Decoder<R, D> {
125                #[pin]
126                reader: R,
127                decoder: D,
128                inner: GenericDecoder,
129            }
130        }
131
132        impl<R: AsyncBufRead, D: DecodeV2> Decoder<R, D> {
133            pub fn new(reader: R, decoder: D) -> Self {
134                Self {
135                    reader,
136                    decoder,
137                    inner: GenericDecoder::default(),
138                }
139            }
140        }
141
142        impl<R, D> Decoder<R, D> {
143            pub fn get_ref(&self) -> &R {
144                &self.reader
145            }
146
147            pub fn get_mut(&mut self) -> &mut R {
148                &mut self.reader
149            }
150
151            pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
152                self.project().reader
153            }
154
155            pub fn into_inner(self) -> R {
156                self.reader
157            }
158
159            pub fn multiple_members(&mut self, enabled: bool) {
160                self.inner.multiple_members(enabled);
161            }
162        }
163
164        fn do_poll_read(
165            inner: &mut GenericDecoder,
166            decoder: &mut dyn DecodeV2,
167            mut reader: Pin<&mut dyn AsyncBufRead>,
168            cx: &mut Context<'_>,
169            output: &mut WriteBuffer<'_>,
170        ) -> Poll<Result<()>> {
171            if let ControlFlow::Break(res) =
172                inner.do_poll_read(output, decoder, &mut PartialBuffer::new(&[][..]), true)
173            {
174                return Poll::Ready(res);
175            }
176
177            loop {
178                let mut input = PartialBuffer::new(ready!(reader.as_mut().poll_fill_buf(cx))?);
179
180                let control_flow = inner.do_poll_read(output, decoder, &mut input, false);
181
182                let bytes_read = input.written().len();
183                reader.as_mut().consume(bytes_read);
184
185                if let ControlFlow::Break(res) = control_flow {
186                    break Poll::Ready(res);
187                }
188            }
189        }
190
191        impl<R: AsyncBufRead, D: DecodeV2> Decoder<R, D> {
192            fn do_poll_read(
193                self: Pin<&mut Self>,
194                cx: &mut Context<'_>,
195                output: &mut WriteBuffer<'_>,
196            ) -> Poll<Result<()>> {
197                let this = self.project();
198
199                do_poll_read(this.inner, this.decoder, this.reader, cx, output)
200            }
201        }
202    };
203}
204pub(crate) use impl_decoder;