1use std::error::Error;
20use std::fmt;
21use std::io::{self};
22use std::pin::Pin;
23
24use async_compression::tokio::bufread::{BrotliDecoder, GzipDecoder, ZlibDecoder, ZstdDecoder};
25use bytes::Bytes;
26use futures::stream::Peekable;
27use futures::task::{Context, Poll};
28use futures::{Future, Stream};
29use futures_util::StreamExt;
30use headers::{ContentLength, HeaderMapExt};
31use http_body_util::BodyExt;
32use hyper::Response;
33use hyper::body::Body;
34use hyper::header::{CONTENT_ENCODING, HeaderValue, TRANSFER_ENCODING};
35use tokio_util::codec::{BytesCodec, FramedRead};
36use tokio_util::io::StreamReader;
37
38use crate::connector::BoxedBody;
39
40pub const DECODER_BUFFER_SIZE: usize = 8192;
41
42pub struct Decoder {
46    inner: Inner,
47}
48
49#[derive(PartialEq)]
50enum DecoderType {
51    Gzip,
52    Brotli,
53    Deflate,
54    Zstd,
55}
56
57enum Inner {
58    PlainText(BodyStream),
60    Gzip(FramedRead<GzipDecoder<StreamReader<Peekable<BodyStream>, Bytes>>, BytesCodec>),
62    Deflate(FramedRead<ZlibDecoder<StreamReader<Peekable<BodyStream>, Bytes>>, BytesCodec>),
64    Brotli(FramedRead<BrotliDecoder<StreamReader<Peekable<BodyStream>, Bytes>>, BytesCodec>),
66    Zstd(FramedRead<ZstdDecoder<StreamReader<Peekable<BodyStream>, Bytes>>, BytesCodec>),
68    Pending(Pending),
70}
71
72struct Pending {
74    body: Peekable<BodyStream>,
75    type_: DecoderType,
76}
77
78impl fmt::Debug for Decoder {
79    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
80        f.debug_struct("Decoder").finish()
81    }
82}
83
84impl Decoder {
85    #[inline]
89    fn plain_text(
90        body: BoxedBody,
91        is_secure_scheme: bool,
92        content_length: Option<ContentLength>,
93    ) -> Decoder {
94        Decoder {
95            inner: Inner::PlainText(BodyStream::new(body, is_secure_scheme, content_length)),
96        }
97    }
98
99    #[inline]
103    fn pending(
104        body: BoxedBody,
105        type_: DecoderType,
106        is_secure_scheme: bool,
107        content_length: Option<ContentLength>,
108    ) -> Decoder {
109        Decoder {
110            inner: Inner::Pending(Pending {
111                body: BodyStream::new(body, is_secure_scheme, content_length).peekable(),
112                type_,
113            }),
114        }
115    }
116
117    pub fn detect(response: Response<BoxedBody>, is_secure_scheme: bool) -> Response<Decoder> {
124        let values = response
125            .headers()
126            .get_all(CONTENT_ENCODING)
127            .iter()
128            .chain(response.headers().get_all(TRANSFER_ENCODING).iter());
129        let decoder = values.fold(None, |acc, enc| {
130            acc.or_else(|| {
131                if enc == HeaderValue::from_static("gzip") {
132                    Some(DecoderType::Gzip)
133                } else if enc == HeaderValue::from_static("br") {
134                    Some(DecoderType::Brotli)
135                } else if enc == HeaderValue::from_static("deflate") {
136                    Some(DecoderType::Deflate)
137                } else if enc == HeaderValue::from_static("zstd") {
138                    Some(DecoderType::Zstd)
139                } else {
140                    None
141                }
142            })
143        });
144        let content_length = response.headers().typed_get::<ContentLength>();
145        match decoder {
146            Some(type_) => {
147                response.map(|r| Decoder::pending(r, type_, is_secure_scheme, content_length))
148            },
149            None => response.map(|r| Decoder::plain_text(r, is_secure_scheme, content_length)),
150        }
151    }
152}
153
154impl Stream for Decoder {
155    type Item = Result<Bytes, io::Error>;
156
157    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
158        match self.inner {
160            Inner::Pending(ref mut future) => match futures_core::ready!(Pin::new(future).poll(cx))
161            {
162                Ok(inner) => {
163                    self.inner = inner;
164                    self.poll_next(cx)
165                },
166                Err(e) => Poll::Ready(Some(Err(e))),
167            },
168            Inner::PlainText(ref mut body) => Pin::new(body).poll_next(cx),
169            Inner::Gzip(ref mut decoder) => {
170                match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
171                    Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))),
172                    Some(Err(err)) => Poll::Ready(Some(Err(err))),
173                    None => Poll::Ready(None),
174                }
175            },
176            Inner::Brotli(ref mut decoder) => {
177                match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
178                    Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))),
179                    Some(Err(err)) => Poll::Ready(Some(Err(err))),
180                    None => Poll::Ready(None),
181                }
182            },
183            Inner::Deflate(ref mut decoder) => {
184                match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
185                    Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))),
186                    Some(Err(err)) => Poll::Ready(Some(Err(err))),
187                    None => Poll::Ready(None),
188                }
189            },
190            Inner::Zstd(ref mut decoder) => {
191                match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
192                    Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))),
193                    Some(Err(err)) => Poll::Ready(Some(Err(err))),
194                    None => Poll::Ready(None),
195                }
196            },
197        }
198    }
199}
200
201impl Future for Pending {
202    type Output = Result<Inner, io::Error>;
203
204    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
205        match futures_core::ready!(Pin::new(&mut self.body).poll_peek(cx)) {
206            Some(Ok(_)) => {
207                },
209            Some(Err(_e)) => {
210                return Poll::Ready(Err(futures_core::ready!(
212                    Pin::new(&mut self.body).poll_next(cx)
213                )
214                .expect("just peeked Some")
215                .unwrap_err()));
216            },
217            None => return Poll::Ready(Ok(Inner::PlainText(BodyStream::empty()))),
218        };
219
220        let body = std::mem::replace(&mut self.body, BodyStream::empty().peekable());
221
222        match self.type_ {
223            DecoderType::Brotli => Poll::Ready(Ok(Inner::Brotli(FramedRead::with_capacity(
224                BrotliDecoder::new(StreamReader::new(body)),
225                BytesCodec::new(),
226                DECODER_BUFFER_SIZE,
227            )))),
228            DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(FramedRead::with_capacity(
229                GzipDecoder::new(StreamReader::new(body)),
230                BytesCodec::new(),
231                DECODER_BUFFER_SIZE,
232            )))),
233            DecoderType::Deflate => Poll::Ready(Ok(Inner::Deflate(FramedRead::with_capacity(
234                ZlibDecoder::new(StreamReader::new(body)),
235                BytesCodec::new(),
236                DECODER_BUFFER_SIZE,
237            )))),
238            DecoderType::Zstd => Poll::Ready(Ok(Inner::Zstd(FramedRead::with_capacity(
239                ZstdDecoder::new(StreamReader::new(body)),
240                BytesCodec::new(),
241                DECODER_BUFFER_SIZE,
242            )))),
243        }
244    }
245}
246
247struct BodyStream {
248    body: BoxedBody,
249    is_secure_scheme: bool,
250    content_length: Option<ContentLength>,
251    total_read: u64,
252}
253
254impl BodyStream {
255    fn empty() -> Self {
256        BodyStream {
257            body: http_body_util::Empty::new()
258                .map_err(|_| unreachable!())
259                .boxed(),
260            is_secure_scheme: false,
261            content_length: None,
262            total_read: 0,
263        }
264    }
265
266    fn new(body: BoxedBody, is_secure_scheme: bool, content_length: Option<ContentLength>) -> Self {
267        BodyStream {
268            body,
269            is_secure_scheme,
270            content_length,
271            total_read: 0,
272        }
273    }
274}
275
276impl Stream for BodyStream {
277    type Item = Result<Bytes, io::Error>;
278
279    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
280        match futures_core::ready!(Pin::new(&mut self.body).poll_frame(cx)) {
281            Some(Ok(bytes)) => {
282                let Ok(bytes) = bytes.into_data() else {
283                    return Poll::Ready(None);
284                };
285                self.total_read += bytes.len() as u64;
286                Poll::Ready(Some(Ok(bytes)))
287            },
288            Some(Err(err)) => {
289                let all_content_read = self.content_length.is_some_and(|c| c.0 == self.total_read);
296                if self.is_secure_scheme && all_content_read {
297                    let source = err.source();
298                    let is_unexpected_eof = source
299                        .and_then(|e| e.downcast_ref::<io::Error>())
300                        .is_some_and(|e| e.kind() == io::ErrorKind::UnexpectedEof);
301                    if is_unexpected_eof {
302                        return Poll::Ready(None);
303                    }
304                }
305                Poll::Ready(Some(Err(io::Error::other(err))))
306            },
307            None => Poll::Ready(None),
308        }
309    }
310}