async_compression/generic/write/
encoder.rs

1use crate::{
2    codecs::EncodeV2,
3    core::util::{PartialBuffer, WriteBuffer},
4    generic::write::AsyncBufWrite,
5};
6use std::{
7    io,
8    pin::Pin,
9    task::{ready, Context, Poll},
10};
11
12#[derive(Debug)]
13enum State {
14    Encoding,
15    Finishing,
16    Done,
17}
18
19#[derive(Debug)]
20pub struct Encoder {
21    state: State,
22}
23
24impl Default for Encoder {
25    fn default() -> Self {
26        Self {
27            state: State::Encoding,
28        }
29    }
30}
31
32impl Encoder {
33    fn do_poll_write(
34        &mut self,
35        cx: &mut Context<'_>,
36        input: &mut PartialBuffer<&[u8]>,
37        mut writer: Pin<&mut dyn AsyncBufWrite>,
38        encoder: &mut dyn EncodeV2,
39    ) -> Poll<io::Result<()>> {
40        loop {
41            let mut output = ready!(writer.as_mut().poll_partial_flush_buf(cx))?;
42            let output = &mut output.write_buffer;
43
44            self.state = match self.state {
45                State::Encoding => {
46                    encoder.encode(input, output)?;
47                    State::Encoding
48                }
49
50                State::Finishing | State::Done => {
51                    break Poll::Ready(Err(io::Error::other("Write after close")));
52                }
53            };
54
55            if input.unwritten().is_empty() {
56                break Poll::Ready(Ok(()));
57            }
58        }
59    }
60
61    pub fn poll_write(
62        &mut self,
63        cx: &mut Context<'_>,
64        buf: &[u8],
65        writer: Pin<&mut dyn AsyncBufWrite>,
66        encoder: &mut dyn EncodeV2,
67    ) -> Poll<io::Result<usize>> {
68        if buf.is_empty() {
69            return Poll::Ready(Ok(0));
70        }
71
72        let mut input = PartialBuffer::new(buf);
73
74        match self.do_poll_write(cx, &mut input, writer, encoder)? {
75            Poll::Pending if input.written().is_empty() => Poll::Pending,
76            _ => Poll::Ready(Ok(input.written().len())),
77        }
78    }
79
80    pub fn do_poll_flush(
81        &mut self,
82        cx: &mut Context<'_>,
83        mut writer: Pin<&mut dyn AsyncBufWrite>,
84        encoder: &mut dyn EncodeV2,
85    ) -> Poll<io::Result<()>> {
86        loop {
87            let mut output = ready!(writer.as_mut().poll_partial_flush_buf(cx))?;
88            let output = &mut output.write_buffer;
89
90            let done = match self.state {
91                State::Encoding => encoder.flush(output)?,
92
93                State::Finishing | State::Done => {
94                    break Poll::Ready(Err(io::Error::other("Flush after close")));
95                }
96            };
97
98            if done {
99                break Poll::Ready(Ok(()));
100            }
101        }
102    }
103
104    pub fn do_poll_close(
105        &mut self,
106        cx: &mut Context<'_>,
107        mut writer: Pin<&mut dyn AsyncBufWrite>,
108        encoder: &mut dyn EncodeV2,
109    ) -> Poll<io::Result<()>> {
110        loop {
111            let mut output = ready!(writer.as_mut().poll_partial_flush_buf(cx))?;
112            let output = &mut output.write_buffer;
113
114            self.state = match self.state {
115                State::Encoding | State::Finishing => {
116                    if encoder.finish(output)? {
117                        State::Done
118                    } else {
119                        State::Finishing
120                    }
121                }
122
123                State::Done => State::Done,
124            };
125
126            if let State::Done = self.state {
127                break Poll::Ready(Ok(()));
128            }
129        }
130    }
131}
132
133macro_rules! impl_encoder {
134    ($poll_close: tt) => {
135        use crate::{codecs::EncodeV2, generic::write::Encoder as GenericEncoder};
136        use pin_project_lite::pin_project;
137        use std::task::ready;
138
139        pin_project! {
140            #[derive(Debug)]
141            pub struct Encoder<W, E> {
142                #[pin]
143                writer: BufWriter<W>,
144                encoder: E,
145                inner: GenericEncoder,
146            }
147        }
148
149        impl<W, E> Encoder<W, E> {
150            pub fn get_ref(&self) -> &W {
151                self.writer.get_ref()
152            }
153
154            pub fn get_mut(&mut self) -> &mut W {
155                self.writer.get_mut()
156            }
157
158            pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
159                self.project().writer.get_pin_mut()
160            }
161
162            pub(crate) fn get_encoder_ref(&self) -> &E {
163                &self.encoder
164            }
165
166            pub fn into_inner(self) -> W {
167                self.writer.into_inner()
168            }
169
170            pub fn new(writer: W, encoder: E) -> Self {
171                Self {
172                    writer: BufWriter::new(writer),
173                    encoder,
174                    inner: Default::default(),
175                }
176            }
177
178            pub fn with_capacity(writer: W, encoder: E, cap: usize) -> Self {
179                Self {
180                    writer: BufWriter::with_capacity(cap, writer),
181                    encoder,
182                    inner: Default::default(),
183                }
184            }
185        }
186
187        impl<W: AsyncWrite, E: EncodeV2> AsyncWrite for Encoder<W, E> {
188            fn poll_write(
189                self: Pin<&mut Self>,
190                cx: &mut Context<'_>,
191                buf: &[u8],
192            ) -> Poll<io::Result<usize>> {
193                let this = self.project();
194                this.inner.poll_write(cx, buf, this.writer, this.encoder)
195            }
196
197            fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
198                let mut this = self.project();
199
200                ready!(this
201                    .inner
202                    .do_poll_flush(cx, this.writer.as_mut(), this.encoder))?;
203                this.writer.poll_flush(cx)
204            }
205
206            fn $poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
207                let mut this = self.project();
208
209                ready!(this
210                    .inner
211                    .do_poll_close(cx, this.writer.as_mut(), this.encoder))?;
212                this.writer.$poll_close(cx)
213            }
214        }
215
216        impl<W: AsyncBufRead, E> AsyncBufRead for Encoder<W, E> {
217            fn poll_fill_buf(
218                self: Pin<&mut Self>,
219                cx: &mut Context<'_>,
220            ) -> Poll<io::Result<&[u8]>> {
221                self.get_pin_mut().poll_fill_buf(cx)
222            }
223
224            fn consume(self: Pin<&mut Self>, amt: usize) {
225                self.get_pin_mut().consume(amt)
226            }
227        }
228    };
229}
230pub(crate) use impl_encoder;