async_compression/tokio/bufread/generic/
encoder.rs

1use crate::codecs::Encode;
2use crate::core::util::PartialBuffer;
3use core::{
4    pin::Pin,
5    task::{Context, Poll},
6};
7use futures_core::ready;
8use pin_project_lite::pin_project;
9use std::io::{IoSlice, Result};
10use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
11
12#[derive(Debug)]
13enum State {
14    Encoding,
15    Flushing,
16    Finishing,
17    Done,
18}
19
20pin_project! {
21    #[derive(Debug)]
22    pub struct Encoder<R, E> {
23        #[pin]
24        reader: R,
25        encoder: E,
26        state: State,
27    }
28}
29
30impl<R: AsyncBufRead, E: Encode> Encoder<R, E> {
31    pub fn new(reader: R, encoder: E) -> Self {
32        Self {
33            reader,
34            encoder,
35            state: State::Encoding,
36        }
37    }
38
39    pub fn with_capacity(reader: R, encoder: E, _cap: usize) -> Self {
40        Self::new(reader, encoder)
41    }
42}
43
44impl<R, E> Encoder<R, E> {
45    pub fn get_ref(&self) -> &R {
46        &self.reader
47    }
48
49    pub fn get_mut(&mut self) -> &mut R {
50        &mut self.reader
51    }
52
53    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
54        self.project().reader
55    }
56
57    pub(crate) fn get_encoder_ref(&self) -> &E {
58        &self.encoder
59    }
60
61    pub fn into_inner(self) -> R {
62        self.reader
63    }
64}
65impl<R: AsyncBufRead, E: Encode> Encoder<R, E> {
66    fn do_poll_read(
67        self: Pin<&mut Self>,
68        cx: &mut Context<'_>,
69        output: &mut PartialBuffer<&mut [u8]>,
70    ) -> Poll<Result<()>> {
71        let mut this = self.project();
72        let mut read = 0usize;
73
74        loop {
75            *this.state = match this.state {
76                State::Encoding => {
77                    let res = this.reader.as_mut().poll_fill_buf(cx);
78
79                    match res {
80                        Poll::Pending => {
81                            if read == 0 {
82                                return Poll::Pending;
83                            } else {
84                                State::Flushing
85                            }
86                        }
87                        Poll::Ready(res) => {
88                            let input = res?;
89
90                            if input.is_empty() {
91                                State::Finishing
92                            } else {
93                                let mut input = PartialBuffer::new(input);
94                                this.encoder.encode(&mut input, output)?;
95                                let len = input.written().len();
96                                this.reader.as_mut().consume(len);
97                                read += len;
98
99                                State::Encoding
100                            }
101                        }
102                    }
103                }
104
105                State::Flushing => {
106                    if this.encoder.flush(output)? {
107                        read = 0;
108                        State::Encoding
109                    } else {
110                        State::Flushing
111                    }
112                }
113
114                State::Finishing => {
115                    if this.encoder.finish(output)? {
116                        State::Done
117                    } else {
118                        State::Finishing
119                    }
120                }
121
122                State::Done => State::Done,
123            };
124
125            if let State::Done = *this.state {
126                return Poll::Ready(Ok(()));
127            }
128            if output.unwritten().is_empty() {
129                return Poll::Ready(Ok(()));
130            }
131        }
132    }
133}
134
135impl<R: AsyncBufRead, E: Encode> AsyncRead for Encoder<R, E> {
136    fn poll_read(
137        self: Pin<&mut Self>,
138        cx: &mut Context<'_>,
139        buf: &mut ReadBuf<'_>,
140    ) -> Poll<Result<()>> {
141        if buf.remaining() == 0 {
142            return Poll::Ready(Ok(()));
143        }
144
145        let mut output = PartialBuffer::new(buf.initialize_unfilled());
146        match self.do_poll_read(cx, &mut output)? {
147            Poll::Pending if output.written().is_empty() => Poll::Pending,
148            _ => {
149                let len = output.written().len();
150                buf.advance(len);
151                Poll::Ready(Ok(()))
152            }
153        }
154    }
155}
156
157impl<R: AsyncWrite, E> AsyncWrite for Encoder<R, E> {
158    fn poll_write(
159        mut self: Pin<&mut Self>,
160        cx: &mut Context<'_>,
161        buf: &[u8],
162    ) -> Poll<Result<usize>> {
163        self.get_pin_mut().poll_write(cx, buf)
164    }
165
166    fn poll_write_vectored(
167        mut self: Pin<&mut Self>,
168        cx: &mut Context<'_>,
169        mut bufs: &[IoSlice<'_>],
170    ) -> Poll<Result<usize>> {
171        self.get_pin_mut().poll_write_vectored(cx, bufs)
172    }
173
174    fn is_write_vectored(&self) -> bool {
175        self.get_ref().is_write_vectored()
176    }
177
178    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
179        self.get_pin_mut().poll_flush(cx)
180    }
181
182    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
183        self.get_pin_mut().poll_shutdown(cx)
184    }
185}