async_compression/tokio/bufread/generic/
encoder.rs1use crate::codecs::Encode;
2use crate::core::util::PartialBuffer;
3use core::{
4 pin::Pin,
5 task::{Context, Poll},
6};
7use futures_core::ready;
8use pin_project_lite::pin_project;
9use std::io::{IoSlice, Result};
10use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
11
12#[derive(Debug)]
13enum State {
14 Encoding,
15 Flushing,
16 Done,
17}
18
19pin_project! {
20 #[derive(Debug)]
21 pub struct Encoder<R, E> {
22 #[pin]
23 reader: R,
24 encoder: E,
25 state: State,
26 }
27}
28
29impl<R: AsyncBufRead, E: Encode> Encoder<R, E> {
30 pub fn new(reader: R, encoder: E) -> Self {
31 Self {
32 reader,
33 encoder,
34 state: State::Encoding,
35 }
36 }
37
38 pub fn with_capacity(reader: R, encoder: E, _cap: usize) -> Self {
39 Self::new(reader, encoder)
40 }
41}
42
43impl<R, E> Encoder<R, E> {
44 pub fn get_ref(&self) -> &R {
45 &self.reader
46 }
47
48 pub fn get_mut(&mut self) -> &mut R {
49 &mut self.reader
50 }
51
52 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
53 self.project().reader
54 }
55
56 pub(crate) fn get_encoder_ref(&self) -> &E {
57 &self.encoder
58 }
59
60 pub fn into_inner(self) -> R {
61 self.reader
62 }
63}
64impl<R: AsyncBufRead, E: Encode> Encoder<R, E> {
65 fn do_poll_read(
66 self: Pin<&mut Self>,
67 cx: &mut Context<'_>,
68 output: &mut PartialBuffer<&mut [u8]>,
69 ) -> Poll<Result<()>> {
70 let mut this = self.project();
71
72 loop {
73 *this.state = match this.state {
74 State::Encoding => {
75 let input = ready!(this.reader.as_mut().poll_fill_buf(cx))?;
76 if input.is_empty() {
77 State::Flushing
78 } else {
79 let mut input = PartialBuffer::new(input);
80 this.encoder.encode(&mut input, output)?;
81 let len = input.written().len();
82 this.reader.as_mut().consume(len);
83 State::Encoding
84 }
85 }
86
87 State::Flushing => {
88 if this.encoder.finish(output)? {
89 State::Done
90 } else {
91 State::Flushing
92 }
93 }
94
95 State::Done => State::Done,
96 };
97
98 if let State::Done = *this.state {
99 return Poll::Ready(Ok(()));
100 }
101 if output.unwritten().is_empty() {
102 return Poll::Ready(Ok(()));
103 }
104 }
105 }
106}
107
108impl<R: AsyncBufRead, E: Encode> AsyncRead for Encoder<R, E> {
109 fn poll_read(
110 self: Pin<&mut Self>,
111 cx: &mut Context<'_>,
112 buf: &mut ReadBuf<'_>,
113 ) -> Poll<Result<()>> {
114 if buf.remaining() == 0 {
115 return Poll::Ready(Ok(()));
116 }
117
118 let mut output = PartialBuffer::new(buf.initialize_unfilled());
119 match self.do_poll_read(cx, &mut output)? {
120 Poll::Pending if output.written().is_empty() => Poll::Pending,
121 _ => {
122 let len = output.written().len();
123 buf.advance(len);
124 Poll::Ready(Ok(()))
125 }
126 }
127 }
128}
129
130impl<R: AsyncWrite, E> AsyncWrite for Encoder<R, E> {
131 fn poll_write(
132 mut self: Pin<&mut Self>,
133 cx: &mut Context<'_>,
134 buf: &[u8],
135 ) -> Poll<Result<usize>> {
136 self.get_pin_mut().poll_write(cx, buf)
137 }
138
139 fn poll_write_vectored(
140 mut self: Pin<&mut Self>,
141 cx: &mut Context<'_>,
142 mut bufs: &[IoSlice<'_>],
143 ) -> Poll<Result<usize>> {
144 self.get_pin_mut().poll_write_vectored(cx, bufs)
145 }
146
147 fn is_write_vectored(&self) -> bool {
148 self.get_ref().is_write_vectored()
149 }
150
151 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
152 self.get_pin_mut().poll_flush(cx)
153 }
154
155 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
156 self.get_pin_mut().poll_shutdown(cx)
157 }
158}