async_compression/generic/write/
decoder.rs

1use crate::{
2    codecs::DecodeV2,
3    core::util::{PartialBuffer, WriteBuffer},
4    generic::write::AsyncBufWrite,
5};
6use futures_core::ready;
7use std::{
8    io,
9    pin::Pin,
10    task::{Context, Poll},
11};
12
13#[derive(Debug)]
14enum State {
15    Decoding,
16    Finishing,
17    Done,
18}
19
20#[derive(Debug)]
21pub struct Decoder {
22    state: State,
23}
24
25impl Default for Decoder {
26    fn default() -> Self {
27        Self {
28            state: State::Decoding,
29        }
30    }
31}
32
33impl Decoder {
34    fn do_poll_write(
35        &mut self,
36        cx: &mut Context<'_>,
37        input: &mut PartialBuffer<&[u8]>,
38        mut writer: Pin<&mut dyn AsyncBufWrite>,
39        decoder: &mut dyn DecodeV2,
40    ) -> Poll<io::Result<()>> {
41        loop {
42            let output = ready!(writer.as_mut().poll_partial_flush_buf(cx))?;
43            let mut output = WriteBuffer::new_initialized(output);
44
45            self.state = match self.state {
46                State::Decoding => {
47                    if decoder.decode(input, &mut output)? {
48                        State::Finishing
49                    } else {
50                        State::Decoding
51                    }
52                }
53
54                State::Finishing => {
55                    if decoder.finish(&mut output)? {
56                        State::Done
57                    } else {
58                        State::Finishing
59                    }
60                }
61
62                State::Done => {
63                    return Poll::Ready(Err(io::Error::other("Write after end of stream")))
64                }
65            };
66
67            let produced = output.written_len();
68            writer.as_mut().produce(produced);
69
70            if let State::Done = self.state {
71                return Poll::Ready(Ok(()));
72            }
73
74            if input.unwritten().is_empty() {
75                return Poll::Ready(Ok(()));
76            }
77        }
78    }
79
80    pub fn poll_write(
81        &mut self,
82        cx: &mut Context<'_>,
83        buf: &[u8],
84        writer: Pin<&mut dyn AsyncBufWrite>,
85        decoder: &mut dyn DecodeV2,
86    ) -> Poll<io::Result<usize>> {
87        if buf.is_empty() {
88            return Poll::Ready(Ok(0));
89        }
90
91        let mut input = PartialBuffer::new(buf);
92
93        match self.do_poll_write(cx, &mut input, writer, decoder)? {
94            Poll::Pending if input.written().is_empty() => Poll::Pending,
95            _ => Poll::Ready(Ok(input.written().len())),
96        }
97    }
98
99    pub fn do_poll_flush(
100        &mut self,
101        cx: &mut Context<'_>,
102        mut writer: Pin<&mut dyn AsyncBufWrite>,
103        decoder: &mut dyn DecodeV2,
104    ) -> Poll<io::Result<()>> {
105        loop {
106            let output = ready!(writer.as_mut().poll_partial_flush_buf(cx))?;
107            let mut output = WriteBuffer::new_initialized(output);
108
109            let (state, done) = match self.state {
110                State::Decoding => {
111                    let done = decoder.flush(&mut output)?;
112                    (State::Decoding, done)
113                }
114
115                State::Finishing => {
116                    if decoder.finish(&mut output)? {
117                        (State::Done, false)
118                    } else {
119                        (State::Finishing, false)
120                    }
121                }
122
123                State::Done => (State::Done, true),
124            };
125
126            self.state = state;
127
128            let produced = output.written_len();
129            writer.as_mut().produce(produced);
130
131            if done {
132                break Poll::Ready(Ok(()));
133            }
134        }
135    }
136
137    pub fn do_close(&mut self) {
138        if let State::Decoding = self.state {
139            self.state = State::Finishing;
140        }
141    }
142
143    pub fn is_done(&self) -> bool {
144        matches!(self.state, State::Done)
145    }
146}
147
148macro_rules! impl_decoder {
149    ($poll_close: tt) => {
150        use crate::{
151            codecs::DecodeV2, core::util::PartialBuffer, generic::write::Decoder as GenericDecoder,
152        };
153        use futures_core::ready;
154        use pin_project_lite::pin_project;
155
156        pin_project! {
157            #[derive(Debug)]
158            pub struct Decoder<W, D> {
159                #[pin]
160                writer: BufWriter<W>,
161                decoder: D,
162                inner: GenericDecoder,
163            }
164        }
165
166        impl<W: AsyncWrite, D: DecodeV2> Decoder<W, D> {
167            pub fn new(writer: W, decoder: D) -> Self {
168                Self {
169                    writer: BufWriter::new(writer),
170                    decoder,
171                    inner: Default::default(),
172                }
173            }
174        }
175
176        impl<W, D> Decoder<W, D> {
177            pub fn get_ref(&self) -> &W {
178                self.writer.get_ref()
179            }
180
181            pub fn get_mut(&mut self) -> &mut W {
182                self.writer.get_mut()
183            }
184
185            pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
186                self.project().writer.get_pin_mut()
187            }
188
189            pub fn into_inner(self) -> W {
190                self.writer.into_inner()
191            }
192        }
193
194        impl<W: AsyncWrite, D: DecodeV2> Decoder<W, D> {
195            fn do_poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
196                let mut this = self.project();
197
198                this.inner.do_poll_flush(cx, this.writer, this.decoder)
199            }
200        }
201
202        impl<W: AsyncWrite, D: DecodeV2> AsyncWrite for Decoder<W, D> {
203            fn poll_write(
204                self: Pin<&mut Self>,
205                cx: &mut Context<'_>,
206                buf: &[u8],
207            ) -> Poll<io::Result<usize>> {
208                let mut this = self.project();
209
210                this.inner.poll_write(cx, buf, this.writer, this.decoder)
211            }
212
213            fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
214                ready!(self.as_mut().do_poll_flush(cx))?;
215                self.project().writer.poll_flush(cx)
216            }
217
218            fn $poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
219                self.as_mut().project().inner.do_close();
220
221                ready!(self.as_mut().do_poll_flush(cx))?;
222
223                let this = self.project();
224                if this.inner.is_done() {
225                    this.writer.$poll_close(cx)
226                } else {
227                    Poll::Ready(Err(io::Error::other(
228                        "Attempt to close before finishing input",
229                    )))
230                }
231            }
232        }
233
234        impl<W: AsyncBufRead, D> AsyncBufRead for Decoder<W, D> {
235            fn poll_fill_buf(
236                self: Pin<&mut Self>,
237                cx: &mut Context<'_>,
238            ) -> Poll<io::Result<&[u8]>> {
239                self.get_pin_mut().poll_fill_buf(cx)
240            }
241
242            fn consume(self: Pin<&mut Self>, amt: usize) {
243                self.get_pin_mut().consume(amt)
244            }
245        }
246    };
247}
248pub(crate) use impl_decoder;