async_compression/generic/write/
decoder.rs

1use crate::{
2    codecs::DecodeV2,
3    core::util::{PartialBuffer, WriteBuffer},
4    generic::write::AsyncBufWrite,
5};
6use std::{
7    io,
8    pin::Pin,
9    task::{ready, Context, Poll},
10};
11
12#[derive(Debug)]
13enum State {
14    Decoding,
15    Finishing,
16    Done,
17}
18
19#[derive(Debug)]
20pub struct Decoder {
21    state: State,
22}
23
24impl Default for Decoder {
25    fn default() -> Self {
26        Self {
27            state: State::Decoding,
28        }
29    }
30}
31
32impl Decoder {
33    fn do_poll_write(
34        &mut self,
35        cx: &mut Context<'_>,
36        input: &mut PartialBuffer<&[u8]>,
37        mut writer: Pin<&mut dyn AsyncBufWrite>,
38        decoder: &mut dyn DecodeV2,
39    ) -> Poll<io::Result<()>> {
40        loop {
41            let mut output = ready!(writer.as_mut().poll_partial_flush_buf(cx))?;
42            let output = &mut output.write_buffer;
43
44            self.state = match self.state {
45                State::Decoding => {
46                    if decoder.decode(input, output)? {
47                        State::Finishing
48                    } else {
49                        State::Decoding
50                    }
51                }
52
53                State::Finishing => {
54                    if decoder.finish(output)? {
55                        State::Done
56                    } else {
57                        State::Finishing
58                    }
59                }
60
61                State::Done => {
62                    return Poll::Ready(Err(io::Error::other("Write after end of stream")));
63                }
64            };
65
66            if let State::Done = self.state {
67                return Poll::Ready(Ok(()));
68            }
69
70            if input.unwritten().is_empty() {
71                return Poll::Ready(Ok(()));
72            }
73        }
74    }
75
76    pub fn poll_write(
77        &mut self,
78        cx: &mut Context<'_>,
79        buf: &[u8],
80        writer: Pin<&mut dyn AsyncBufWrite>,
81        decoder: &mut dyn DecodeV2,
82    ) -> Poll<io::Result<usize>> {
83        if buf.is_empty() {
84            return Poll::Ready(Ok(0));
85        }
86
87        let mut input = PartialBuffer::new(buf);
88
89        match self.do_poll_write(cx, &mut input, writer, decoder)? {
90            Poll::Pending if input.written().is_empty() => Poll::Pending,
91            _ => Poll::Ready(Ok(input.written().len())),
92        }
93    }
94
95    pub fn do_poll_flush(
96        &mut self,
97        cx: &mut Context<'_>,
98        mut writer: Pin<&mut dyn AsyncBufWrite>,
99        decoder: &mut dyn DecodeV2,
100    ) -> Poll<io::Result<()>> {
101        loop {
102            let mut output = ready!(writer.as_mut().poll_partial_flush_buf(cx))?;
103            let output = &mut output.write_buffer;
104
105            let (state, done) = match self.state {
106                State::Decoding => {
107                    let done = decoder.flush(output)?;
108                    (State::Decoding, done)
109                }
110
111                State::Finishing => {
112                    if decoder.finish(output)? {
113                        (State::Done, false)
114                    } else {
115                        (State::Finishing, false)
116                    }
117                }
118
119                State::Done => (State::Done, true),
120            };
121
122            self.state = state;
123
124            if done {
125                break Poll::Ready(Ok(()));
126            }
127        }
128    }
129
130    pub fn do_close(&mut self) {
131        if let State::Decoding = self.state {
132            self.state = State::Finishing;
133        }
134    }
135
136    pub fn is_done(&self) -> bool {
137        matches!(self.state, State::Done)
138    }
139}
140
141macro_rules! impl_decoder {
142    ($poll_close: tt) => {
143        use crate::{
144            codecs::DecodeV2, core::util::PartialBuffer, generic::write::Decoder as GenericDecoder,
145        };
146        use pin_project_lite::pin_project;
147        use std::task::ready;
148
149        pin_project! {
150            #[derive(Debug)]
151            pub struct Decoder<W, D> {
152                #[pin]
153                writer: BufWriter<W>,
154                decoder: D,
155                inner: GenericDecoder,
156            }
157        }
158
159        impl<W: AsyncWrite, D: DecodeV2> Decoder<W, D> {
160            pub fn new(writer: W, decoder: D) -> Self {
161                Self {
162                    writer: BufWriter::new(writer),
163                    decoder,
164                    inner: Default::default(),
165                }
166            }
167        }
168
169        impl<W, D> Decoder<W, D> {
170            pub fn get_ref(&self) -> &W {
171                self.writer.get_ref()
172            }
173
174            pub fn get_mut(&mut self) -> &mut W {
175                self.writer.get_mut()
176            }
177
178            pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
179                self.project().writer.get_pin_mut()
180            }
181
182            pub fn into_inner(self) -> W {
183                self.writer.into_inner()
184            }
185        }
186
187        impl<W: AsyncWrite, D: DecodeV2> Decoder<W, D> {
188            fn do_poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
189                let mut this = self.project();
190
191                this.inner.do_poll_flush(cx, this.writer, this.decoder)
192            }
193        }
194
195        impl<W: AsyncWrite, D: DecodeV2> AsyncWrite for Decoder<W, D> {
196            fn poll_write(
197                self: Pin<&mut Self>,
198                cx: &mut Context<'_>,
199                buf: &[u8],
200            ) -> Poll<io::Result<usize>> {
201                let mut this = self.project();
202
203                this.inner.poll_write(cx, buf, this.writer, this.decoder)
204            }
205
206            fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
207                ready!(self.as_mut().do_poll_flush(cx))?;
208                self.project().writer.poll_flush(cx)
209            }
210
211            fn $poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
212                self.as_mut().project().inner.do_close();
213
214                ready!(self.as_mut().do_poll_flush(cx))?;
215
216                let this = self.project();
217                if this.inner.is_done() {
218                    this.writer.$poll_close(cx)
219                } else {
220                    Poll::Ready(Err(io::Error::other(
221                        "Attempt to close before finishing input",
222                    )))
223                }
224            }
225        }
226
227        impl<W: AsyncBufRead, D> AsyncBufRead for Decoder<W, D> {
228            fn poll_fill_buf(
229                self: Pin<&mut Self>,
230                cx: &mut Context<'_>,
231            ) -> Poll<io::Result<&[u8]>> {
232                self.get_pin_mut().poll_fill_buf(cx)
233            }
234
235            fn consume(self: Pin<&mut Self>, amt: usize) {
236                self.get_pin_mut().consume(amt)
237            }
238        }
239    };
240}
241pub(crate) use impl_decoder;