async_compression/tokio/write/generic/
decoder.rs1use crate::codecs::Decode;
2use crate::core::util::PartialBuffer;
3use crate::tokio::write::{AsyncBufWrite, BufWriter};
4use futures_core::ready;
5use pin_project_lite::pin_project;
6use std::{
7 io,
8 pin::Pin,
9 task::{Context, Poll},
10};
11use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
12
13#[derive(Debug)]
14enum State {
15 Decoding,
16 Finishing,
17 Done,
18}
19
20pin_project! {
21 #[derive(Debug)]
22 pub struct Decoder<W, D> {
23 #[pin]
24 writer: BufWriter<W>,
25 decoder: D,
26 state: State,
27 }
28}
29
30impl<W: AsyncWrite, D: Decode> Decoder<W, D> {
31 pub fn new(writer: W, decoder: D) -> Self {
32 Self {
33 writer: BufWriter::new(writer),
34 decoder,
35 state: State::Decoding,
36 }
37 }
38}
39
40impl<W, D> Decoder<W, D> {
41 pub fn get_ref(&self) -> &W {
42 self.writer.get_ref()
43 }
44
45 pub fn get_mut(&mut self) -> &mut W {
46 self.writer.get_mut()
47 }
48
49 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
50 self.project().writer.get_pin_mut()
51 }
52
53 pub fn into_inner(self) -> W {
54 self.writer.into_inner()
55 }
56}
57
58impl<W: AsyncWrite, D: Decode> Decoder<W, D> {
59 fn do_poll_write(
60 self: Pin<&mut Self>,
61 cx: &mut Context<'_>,
62 input: &mut PartialBuffer<&[u8]>,
63 ) -> Poll<io::Result<()>> {
64 let mut this = self.project();
65
66 loop {
67 let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
68 let mut output = PartialBuffer::new(output);
69
70 *this.state = match this.state {
71 State::Decoding => {
72 if this.decoder.decode(input, &mut output)? {
73 State::Finishing
74 } else {
75 State::Decoding
76 }
77 }
78
79 State::Finishing => {
80 if this.decoder.finish(&mut output)? {
81 State::Done
82 } else {
83 State::Finishing
84 }
85 }
86
87 State::Done => {
88 return Poll::Ready(Err(io::Error::other("Write after end of stream")))
89 }
90 };
91
92 let produced = output.written().len();
93 this.writer.as_mut().produce(produced);
94
95 if let State::Done = this.state {
96 return Poll::Ready(Ok(()));
97 }
98
99 if input.unwritten().is_empty() {
100 return Poll::Ready(Ok(()));
101 }
102 }
103 }
104
105 fn do_poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
106 let mut this = self.project();
107
108 loop {
109 let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
110 let mut output = PartialBuffer::new(output);
111
112 let (state, done) = match this.state {
113 State::Decoding => {
114 let done = this.decoder.flush(&mut output)?;
115 (State::Decoding, done)
116 }
117
118 State::Finishing => {
119 if this.decoder.finish(&mut output)? {
120 (State::Done, false)
121 } else {
122 (State::Finishing, false)
123 }
124 }
125
126 State::Done => (State::Done, true),
127 };
128
129 *this.state = state;
130
131 let produced = output.written().len();
132 this.writer.as_mut().produce(produced);
133
134 if done {
135 return Poll::Ready(Ok(()));
136 }
137 }
138 }
139}
140
141impl<W: AsyncWrite, D: Decode> AsyncWrite for Decoder<W, D> {
142 fn poll_write(
143 self: Pin<&mut Self>,
144 cx: &mut Context<'_>,
145 buf: &[u8],
146 ) -> Poll<io::Result<usize>> {
147 if buf.is_empty() {
148 return Poll::Ready(Ok(0));
149 }
150
151 let mut input = PartialBuffer::new(buf);
152
153 match self.do_poll_write(cx, &mut input)? {
154 Poll::Pending if input.written().is_empty() => Poll::Pending,
155 _ => Poll::Ready(Ok(input.written().len())),
156 }
157 }
158
159 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
160 ready!(self.as_mut().do_poll_flush(cx))?;
161 ready!(self.project().writer.as_mut().poll_flush(cx))?;
162 Poll::Ready(Ok(()))
163 }
164
165 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
166 if let State::Decoding = self.as_mut().project().state {
167 *self.as_mut().project().state = State::Finishing;
168 }
169
170 ready!(self.as_mut().do_poll_flush(cx))?;
171
172 if let State::Done = self.as_mut().project().state {
173 ready!(self.as_mut().project().writer.as_mut().poll_shutdown(cx))?;
174 Poll::Ready(Ok(()))
175 } else {
176 Poll::Ready(Err(io::Error::other(
177 "Attempt to shutdown before finishing input",
178 )))
179 }
180 }
181}
182
183impl<W: AsyncRead, D> AsyncRead for Decoder<W, D> {
184 fn poll_read(
185 self: Pin<&mut Self>,
186 cx: &mut Context<'_>,
187 buf: &mut ReadBuf<'_>,
188 ) -> Poll<io::Result<()>> {
189 self.get_pin_mut().poll_read(cx, buf)
190 }
191}
192
193impl<W: AsyncBufRead, D> AsyncBufRead for Decoder<W, D> {
194 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
195 self.get_pin_mut().poll_fill_buf(cx)
196 }
197
198 fn consume(self: Pin<&mut Self>, amt: usize) {
199 self.get_pin_mut().consume(amt)
200 }
201}