1use std::error::Error;
20use std::fmt;
21use std::io::{self};
22use std::pin::Pin;
23
24use async_compression::tokio::bufread::{BrotliDecoder, GzipDecoder, ZlibDecoder, ZstdDecoder};
25use bytes::Bytes;
26use futures::stream::Peekable;
27use futures::task::{Context, Poll};
28use futures::{Future, Stream};
29use futures_util::StreamExt;
30use headers::{ContentLength, HeaderMapExt};
31use http_body_util::BodyExt;
32use hyper::Response;
33use hyper::body::Body;
34use hyper::header::{CONTENT_ENCODING, HeaderValue, TRANSFER_ENCODING};
35use tokio_util::codec::{BytesCodec, FramedRead};
36use tokio_util::io::StreamReader;
37
38use crate::connector::BoxedBody;
39
40pub const DECODER_BUFFER_SIZE: usize = 8192;
41
42pub struct Decoder {
46 inner: Inner,
47}
48
49#[derive(PartialEq)]
50enum DecoderType {
51 Gzip,
52 Brotli,
53 Deflate,
54 Zstd,
55}
56
57enum Inner {
58 PlainText(BodyStream),
60 Gzip(FramedRead<GzipDecoder<StreamReader<Peekable<BodyStream>, Bytes>>, BytesCodec>),
62 Deflate(FramedRead<ZlibDecoder<StreamReader<Peekable<BodyStream>, Bytes>>, BytesCodec>),
64 Brotli(FramedRead<BrotliDecoder<StreamReader<Peekable<BodyStream>, Bytes>>, BytesCodec>),
66 Zstd(FramedRead<ZstdDecoder<StreamReader<Peekable<BodyStream>, Bytes>>, BytesCodec>),
68 Pending(Pending),
70}
71
72struct Pending {
74 body: Peekable<BodyStream>,
75 type_: DecoderType,
76}
77
78impl fmt::Debug for Decoder {
79 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
80 f.debug_struct("Decoder").finish()
81 }
82}
83
84impl Decoder {
85 #[inline]
89 fn plain_text(
90 body: BoxedBody,
91 is_secure_scheme: bool,
92 content_length: Option<ContentLength>,
93 ) -> Decoder {
94 Decoder {
95 inner: Inner::PlainText(BodyStream::new(body, is_secure_scheme, content_length)),
96 }
97 }
98
99 #[inline]
103 fn pending(
104 body: BoxedBody,
105 type_: DecoderType,
106 is_secure_scheme: bool,
107 content_length: Option<ContentLength>,
108 ) -> Decoder {
109 Decoder {
110 inner: Inner::Pending(Pending {
111 body: BodyStream::new(body, is_secure_scheme, content_length).peekable(),
112 type_,
113 }),
114 }
115 }
116
117 pub fn detect(response: Response<BoxedBody>, is_secure_scheme: bool) -> Response<Decoder> {
124 let values = response
125 .headers()
126 .get_all(CONTENT_ENCODING)
127 .iter()
128 .chain(response.headers().get_all(TRANSFER_ENCODING).iter());
129 let decoder = values.fold(None, |acc, enc| {
130 acc.or_else(|| {
131 if enc == HeaderValue::from_static("gzip") {
132 Some(DecoderType::Gzip)
133 } else if enc == HeaderValue::from_static("br") {
134 Some(DecoderType::Brotli)
135 } else if enc == HeaderValue::from_static("deflate") {
136 Some(DecoderType::Deflate)
137 } else if enc == HeaderValue::from_static("zstd") {
138 Some(DecoderType::Zstd)
139 } else {
140 None
141 }
142 })
143 });
144 let content_length = response.headers().typed_get::<ContentLength>();
145 match decoder {
146 Some(type_) => {
147 response.map(|r| Decoder::pending(r, type_, is_secure_scheme, content_length))
148 },
149 None => response.map(|r| Decoder::plain_text(r, is_secure_scheme, content_length)),
150 }
151 }
152}
153
154impl Stream for Decoder {
155 type Item = Result<Bytes, io::Error>;
156
157 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
158 match self.inner {
160 Inner::Pending(ref mut future) => match futures_core::ready!(Pin::new(future).poll(cx))
161 {
162 Ok(inner) => {
163 self.inner = inner;
164 self.poll_next(cx)
165 },
166 Err(e) => Poll::Ready(Some(Err(e))),
167 },
168 Inner::PlainText(ref mut body) => Pin::new(body).poll_next(cx),
169 Inner::Gzip(ref mut decoder) => {
170 match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
171 Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))),
172 Some(Err(err)) => Poll::Ready(Some(Err(err))),
173 None => Poll::Ready(None),
174 }
175 },
176 Inner::Brotli(ref mut decoder) => {
177 match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
178 Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))),
179 Some(Err(err)) => Poll::Ready(Some(Err(err))),
180 None => Poll::Ready(None),
181 }
182 },
183 Inner::Deflate(ref mut decoder) => {
184 match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
185 Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))),
186 Some(Err(err)) => Poll::Ready(Some(Err(err))),
187 None => Poll::Ready(None),
188 }
189 },
190 Inner::Zstd(ref mut decoder) => {
191 match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
192 Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))),
193 Some(Err(err)) => Poll::Ready(Some(Err(err))),
194 None => Poll::Ready(None),
195 }
196 },
197 }
198 }
199}
200
201impl Future for Pending {
202 type Output = Result<Inner, io::Error>;
203
204 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
205 match futures_core::ready!(Pin::new(&mut self.body).poll_peek(cx)) {
206 Some(Ok(_)) => {
207 },
209 Some(Err(_e)) => {
210 return Poll::Ready(Err(futures_core::ready!(
212 Pin::new(&mut self.body).poll_next(cx)
213 )
214 .expect("just peeked Some")
215 .unwrap_err()));
216 },
217 None => return Poll::Ready(Ok(Inner::PlainText(BodyStream::empty()))),
218 };
219
220 let body = std::mem::replace(&mut self.body, BodyStream::empty().peekable());
221
222 match self.type_ {
223 DecoderType::Brotli => Poll::Ready(Ok(Inner::Brotli(FramedRead::with_capacity(
224 BrotliDecoder::new(StreamReader::new(body)),
225 BytesCodec::new(),
226 DECODER_BUFFER_SIZE,
227 )))),
228 DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(FramedRead::with_capacity(
229 GzipDecoder::new(StreamReader::new(body)),
230 BytesCodec::new(),
231 DECODER_BUFFER_SIZE,
232 )))),
233 DecoderType::Deflate => Poll::Ready(Ok(Inner::Deflate(FramedRead::with_capacity(
234 ZlibDecoder::new(StreamReader::new(body)),
235 BytesCodec::new(),
236 DECODER_BUFFER_SIZE,
237 )))),
238 DecoderType::Zstd => Poll::Ready(Ok(Inner::Zstd(FramedRead::with_capacity(
239 ZstdDecoder::new(StreamReader::new(body)),
240 BytesCodec::new(),
241 DECODER_BUFFER_SIZE,
242 )))),
243 }
244 }
245}
246
247struct BodyStream {
248 body: BoxedBody,
249 is_secure_scheme: bool,
250 content_length: Option<ContentLength>,
251 total_read: u64,
252}
253
254impl BodyStream {
255 fn empty() -> Self {
256 BodyStream {
257 body: http_body_util::Empty::new()
258 .map_err(|_| unreachable!())
259 .boxed(),
260 is_secure_scheme: false,
261 content_length: None,
262 total_read: 0,
263 }
264 }
265
266 fn new(body: BoxedBody, is_secure_scheme: bool, content_length: Option<ContentLength>) -> Self {
267 BodyStream {
268 body,
269 is_secure_scheme,
270 content_length,
271 total_read: 0,
272 }
273 }
274}
275
276impl Stream for BodyStream {
277 type Item = Result<Bytes, io::Error>;
278
279 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
280 match futures_core::ready!(Pin::new(&mut self.body).poll_frame(cx)) {
281 Some(Ok(bytes)) => {
282 let Ok(bytes) = bytes.into_data() else {
283 return Poll::Ready(None);
284 };
285 self.total_read += bytes.len() as u64;
286 Poll::Ready(Some(Ok(bytes)))
287 },
288 Some(Err(err)) => {
289 let all_content_read = self.content_length.is_some_and(|c| c.0 == self.total_read);
296 if self.is_secure_scheme && all_content_read {
297 let source = err.source();
298 let is_unexpected_eof = source
299 .and_then(|e| e.downcast_ref::<io::Error>())
300 .is_some_and(|e| e.kind() == io::ErrorKind::UnexpectedEof);
301 if is_unexpected_eof {
302 return Poll::Ready(None);
303 }
304 }
305 Poll::Ready(Some(Err(io::Error::other(err))))
306 },
307 None => Poll::Ready(None),
308 }
309 }
310}