async_compression/tokio/write/generic/
encoder.rs

1use crate::codecs::Encode;
2use crate::core::util::PartialBuffer;
3use crate::tokio::write::{AsyncBufWrite, BufWriter};
4use futures_core::ready;
5use pin_project_lite::pin_project;
6use std::{
7    io,
8    pin::Pin,
9    task::{Context, Poll},
10};
11use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
12
13#[derive(Debug)]
14enum State {
15    Encoding,
16    Flushing,
17    Finishing,
18    Done,
19}
20
21pin_project! {
22    #[derive(Debug)]
23    pub struct Encoder<W, E> {
24        #[pin]
25        writer: BufWriter<W>,
26        encoder: E,
27        state: State,
28    }
29}
30
31impl<W: AsyncWrite, E: Encode> Encoder<W, E> {
32    pub fn new(writer: W, encoder: E) -> Self {
33        Self {
34            writer: BufWriter::new(writer),
35            encoder,
36            state: State::Encoding,
37        }
38    }
39
40    pub fn with_capacity(writer: W, encoder: E, cap: usize) -> Self {
41        Self {
42            writer: BufWriter::with_capacity(cap, writer),
43            encoder,
44            state: State::Encoding,
45        }
46    }
47}
48
49impl<W, E> Encoder<W, E> {
50    pub fn get_ref(&self) -> &W {
51        self.writer.get_ref()
52    }
53
54    pub fn get_mut(&mut self) -> &mut W {
55        self.writer.get_mut()
56    }
57
58    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
59        self.project().writer.get_pin_mut()
60    }
61
62    pub(crate) fn get_encoder_ref(&self) -> &E {
63        &self.encoder
64    }
65
66    pub fn into_inner(self) -> W {
67        self.writer.into_inner()
68    }
69}
70
71impl<W: AsyncWrite, E: Encode> Encoder<W, E> {
72    fn do_poll_write(
73        self: Pin<&mut Self>,
74        cx: &mut Context<'_>,
75        input: &mut PartialBuffer<&[u8]>,
76    ) -> Poll<io::Result<()>> {
77        let mut this = self.project();
78
79        loop {
80            let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
81            let mut output = PartialBuffer::new(output);
82
83            *this.state = match this.state {
84                State::Encoding => {
85                    this.encoder.encode(input, &mut output)?;
86                    State::Encoding
87                }
88
89                // Once a flush has been started, it must be completed.
90                State::Flushing => match this.encoder.flush(&mut output)? {
91                    true => State::Encoding,
92                    false => State::Flushing,
93                },
94
95                State::Finishing | State::Done => {
96                    return Poll::Ready(Err(io::Error::other("Write after shutdown")))
97                }
98            };
99
100            let produced = output.written().len();
101            this.writer.as_mut().produce(produced);
102
103            if input.unwritten().is_empty() {
104                return Poll::Ready(Ok(()));
105            }
106        }
107    }
108
109    fn do_poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
110        let mut this = self.project();
111
112        loop {
113            let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
114            let mut output = PartialBuffer::new(output);
115
116            let done = match this.state {
117                State::Encoding | State::Flushing => this.encoder.flush(&mut output)?,
118
119                State::Finishing | State::Done => {
120                    return Poll::Ready(Err(io::Error::other("Flush after shutdown")))
121                }
122            };
123            *this.state = State::Flushing;
124
125            let produced = output.written().len();
126            this.writer.as_mut().produce(produced);
127
128            if done {
129                *this.state = State::Encoding;
130                return Poll::Ready(Ok(()));
131            }
132        }
133    }
134
135    fn do_poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
136        let mut this = self.project();
137
138        loop {
139            let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
140            let mut output = PartialBuffer::new(output);
141
142            *this.state = match this.state {
143                State::Encoding | State::Finishing => {
144                    if this.encoder.finish(&mut output)? {
145                        State::Done
146                    } else {
147                        State::Finishing
148                    }
149                }
150
151                // Once a flush has been started, it must be completed.
152                State::Flushing => match this.encoder.flush(&mut output)? {
153                    true => State::Finishing,
154                    false => State::Flushing,
155                },
156
157                State::Done => State::Done,
158            };
159
160            let produced = output.written().len();
161            this.writer.as_mut().produce(produced);
162
163            if let State::Done = this.state {
164                return Poll::Ready(Ok(()));
165            }
166        }
167    }
168}
169
170impl<W: AsyncWrite, E: Encode> AsyncWrite for Encoder<W, E> {
171    fn poll_write(
172        self: Pin<&mut Self>,
173        cx: &mut Context<'_>,
174        buf: &[u8],
175    ) -> Poll<io::Result<usize>> {
176        if buf.is_empty() {
177            return Poll::Ready(Ok(0));
178        }
179
180        let mut input = PartialBuffer::new(buf);
181
182        match self.do_poll_write(cx, &mut input)? {
183            Poll::Pending if input.written().is_empty() => Poll::Pending,
184            _ => Poll::Ready(Ok(input.written().len())),
185        }
186    }
187
188    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
189        ready!(self.as_mut().do_poll_flush(cx))?;
190        ready!(self.project().writer.as_mut().poll_flush(cx))?;
191        Poll::Ready(Ok(()))
192    }
193
194    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
195        ready!(self.as_mut().do_poll_shutdown(cx))?;
196        ready!(self.project().writer.as_mut().poll_shutdown(cx))?;
197        Poll::Ready(Ok(()))
198    }
199}
200
201impl<W: AsyncRead, E> AsyncRead for Encoder<W, E> {
202    fn poll_read(
203        self: Pin<&mut Self>,
204        cx: &mut Context<'_>,
205        buf: &mut ReadBuf<'_>,
206    ) -> Poll<io::Result<()>> {
207        self.get_pin_mut().poll_read(cx, buf)
208    }
209}
210
211impl<W: AsyncBufRead, E> AsyncBufRead for Encoder<W, E> {
212    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
213        self.get_pin_mut().poll_fill_buf(cx)
214    }
215
216    fn consume(self: Pin<&mut Self>, amt: usize) {
217        self.get_pin_mut().consume(amt)
218    }
219}