async_compression/tokio/bufread/generic/
encoder.rs1use crate::codecs::Encode;
2use crate::core::util::PartialBuffer;
3use core::{
4 pin::Pin,
5 task::{Context, Poll},
6};
7use futures_core::ready;
8use pin_project_lite::pin_project;
9use std::io::{IoSlice, Result};
10use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
11
12#[derive(Debug)]
13enum State {
14 Encoding,
15 Flushing,
16 Finishing,
17 Done,
18}
19
20pin_project! {
21 #[derive(Debug)]
22 pub struct Encoder<R, E> {
23 #[pin]
24 reader: R,
25 encoder: E,
26 state: State,
27 }
28}
29
30impl<R: AsyncBufRead, E: Encode> Encoder<R, E> {
31 pub fn new(reader: R, encoder: E) -> Self {
32 Self {
33 reader,
34 encoder,
35 state: State::Encoding,
36 }
37 }
38
39 pub fn with_capacity(reader: R, encoder: E, _cap: usize) -> Self {
40 Self::new(reader, encoder)
41 }
42}
43
44impl<R, E> Encoder<R, E> {
45 pub fn get_ref(&self) -> &R {
46 &self.reader
47 }
48
49 pub fn get_mut(&mut self) -> &mut R {
50 &mut self.reader
51 }
52
53 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
54 self.project().reader
55 }
56
57 pub(crate) fn get_encoder_ref(&self) -> &E {
58 &self.encoder
59 }
60
61 pub fn into_inner(self) -> R {
62 self.reader
63 }
64}
65impl<R: AsyncBufRead, E: Encode> Encoder<R, E> {
66 fn do_poll_read(
67 self: Pin<&mut Self>,
68 cx: &mut Context<'_>,
69 output: &mut PartialBuffer<&mut [u8]>,
70 ) -> Poll<Result<()>> {
71 let mut this = self.project();
72 let mut read = 0usize;
73
74 loop {
75 *this.state = match this.state {
76 State::Encoding => {
77 let res = this.reader.as_mut().poll_fill_buf(cx);
78
79 match res {
80 Poll::Pending => {
81 if read == 0 {
82 return Poll::Pending;
83 } else {
84 State::Flushing
85 }
86 }
87 Poll::Ready(res) => {
88 let input = res?;
89
90 if input.is_empty() {
91 State::Finishing
92 } else {
93 let mut input = PartialBuffer::new(input);
94 this.encoder.encode(&mut input, output)?;
95 let len = input.written().len();
96 this.reader.as_mut().consume(len);
97 read += len;
98
99 State::Encoding
100 }
101 }
102 }
103 }
104
105 State::Flushing => {
106 if this.encoder.flush(output)? {
107 read = 0;
108 State::Encoding
109 } else {
110 State::Flushing
111 }
112 }
113
114 State::Finishing => {
115 if this.encoder.finish(output)? {
116 State::Done
117 } else {
118 State::Finishing
119 }
120 }
121
122 State::Done => State::Done,
123 };
124
125 if let State::Done = *this.state {
126 return Poll::Ready(Ok(()));
127 }
128 if output.unwritten().is_empty() {
129 return Poll::Ready(Ok(()));
130 }
131 }
132 }
133}
134
135impl<R: AsyncBufRead, E: Encode> AsyncRead for Encoder<R, E> {
136 fn poll_read(
137 self: Pin<&mut Self>,
138 cx: &mut Context<'_>,
139 buf: &mut ReadBuf<'_>,
140 ) -> Poll<Result<()>> {
141 if buf.remaining() == 0 {
142 return Poll::Ready(Ok(()));
143 }
144
145 let mut output = PartialBuffer::new(buf.initialize_unfilled());
146 match self.do_poll_read(cx, &mut output)? {
147 Poll::Pending if output.written().is_empty() => Poll::Pending,
148 _ => {
149 let len = output.written().len();
150 buf.advance(len);
151 Poll::Ready(Ok(()))
152 }
153 }
154 }
155}
156
157impl<R: AsyncWrite, E> AsyncWrite for Encoder<R, E> {
158 fn poll_write(
159 mut self: Pin<&mut Self>,
160 cx: &mut Context<'_>,
161 buf: &[u8],
162 ) -> Poll<Result<usize>> {
163 self.get_pin_mut().poll_write(cx, buf)
164 }
165
166 fn poll_write_vectored(
167 mut self: Pin<&mut Self>,
168 cx: &mut Context<'_>,
169 mut bufs: &[IoSlice<'_>],
170 ) -> Poll<Result<usize>> {
171 self.get_pin_mut().poll_write_vectored(cx, bufs)
172 }
173
174 fn is_write_vectored(&self) -> bool {
175 self.get_ref().is_write_vectored()
176 }
177
178 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
179 self.get_pin_mut().poll_flush(cx)
180 }
181
182 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
183 self.get_pin_mut().poll_shutdown(cx)
184 }
185}