async_compression/tokio/write/generic/
decoder.rs

1use crate::codecs::Decode;
2use crate::core::util::PartialBuffer;
3use crate::tokio::write::{AsyncBufWrite, BufWriter};
4use futures_core::ready;
5use pin_project_lite::pin_project;
6use std::{
7    io,
8    pin::Pin,
9    task::{Context, Poll},
10};
11use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
12
13#[derive(Debug)]
14enum State {
15    Decoding,
16    Finishing,
17    Done,
18}
19
20pin_project! {
21    #[derive(Debug)]
22    pub struct Decoder<W, D> {
23        #[pin]
24        writer: BufWriter<W>,
25        decoder: D,
26        state: State,
27    }
28}
29
30impl<W: AsyncWrite, D: Decode> Decoder<W, D> {
31    pub fn new(writer: W, decoder: D) -> Self {
32        Self {
33            writer: BufWriter::new(writer),
34            decoder,
35            state: State::Decoding,
36        }
37    }
38}
39
40impl<W, D> Decoder<W, D> {
41    pub fn get_ref(&self) -> &W {
42        self.writer.get_ref()
43    }
44
45    pub fn get_mut(&mut self) -> &mut W {
46        self.writer.get_mut()
47    }
48
49    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
50        self.project().writer.get_pin_mut()
51    }
52
53    pub fn into_inner(self) -> W {
54        self.writer.into_inner()
55    }
56}
57
58impl<W: AsyncWrite, D: Decode> Decoder<W, D> {
59    fn do_poll_write(
60        self: Pin<&mut Self>,
61        cx: &mut Context<'_>,
62        input: &mut PartialBuffer<&[u8]>,
63    ) -> Poll<io::Result<()>> {
64        let mut this = self.project();
65
66        loop {
67            let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
68            let mut output = PartialBuffer::new(output);
69
70            *this.state = match this.state {
71                State::Decoding => {
72                    if this.decoder.decode(input, &mut output)? {
73                        State::Finishing
74                    } else {
75                        State::Decoding
76                    }
77                }
78
79                State::Finishing => {
80                    if this.decoder.finish(&mut output)? {
81                        State::Done
82                    } else {
83                        State::Finishing
84                    }
85                }
86
87                State::Done => {
88                    return Poll::Ready(Err(io::Error::other("Write after end of stream")))
89                }
90            };
91
92            let produced = output.written().len();
93            this.writer.as_mut().produce(produced);
94
95            if let State::Done = this.state {
96                return Poll::Ready(Ok(()));
97            }
98
99            if input.unwritten().is_empty() {
100                return Poll::Ready(Ok(()));
101            }
102        }
103    }
104
105    fn do_poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
106        let mut this = self.project();
107
108        loop {
109            let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
110            let mut output = PartialBuffer::new(output);
111
112            let (state, done) = match this.state {
113                State::Decoding => {
114                    let done = this.decoder.flush(&mut output)?;
115                    (State::Decoding, done)
116                }
117
118                State::Finishing => {
119                    if this.decoder.finish(&mut output)? {
120                        (State::Done, false)
121                    } else {
122                        (State::Finishing, false)
123                    }
124                }
125
126                State::Done => (State::Done, true),
127            };
128
129            *this.state = state;
130
131            let produced = output.written().len();
132            this.writer.as_mut().produce(produced);
133
134            if done {
135                return Poll::Ready(Ok(()));
136            }
137        }
138    }
139}
140
141impl<W: AsyncWrite, D: Decode> AsyncWrite for Decoder<W, D> {
142    fn poll_write(
143        self: Pin<&mut Self>,
144        cx: &mut Context<'_>,
145        buf: &[u8],
146    ) -> Poll<io::Result<usize>> {
147        if buf.is_empty() {
148            return Poll::Ready(Ok(0));
149        }
150
151        let mut input = PartialBuffer::new(buf);
152
153        match self.do_poll_write(cx, &mut input)? {
154            Poll::Pending if input.written().is_empty() => Poll::Pending,
155            _ => Poll::Ready(Ok(input.written().len())),
156        }
157    }
158
159    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
160        ready!(self.as_mut().do_poll_flush(cx))?;
161        ready!(self.project().writer.as_mut().poll_flush(cx))?;
162        Poll::Ready(Ok(()))
163    }
164
165    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
166        if let State::Decoding = self.as_mut().project().state {
167            *self.as_mut().project().state = State::Finishing;
168        }
169
170        ready!(self.as_mut().do_poll_flush(cx))?;
171
172        if let State::Done = self.as_mut().project().state {
173            ready!(self.as_mut().project().writer.as_mut().poll_shutdown(cx))?;
174            Poll::Ready(Ok(()))
175        } else {
176            Poll::Ready(Err(io::Error::other(
177                "Attempt to shutdown before finishing input",
178            )))
179        }
180    }
181}
182
183impl<W: AsyncRead, D> AsyncRead for Decoder<W, D> {
184    fn poll_read(
185        self: Pin<&mut Self>,
186        cx: &mut Context<'_>,
187        buf: &mut ReadBuf<'_>,
188    ) -> Poll<io::Result<()>> {
189        self.get_pin_mut().poll_read(cx, buf)
190    }
191}
192
193impl<W: AsyncBufRead, D> AsyncBufRead for Decoder<W, D> {
194    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
195        self.get_pin_mut().poll_fill_buf(cx)
196    }
197
198    fn consume(self: Pin<&mut Self>, amt: usize) {
199        self.get_pin_mut().consume(amt)
200    }
201}