async_compression/tokio/bufread/generic/
encoder.rs

1use crate::codecs::Encode;
2use crate::core::util::PartialBuffer;
3use core::{
4    pin::Pin,
5    task::{Context, Poll},
6};
7use futures_core::ready;
8use pin_project_lite::pin_project;
9use std::io::{IoSlice, Result};
10use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
11
12#[derive(Debug)]
13enum State {
14    Encoding,
15    Flushing,
16    Done,
17}
18
19pin_project! {
20    #[derive(Debug)]
21    pub struct Encoder<R, E> {
22        #[pin]
23        reader: R,
24        encoder: E,
25        state: State,
26    }
27}
28
29impl<R: AsyncBufRead, E: Encode> Encoder<R, E> {
30    pub fn new(reader: R, encoder: E) -> Self {
31        Self {
32            reader,
33            encoder,
34            state: State::Encoding,
35        }
36    }
37
38    pub fn with_capacity(reader: R, encoder: E, _cap: usize) -> Self {
39        Self::new(reader, encoder)
40    }
41}
42
43impl<R, E> Encoder<R, E> {
44    pub fn get_ref(&self) -> &R {
45        &self.reader
46    }
47
48    pub fn get_mut(&mut self) -> &mut R {
49        &mut self.reader
50    }
51
52    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
53        self.project().reader
54    }
55
56    pub(crate) fn get_encoder_ref(&self) -> &E {
57        &self.encoder
58    }
59
60    pub fn into_inner(self) -> R {
61        self.reader
62    }
63}
64impl<R: AsyncBufRead, E: Encode> Encoder<R, E> {
65    fn do_poll_read(
66        self: Pin<&mut Self>,
67        cx: &mut Context<'_>,
68        output: &mut PartialBuffer<&mut [u8]>,
69    ) -> Poll<Result<()>> {
70        let mut this = self.project();
71
72        loop {
73            *this.state = match this.state {
74                State::Encoding => {
75                    let input = ready!(this.reader.as_mut().poll_fill_buf(cx))?;
76                    if input.is_empty() {
77                        State::Flushing
78                    } else {
79                        let mut input = PartialBuffer::new(input);
80                        this.encoder.encode(&mut input, output)?;
81                        let len = input.written().len();
82                        this.reader.as_mut().consume(len);
83                        State::Encoding
84                    }
85                }
86
87                State::Flushing => {
88                    if this.encoder.finish(output)? {
89                        State::Done
90                    } else {
91                        State::Flushing
92                    }
93                }
94
95                State::Done => State::Done,
96            };
97
98            if let State::Done = *this.state {
99                return Poll::Ready(Ok(()));
100            }
101            if output.unwritten().is_empty() {
102                return Poll::Ready(Ok(()));
103            }
104        }
105    }
106}
107
108impl<R: AsyncBufRead, E: Encode> AsyncRead for Encoder<R, E> {
109    fn poll_read(
110        self: Pin<&mut Self>,
111        cx: &mut Context<'_>,
112        buf: &mut ReadBuf<'_>,
113    ) -> Poll<Result<()>> {
114        if buf.remaining() == 0 {
115            return Poll::Ready(Ok(()));
116        }
117
118        let mut output = PartialBuffer::new(buf.initialize_unfilled());
119        match self.do_poll_read(cx, &mut output)? {
120            Poll::Pending if output.written().is_empty() => Poll::Pending,
121            _ => {
122                let len = output.written().len();
123                buf.advance(len);
124                Poll::Ready(Ok(()))
125            }
126        }
127    }
128}
129
130impl<R: AsyncWrite, E> AsyncWrite for Encoder<R, E> {
131    fn poll_write(
132        mut self: Pin<&mut Self>,
133        cx: &mut Context<'_>,
134        buf: &[u8],
135    ) -> Poll<Result<usize>> {
136        self.get_pin_mut().poll_write(cx, buf)
137    }
138
139    fn poll_write_vectored(
140        mut self: Pin<&mut Self>,
141        cx: &mut Context<'_>,
142        mut bufs: &[IoSlice<'_>],
143    ) -> Poll<Result<usize>> {
144        self.get_pin_mut().poll_write_vectored(cx, bufs)
145    }
146
147    fn is_write_vectored(&self) -> bool {
148        self.get_ref().is_write_vectored()
149    }
150
151    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
152        self.get_pin_mut().poll_flush(cx)
153    }
154
155    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
156        self.get_pin_mut().poll_shutdown(cx)
157    }
158}