async_compression/tokio/write/generic/
encoder.rs1use crate::codecs::Encode;
2use crate::core::util::PartialBuffer;
3use crate::tokio::write::{AsyncBufWrite, BufWriter};
4use futures_core::ready;
5use pin_project_lite::pin_project;
6use std::{
7 io,
8 pin::Pin,
9 task::{Context, Poll},
10};
11use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
12
13#[derive(Debug)]
14enum State {
15 Encoding,
16 Flushing,
17 Finishing,
18 Done,
19}
20
21pin_project! {
22 #[derive(Debug)]
23 pub struct Encoder<W, E> {
24 #[pin]
25 writer: BufWriter<W>,
26 encoder: E,
27 state: State,
28 }
29}
30
31impl<W: AsyncWrite, E: Encode> Encoder<W, E> {
32 pub fn new(writer: W, encoder: E) -> Self {
33 Self {
34 writer: BufWriter::new(writer),
35 encoder,
36 state: State::Encoding,
37 }
38 }
39
40 pub fn with_capacity(writer: W, encoder: E, cap: usize) -> Self {
41 Self {
42 writer: BufWriter::with_capacity(cap, writer),
43 encoder,
44 state: State::Encoding,
45 }
46 }
47}
48
49impl<W, E> Encoder<W, E> {
50 pub fn get_ref(&self) -> &W {
51 self.writer.get_ref()
52 }
53
54 pub fn get_mut(&mut self) -> &mut W {
55 self.writer.get_mut()
56 }
57
58 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
59 self.project().writer.get_pin_mut()
60 }
61
62 pub(crate) fn get_encoder_ref(&self) -> &E {
63 &self.encoder
64 }
65
66 pub fn into_inner(self) -> W {
67 self.writer.into_inner()
68 }
69}
70
71impl<W: AsyncWrite, E: Encode> Encoder<W, E> {
72 fn do_poll_write(
73 self: Pin<&mut Self>,
74 cx: &mut Context<'_>,
75 input: &mut PartialBuffer<&[u8]>,
76 ) -> Poll<io::Result<()>> {
77 let mut this = self.project();
78
79 loop {
80 let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
81 let mut output = PartialBuffer::new(output);
82
83 *this.state = match this.state {
84 State::Encoding => {
85 this.encoder.encode(input, &mut output)?;
86 State::Encoding
87 }
88
89 State::Flushing => match this.encoder.flush(&mut output)? {
91 true => State::Encoding,
92 false => State::Flushing,
93 },
94
95 State::Finishing | State::Done => {
96 return Poll::Ready(Err(io::Error::other("Write after shutdown")))
97 }
98 };
99
100 let produced = output.written().len();
101 this.writer.as_mut().produce(produced);
102
103 if input.unwritten().is_empty() {
104 return Poll::Ready(Ok(()));
105 }
106 }
107 }
108
109 fn do_poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
110 let mut this = self.project();
111
112 loop {
113 let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
114 let mut output = PartialBuffer::new(output);
115
116 let done = match this.state {
117 State::Encoding | State::Flushing => this.encoder.flush(&mut output)?,
118
119 State::Finishing | State::Done => {
120 return Poll::Ready(Err(io::Error::other("Flush after shutdown")))
121 }
122 };
123 *this.state = State::Flushing;
124
125 let produced = output.written().len();
126 this.writer.as_mut().produce(produced);
127
128 if done {
129 *this.state = State::Encoding;
130 return Poll::Ready(Ok(()));
131 }
132 }
133 }
134
135 fn do_poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
136 let mut this = self.project();
137
138 loop {
139 let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
140 let mut output = PartialBuffer::new(output);
141
142 *this.state = match this.state {
143 State::Encoding | State::Finishing => {
144 if this.encoder.finish(&mut output)? {
145 State::Done
146 } else {
147 State::Finishing
148 }
149 }
150
151 State::Flushing => match this.encoder.flush(&mut output)? {
153 true => State::Finishing,
154 false => State::Flushing,
155 },
156
157 State::Done => State::Done,
158 };
159
160 let produced = output.written().len();
161 this.writer.as_mut().produce(produced);
162
163 if let State::Done = this.state {
164 return Poll::Ready(Ok(()));
165 }
166 }
167 }
168}
169
170impl<W: AsyncWrite, E: Encode> AsyncWrite for Encoder<W, E> {
171 fn poll_write(
172 self: Pin<&mut Self>,
173 cx: &mut Context<'_>,
174 buf: &[u8],
175 ) -> Poll<io::Result<usize>> {
176 if buf.is_empty() {
177 return Poll::Ready(Ok(0));
178 }
179
180 let mut input = PartialBuffer::new(buf);
181
182 match self.do_poll_write(cx, &mut input)? {
183 Poll::Pending if input.written().is_empty() => Poll::Pending,
184 _ => Poll::Ready(Ok(input.written().len())),
185 }
186 }
187
188 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
189 ready!(self.as_mut().do_poll_flush(cx))?;
190 ready!(self.project().writer.as_mut().poll_flush(cx))?;
191 Poll::Ready(Ok(()))
192 }
193
194 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
195 ready!(self.as_mut().do_poll_shutdown(cx))?;
196 ready!(self.project().writer.as_mut().poll_shutdown(cx))?;
197 Poll::Ready(Ok(()))
198 }
199}
200
201impl<W: AsyncRead, E> AsyncRead for Encoder<W, E> {
202 fn poll_read(
203 self: Pin<&mut Self>,
204 cx: &mut Context<'_>,
205 buf: &mut ReadBuf<'_>,
206 ) -> Poll<io::Result<()>> {
207 self.get_pin_mut().poll_read(cx, buf)
208 }
209}
210
211impl<W: AsyncBufRead, E> AsyncBufRead for Encoder<W, E> {
212 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
213 self.get_pin_mut().poll_fill_buf(cx)
214 }
215
216 fn consume(self: Pin<&mut Self>, amt: usize) {
217 self.get_pin_mut().consume(amt)
218 }
219}