async_compression/generic/write/
decoder.rs

1use crate::{
2    codecs::DecodeV2,
3    core::util::{PartialBuffer, WriteBuffer},
4    generic::write::AsyncBufWrite,
5};
6use futures_core::ready;
7use std::{
8    io,
9    pin::Pin,
10    task::{Context, Poll},
11};
12
13#[derive(Debug)]
14enum State {
15    Decoding,
16    Finishing,
17    Done,
18}
19
20#[derive(Debug)]
21pub struct Decoder {
22    state: State,
23}
24
25impl Default for Decoder {
26    fn default() -> Self {
27        Self {
28            state: State::Decoding,
29        }
30    }
31}
32
33impl Decoder {
34    fn do_poll_write(
35        &mut self,
36        cx: &mut Context<'_>,
37        input: &mut PartialBuffer<&[u8]>,
38        mut writer: Pin<&mut dyn AsyncBufWrite>,
39        decoder: &mut dyn DecodeV2,
40    ) -> Poll<io::Result<()>> {
41        loop {
42            let mut output = ready!(writer.as_mut().poll_partial_flush_buf(cx))?;
43            let output = &mut output.write_buffer;
44
45            self.state = match self.state {
46                State::Decoding => {
47                    if decoder.decode(input, output)? {
48                        State::Finishing
49                    } else {
50                        State::Decoding
51                    }
52                }
53
54                State::Finishing => {
55                    if decoder.finish(output)? {
56                        State::Done
57                    } else {
58                        State::Finishing
59                    }
60                }
61
62                State::Done => {
63                    return Poll::Ready(Err(io::Error::other("Write after end of stream")))
64                }
65            };
66
67            if let State::Done = self.state {
68                return Poll::Ready(Ok(()));
69            }
70
71            if input.unwritten().is_empty() {
72                return Poll::Ready(Ok(()));
73            }
74        }
75    }
76
77    pub fn poll_write(
78        &mut self,
79        cx: &mut Context<'_>,
80        buf: &[u8],
81        writer: Pin<&mut dyn AsyncBufWrite>,
82        decoder: &mut dyn DecodeV2,
83    ) -> Poll<io::Result<usize>> {
84        if buf.is_empty() {
85            return Poll::Ready(Ok(0));
86        }
87
88        let mut input = PartialBuffer::new(buf);
89
90        match self.do_poll_write(cx, &mut input, writer, decoder)? {
91            Poll::Pending if input.written().is_empty() => Poll::Pending,
92            _ => Poll::Ready(Ok(input.written().len())),
93        }
94    }
95
96    pub fn do_poll_flush(
97        &mut self,
98        cx: &mut Context<'_>,
99        mut writer: Pin<&mut dyn AsyncBufWrite>,
100        decoder: &mut dyn DecodeV2,
101    ) -> Poll<io::Result<()>> {
102        loop {
103            let mut output = ready!(writer.as_mut().poll_partial_flush_buf(cx))?;
104            let output = &mut output.write_buffer;
105
106            let (state, done) = match self.state {
107                State::Decoding => {
108                    let done = decoder.flush(output)?;
109                    (State::Decoding, done)
110                }
111
112                State::Finishing => {
113                    if decoder.finish(output)? {
114                        (State::Done, false)
115                    } else {
116                        (State::Finishing, false)
117                    }
118                }
119
120                State::Done => (State::Done, true),
121            };
122
123            self.state = state;
124
125            if done {
126                break Poll::Ready(Ok(()));
127            }
128        }
129    }
130
131    pub fn do_close(&mut self) {
132        if let State::Decoding = self.state {
133            self.state = State::Finishing;
134        }
135    }
136
137    pub fn is_done(&self) -> bool {
138        matches!(self.state, State::Done)
139    }
140}
141
142macro_rules! impl_decoder {
143    ($poll_close: tt) => {
144        use crate::{
145            codecs::DecodeV2, core::util::PartialBuffer, generic::write::Decoder as GenericDecoder,
146        };
147        use futures_core::ready;
148        use pin_project_lite::pin_project;
149
150        pin_project! {
151            #[derive(Debug)]
152            pub struct Decoder<W, D> {
153                #[pin]
154                writer: BufWriter<W>,
155                decoder: D,
156                inner: GenericDecoder,
157            }
158        }
159
160        impl<W: AsyncWrite, D: DecodeV2> Decoder<W, D> {
161            pub fn new(writer: W, decoder: D) -> Self {
162                Self {
163                    writer: BufWriter::new(writer),
164                    decoder,
165                    inner: Default::default(),
166                }
167            }
168        }
169
170        impl<W, D> Decoder<W, D> {
171            pub fn get_ref(&self) -> &W {
172                self.writer.get_ref()
173            }
174
175            pub fn get_mut(&mut self) -> &mut W {
176                self.writer.get_mut()
177            }
178
179            pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
180                self.project().writer.get_pin_mut()
181            }
182
183            pub fn into_inner(self) -> W {
184                self.writer.into_inner()
185            }
186        }
187
188        impl<W: AsyncWrite, D: DecodeV2> Decoder<W, D> {
189            fn do_poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
190                let mut this = self.project();
191
192                this.inner.do_poll_flush(cx, this.writer, this.decoder)
193            }
194        }
195
196        impl<W: AsyncWrite, D: DecodeV2> AsyncWrite for Decoder<W, D> {
197            fn poll_write(
198                self: Pin<&mut Self>,
199                cx: &mut Context<'_>,
200                buf: &[u8],
201            ) -> Poll<io::Result<usize>> {
202                let mut this = self.project();
203
204                this.inner.poll_write(cx, buf, this.writer, this.decoder)
205            }
206
207            fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
208                ready!(self.as_mut().do_poll_flush(cx))?;
209                self.project().writer.poll_flush(cx)
210            }
211
212            fn $poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
213                self.as_mut().project().inner.do_close();
214
215                ready!(self.as_mut().do_poll_flush(cx))?;
216
217                let this = self.project();
218                if this.inner.is_done() {
219                    this.writer.$poll_close(cx)
220                } else {
221                    Poll::Ready(Err(io::Error::other(
222                        "Attempt to close before finishing input",
223                    )))
224                }
225            }
226        }
227
228        impl<W: AsyncBufRead, D> AsyncBufRead for Decoder<W, D> {
229            fn poll_fill_buf(
230                self: Pin<&mut Self>,
231                cx: &mut Context<'_>,
232            ) -> Poll<io::Result<&[u8]>> {
233                self.get_pin_mut().poll_fill_buf(cx)
234            }
235
236            fn consume(self: Pin<&mut Self>, amt: usize) {
237                self.get_pin_mut().consume(amt)
238            }
239        }
240    };
241}
242pub(crate) use impl_decoder;