async_compression/generic/write/
encoder.rs

1use crate::{
2    codecs::EncodeV2,
3    core::util::{PartialBuffer, WriteBuffer},
4    generic::write::AsyncBufWrite,
5};
6use futures_core::ready;
7use std::{
8    io,
9    pin::Pin,
10    task::{Context, Poll},
11};
12
13#[derive(Debug)]
14enum State {
15    Encoding,
16    Finishing,
17    Done,
18}
19
20#[derive(Debug)]
21pub struct Encoder {
22    state: State,
23}
24
25impl Default for Encoder {
26    fn default() -> Self {
27        Self {
28            state: State::Encoding,
29        }
30    }
31}
32
33impl Encoder {
34    fn do_poll_write(
35        &mut self,
36        cx: &mut Context<'_>,
37        input: &mut PartialBuffer<&[u8]>,
38        mut writer: Pin<&mut dyn AsyncBufWrite>,
39        encoder: &mut dyn EncodeV2,
40    ) -> Poll<io::Result<()>> {
41        loop {
42            let output = ready!(writer.as_mut().poll_partial_flush_buf(cx))?;
43            let mut output = WriteBuffer::new_initialized(output);
44
45            self.state = match self.state {
46                State::Encoding => {
47                    encoder.encode(input, &mut output)?;
48                    State::Encoding
49                }
50
51                State::Finishing | State::Done => {
52                    break Poll::Ready(Err(io::Error::other("Write after close")))
53                }
54            };
55
56            let produced = output.written_len();
57            writer.as_mut().produce(produced);
58
59            if input.unwritten().is_empty() {
60                break Poll::Ready(Ok(()));
61            }
62        }
63    }
64
65    pub fn poll_write(
66        &mut self,
67        cx: &mut Context<'_>,
68        buf: &[u8],
69        writer: Pin<&mut dyn AsyncBufWrite>,
70        encoder: &mut dyn EncodeV2,
71    ) -> Poll<io::Result<usize>> {
72        if buf.is_empty() {
73            return Poll::Ready(Ok(0));
74        }
75
76        let mut input = PartialBuffer::new(buf);
77
78        match self.do_poll_write(cx, &mut input, writer, encoder)? {
79            Poll::Pending if input.written().is_empty() => Poll::Pending,
80            _ => Poll::Ready(Ok(input.written().len())),
81        }
82    }
83
84    pub fn do_poll_flush(
85        &mut self,
86        cx: &mut Context<'_>,
87        mut writer: Pin<&mut dyn AsyncBufWrite>,
88        encoder: &mut dyn EncodeV2,
89    ) -> Poll<io::Result<()>> {
90        loop {
91            let output = ready!(writer.as_mut().poll_partial_flush_buf(cx))?;
92            let mut output = WriteBuffer::new_initialized(output);
93
94            let done = match self.state {
95                State::Encoding => encoder.flush(&mut output)?,
96
97                State::Finishing | State::Done => {
98                    break Poll::Ready(Err(io::Error::other("Flush after close")))
99                }
100            };
101
102            let produced = output.written_len();
103            writer.as_mut().produce(produced);
104
105            if done {
106                break Poll::Ready(Ok(()));
107            }
108        }
109    }
110
111    pub fn do_poll_close(
112        &mut self,
113        cx: &mut Context<'_>,
114        mut writer: Pin<&mut dyn AsyncBufWrite>,
115        encoder: &mut dyn EncodeV2,
116    ) -> Poll<io::Result<()>> {
117        loop {
118            let output = ready!(writer.as_mut().poll_partial_flush_buf(cx))?;
119            let mut output = WriteBuffer::new_initialized(output);
120
121            self.state = match self.state {
122                State::Encoding | State::Finishing => {
123                    if encoder.finish(&mut output)? {
124                        State::Done
125                    } else {
126                        State::Finishing
127                    }
128                }
129
130                State::Done => State::Done,
131            };
132
133            let produced = output.written_len();
134            writer.as_mut().produce(produced);
135
136            if let State::Done = self.state {
137                break Poll::Ready(Ok(()));
138            }
139        }
140    }
141}
142
143macro_rules! impl_encoder {
144    ($poll_close: tt) => {
145        use crate::{codecs::EncodeV2, generic::write::Encoder as GenericEncoder};
146        use futures_core::ready;
147        use pin_project_lite::pin_project;
148
149        pin_project! {
150            #[derive(Debug)]
151            pub struct Encoder<W, E> {
152                #[pin]
153                writer: BufWriter<W>,
154                encoder: E,
155                inner: GenericEncoder,
156            }
157        }
158
159        impl<W, E> Encoder<W, E> {
160            pub fn get_ref(&self) -> &W {
161                self.writer.get_ref()
162            }
163
164            pub fn get_mut(&mut self) -> &mut W {
165                self.writer.get_mut()
166            }
167
168            pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
169                self.project().writer.get_pin_mut()
170            }
171
172            pub(crate) fn get_encoder_ref(&self) -> &E {
173                &self.encoder
174            }
175
176            pub fn into_inner(self) -> W {
177                self.writer.into_inner()
178            }
179
180            pub fn new(writer: W, encoder: E) -> Self {
181                Self {
182                    writer: BufWriter::new(writer),
183                    encoder,
184                    inner: Default::default(),
185                }
186            }
187
188            pub fn with_capacity(writer: W, encoder: E, cap: usize) -> Self {
189                Self {
190                    writer: BufWriter::with_capacity(cap, writer),
191                    encoder,
192                    inner: Default::default(),
193                }
194            }
195        }
196
197        impl<W: AsyncWrite, E: EncodeV2> AsyncWrite for Encoder<W, E> {
198            fn poll_write(
199                self: Pin<&mut Self>,
200                cx: &mut Context<'_>,
201                buf: &[u8],
202            ) -> Poll<io::Result<usize>> {
203                let this = self.project();
204                this.inner.poll_write(cx, buf, this.writer, this.encoder)
205            }
206
207            fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
208                let mut this = self.project();
209
210                ready!(this
211                    .inner
212                    .do_poll_flush(cx, this.writer.as_mut(), this.encoder))?;
213                this.writer.poll_flush(cx)
214            }
215
216            fn $poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
217                let mut this = self.project();
218
219                ready!(this
220                    .inner
221                    .do_poll_close(cx, this.writer.as_mut(), this.encoder))?;
222                this.writer.$poll_close(cx)
223            }
224        }
225
226        impl<W: AsyncBufRead, E> AsyncBufRead for Encoder<W, E> {
227            fn poll_fill_buf(
228                self: Pin<&mut Self>,
229                cx: &mut Context<'_>,
230            ) -> Poll<io::Result<&[u8]>> {
231                self.get_pin_mut().poll_fill_buf(cx)
232            }
233
234            fn consume(self: Pin<&mut Self>, amt: usize) {
235                self.get_pin_mut().consume(amt)
236            }
237        }
238    };
239}
240pub(crate) use impl_encoder;