async_compression/generic/bufread/
encoder.rs

1use crate::{
2    codecs::EncodeV2,
3    core::util::{PartialBuffer, WriteBuffer},
4};
5use std::{io::Result, ops::ControlFlow, panic::AssertUnwindSafe};
6
7#[derive(Debug)]
8enum State {
9    Encoding(usize),
10    Flushing,
11    Finishing,
12    Done,
13    Error(AssertUnwindSafe<std::io::Error>),
14}
15
16#[derive(Debug)]
17pub struct Encoder {
18    state: State,
19}
20
21impl Default for Encoder {
22    fn default() -> Self {
23        Self {
24            state: State::Encoding(0),
25        }
26    }
27}
28
29impl Encoder {
30    /// `input` - should be `None` if `Poll::Pending`.
31    pub fn do_poll_read(
32        &mut self,
33        output: &mut WriteBuffer<'_>,
34        encoder: &mut dyn EncodeV2,
35        mut input: Option<&mut PartialBuffer<&[u8]>>,
36    ) -> ControlFlow<Result<()>> {
37        loop {
38            self.state = match &mut self.state {
39                State::Encoding(read) => match input.as_mut() {
40                    None => {
41                        if *read == 0 {
42                            if output.written().is_empty() {
43                                // Poll for more data
44                                break;
45                            } else {
46                                return ControlFlow::Break(Ok(()));
47                            }
48                        } else {
49                            State::Flushing
50                        }
51                    }
52                    Some(input) => {
53                        if input.unwritten().is_empty() {
54                            State::Finishing
55                        } else {
56                            if let Err(err) = encoder.encode(input, output) {
57                                self.state = State::Error(AssertUnwindSafe(err));
58                                if output.written_len() > 0 {
59                                    return ControlFlow::Break(Ok(()));
60                                } else {
61                                    continue;
62                                }
63                            }
64
65                            *read += input.written().len();
66
67                            // Poll for more data
68                            break;
69                        }
70                    }
71                },
72
73                State::Flushing => match encoder.flush(output) {
74                    Ok(true) => {
75                        self.state = State::Encoding(0);
76
77                        // Poll for more data
78                        break;
79                    }
80                    Ok(false) => State::Flushing,
81                    Err(err) => {
82                        self.state = State::Error(AssertUnwindSafe(err));
83                        if output.written_len() > 0 {
84                            return ControlFlow::Break(Ok(()));
85                        } else {
86                            continue;
87                        }
88                    }
89                },
90
91                State::Finishing => match encoder.finish(output) {
92                    Ok(true) => State::Done,
93                    Ok(false) => State::Finishing,
94                    Err(err) => {
95                        self.state = State::Error(AssertUnwindSafe(err));
96                        if output.written_len() > 0 {
97                            return ControlFlow::Break(Ok(()));
98                        } else {
99                            continue;
100                        }
101                    }
102                },
103
104                State::Done => return ControlFlow::Break(Ok(())),
105
106                State::Error(_) => {
107                    let State::Error(err) = std::mem::replace(&mut self.state, State::Done) else {
108                        unreachable!()
109                    };
110                    return ControlFlow::Break(Err(err.0));
111                }
112            };
113
114            if output.has_no_spare_space() {
115                return ControlFlow::Break(Ok(()));
116            }
117        }
118
119        if output.has_no_spare_space() {
120            ControlFlow::Break(Ok(()))
121        } else {
122            ControlFlow::Continue(())
123        }
124    }
125}
126
127macro_rules! impl_encoder {
128    () => {
129        use crate::generic::bufread::Encoder as GenericEncoder;
130
131        use std::{ops::ControlFlow, task::ready};
132
133        use pin_project_lite::pin_project;
134
135        pin_project! {
136            #[derive(Debug)]
137            pub struct Encoder<R, E> {
138                #[pin]
139                reader: R,
140                encoder: E,
141                inner: GenericEncoder,
142            }
143        }
144
145        impl<R: AsyncBufRead, E: EncodeV2> Encoder<R, E> {
146            pub fn new(reader: R, encoder: E) -> Self {
147                Self {
148                    reader,
149                    encoder,
150                    inner: Default::default(),
151                }
152            }
153
154            pub fn with_capacity(reader: R, encoder: E, _cap: usize) -> Self {
155                Self::new(reader, encoder)
156            }
157        }
158
159        impl<R, E> Encoder<R, E> {
160            pub fn get_ref(&self) -> &R {
161                &self.reader
162            }
163
164            pub fn get_mut(&mut self) -> &mut R {
165                &mut self.reader
166            }
167
168            pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
169                self.project().reader
170            }
171
172            pub(crate) fn get_encoder_ref(&self) -> &E {
173                &self.encoder
174            }
175
176            pub fn into_inner(self) -> R {
177                self.reader
178            }
179        }
180
181        fn do_poll_read(
182            inner: &mut GenericEncoder,
183            encoder: &mut dyn EncodeV2,
184            mut reader: Pin<&mut dyn AsyncBufRead>,
185            cx: &mut Context<'_>,
186            output: &mut WriteBuffer<'_>,
187        ) -> Poll<Result<()>> {
188            if let ControlFlow::Break(res) = inner.do_poll_read(output, encoder, None) {
189                return Poll::Ready(res);
190            }
191
192            loop {
193                let mut input = match reader.as_mut().poll_fill_buf(cx) {
194                    Poll::Pending => None,
195                    Poll::Ready(res) => Some(PartialBuffer::new(res?)),
196                };
197
198                let control_flow = inner.do_poll_read(output, encoder, input.as_mut());
199
200                let is_pending = input.is_none();
201                if let Some(input) = input {
202                    let len = input.written().len();
203                    reader.as_mut().consume(len);
204                }
205
206                if let ControlFlow::Break(res) = control_flow {
207                    break Poll::Ready(res);
208                }
209
210                if is_pending {
211                    if output.written().is_empty() {
212                        return Poll::Pending;
213                    } else {
214                        return Poll::Ready(Ok(()));
215                    }
216                }
217            }
218        }
219
220        impl<R: AsyncBufRead, E: EncodeV2> Encoder<R, E> {
221            fn do_poll_read(
222                self: Pin<&mut Self>,
223                cx: &mut Context<'_>,
224                output: &mut WriteBuffer<'_>,
225            ) -> Poll<Result<()>> {
226                let this = self.project();
227
228                do_poll_read(this.inner, this.encoder, this.reader, cx, output)
229            }
230        }
231    };
232}
233pub(crate) use impl_encoder;