net/
decoder.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
4
5//! Adapted from an implementation in reqwest.
6
7/*!
8A non-blocking response decoder.
9
10The decoder wraps a stream of bytes and produces a new stream of decompressed bytes.
11The decompressed bytes aren't guaranteed to align to the compressed ones.
12
13If the response is plaintext then no additional work is carried out.
14Bytes are just passed along.
15
16If the response is gzip, deflate or brotli then the bytes are decompressed.
17*/
18
19use std::error::Error;
20use std::fmt;
21use std::io::{self};
22use std::pin::Pin;
23
24use async_compression::tokio::bufread::{BrotliDecoder, GzipDecoder, ZlibDecoder, ZstdDecoder};
25use bytes::Bytes;
26use futures::stream::Peekable;
27use futures::task::{Context, Poll};
28use futures::{Future, Stream};
29use futures_util::StreamExt;
30use headers::{ContentLength, HeaderMapExt};
31use http_body_util::BodyExt;
32use hyper::Response;
33use hyper::body::Body;
34use hyper::header::{CONTENT_ENCODING, HeaderValue, TRANSFER_ENCODING};
35use tokio_util::codec::{BytesCodec, FramedRead};
36use tokio_util::io::StreamReader;
37
38use crate::connector::BoxedBody;
39
40pub const DECODER_BUFFER_SIZE: usize = 8192;
41
42/// A response decompressor over a non-blocking stream of bytes.
43///
44/// The inner decoder may be constructed asynchronously.
45pub struct Decoder {
46    inner: Inner,
47}
48
49#[derive(PartialEq)]
50enum DecoderType {
51    Gzip,
52    Brotli,
53    Deflate,
54    Zstd,
55}
56
57enum Inner {
58    /// A `PlainText` decoder just returns the response content as is.
59    PlainText(BodyStream),
60    /// A `Gzip` decoder will uncompress the gzipped response content before returning it.
61    Gzip(FramedRead<GzipDecoder<StreamReader<Peekable<BodyStream>, Bytes>>, BytesCodec>),
62    /// A `Delfate` decoder will uncompress the inflated response content before returning it.
63    Deflate(FramedRead<ZlibDecoder<StreamReader<Peekable<BodyStream>, Bytes>>, BytesCodec>),
64    /// A `Brotli` decoder will uncompress the brotli-encoded response content before returning it.
65    Brotli(FramedRead<BrotliDecoder<StreamReader<Peekable<BodyStream>, Bytes>>, BytesCodec>),
66    /// A `Zstd` decoder will uncompress the zstd-encoded response content before returning it.
67    Zstd(FramedRead<ZstdDecoder<StreamReader<Peekable<BodyStream>, Bytes>>, BytesCodec>),
68    /// A decoder that doesn't have a value yet.
69    Pending(Pending),
70}
71
72/// A future attempt to poll the response body for EOF so we know whether to use gzip or not.
73struct Pending {
74    body: Peekable<BodyStream>,
75    type_: DecoderType,
76}
77
78impl fmt::Debug for Decoder {
79    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
80        f.debug_struct("Decoder").finish()
81    }
82}
83
84impl Decoder {
85    /// A plain text decoder.
86    ///
87    /// This decoder will emit the underlying bytes as-is.
88    #[inline]
89    fn plain_text(
90        body: BoxedBody,
91        is_secure_scheme: bool,
92        content_length: Option<ContentLength>,
93    ) -> Decoder {
94        Decoder {
95            inner: Inner::PlainText(BodyStream::new(body, is_secure_scheme, content_length)),
96        }
97    }
98
99    /// A pending decoder.
100    ///
101    /// This decoder will buffer and decompress bytes that are encoded in the expected format.
102    #[inline]
103    fn pending(
104        body: BoxedBody,
105        type_: DecoderType,
106        is_secure_scheme: bool,
107        content_length: Option<ContentLength>,
108    ) -> Decoder {
109        Decoder {
110            inner: Inner::Pending(Pending {
111                body: BodyStream::new(body, is_secure_scheme, content_length).peekable(),
112                type_,
113            }),
114        }
115    }
116
117    /// Constructs a Decoder from a hyper response.
118    ///
119    /// A decoder is just a wrapper around the hyper response that knows
120    /// how to decode the content body of the response.
121    ///
122    /// Uses the correct variant by inspecting the Content-Encoding header.
123    pub fn detect(response: Response<BoxedBody>, is_secure_scheme: bool) -> Response<Decoder> {
124        let values = response
125            .headers()
126            .get_all(CONTENT_ENCODING)
127            .iter()
128            .chain(response.headers().get_all(TRANSFER_ENCODING).iter());
129        let decoder = values.fold(None, |acc, enc| {
130            acc.or_else(|| {
131                if enc == HeaderValue::from_static("gzip") {
132                    Some(DecoderType::Gzip)
133                } else if enc == HeaderValue::from_static("br") {
134                    Some(DecoderType::Brotli)
135                } else if enc == HeaderValue::from_static("deflate") {
136                    Some(DecoderType::Deflate)
137                } else if enc == HeaderValue::from_static("zstd") {
138                    Some(DecoderType::Zstd)
139                } else {
140                    None
141                }
142            })
143        });
144        let content_length = response.headers().typed_get::<ContentLength>();
145        match decoder {
146            Some(type_) => {
147                response.map(|r| Decoder::pending(r, type_, is_secure_scheme, content_length))
148            },
149            None => response.map(|r| Decoder::plain_text(r, is_secure_scheme, content_length)),
150        }
151    }
152}
153
154impl Stream for Decoder {
155    type Item = Result<Bytes, io::Error>;
156
157    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
158        // Do a read or poll for a pending decoder value.
159        match self.inner {
160            Inner::Pending(ref mut future) => match futures_core::ready!(Pin::new(future).poll(cx))
161            {
162                Ok(inner) => {
163                    self.inner = inner;
164                    self.poll_next(cx)
165                },
166                Err(e) => Poll::Ready(Some(Err(e))),
167            },
168            Inner::PlainText(ref mut body) => Pin::new(body).poll_next(cx),
169            Inner::Gzip(ref mut decoder) => {
170                match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
171                    Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))),
172                    Some(Err(err)) => Poll::Ready(Some(Err(err))),
173                    None => Poll::Ready(None),
174                }
175            },
176            Inner::Brotli(ref mut decoder) => {
177                match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
178                    Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))),
179                    Some(Err(err)) => Poll::Ready(Some(Err(err))),
180                    None => Poll::Ready(None),
181                }
182            },
183            Inner::Deflate(ref mut decoder) => {
184                match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
185                    Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))),
186                    Some(Err(err)) => Poll::Ready(Some(Err(err))),
187                    None => Poll::Ready(None),
188                }
189            },
190            Inner::Zstd(ref mut decoder) => {
191                match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
192                    Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))),
193                    Some(Err(err)) => Poll::Ready(Some(Err(err))),
194                    None => Poll::Ready(None),
195                }
196            },
197        }
198    }
199}
200
201impl Future for Pending {
202    type Output = Result<Inner, io::Error>;
203
204    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
205        match futures_core::ready!(Pin::new(&mut self.body).poll_peek(cx)) {
206            Some(Ok(_)) => {
207                // fallthrough
208            },
209            Some(Err(_e)) => {
210                // error was just a ref, so we need to really poll to move it
211                return Poll::Ready(Err(futures_core::ready!(
212                    Pin::new(&mut self.body).poll_next(cx)
213                )
214                .expect("just peeked Some")
215                .unwrap_err()));
216            },
217            None => return Poll::Ready(Ok(Inner::PlainText(BodyStream::empty()))),
218        };
219
220        let body = std::mem::replace(&mut self.body, BodyStream::empty().peekable());
221
222        match self.type_ {
223            DecoderType::Brotli => Poll::Ready(Ok(Inner::Brotli(FramedRead::with_capacity(
224                BrotliDecoder::new(StreamReader::new(body)),
225                BytesCodec::new(),
226                DECODER_BUFFER_SIZE,
227            )))),
228            DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(FramedRead::with_capacity(
229                GzipDecoder::new(StreamReader::new(body)),
230                BytesCodec::new(),
231                DECODER_BUFFER_SIZE,
232            )))),
233            DecoderType::Deflate => Poll::Ready(Ok(Inner::Deflate(FramedRead::with_capacity(
234                ZlibDecoder::new(StreamReader::new(body)),
235                BytesCodec::new(),
236                DECODER_BUFFER_SIZE,
237            )))),
238            DecoderType::Zstd => Poll::Ready(Ok(Inner::Zstd(FramedRead::with_capacity(
239                ZstdDecoder::new(StreamReader::new(body)),
240                BytesCodec::new(),
241                DECODER_BUFFER_SIZE,
242            )))),
243        }
244    }
245}
246
247struct BodyStream {
248    body: BoxedBody,
249    is_secure_scheme: bool,
250    content_length: Option<ContentLength>,
251    total_read: u64,
252}
253
254impl BodyStream {
255    fn empty() -> Self {
256        BodyStream {
257            body: http_body_util::Empty::new()
258                .map_err(|_| unreachable!())
259                .boxed(),
260            is_secure_scheme: false,
261            content_length: None,
262            total_read: 0,
263        }
264    }
265
266    fn new(body: BoxedBody, is_secure_scheme: bool, content_length: Option<ContentLength>) -> Self {
267        BodyStream {
268            body,
269            is_secure_scheme,
270            content_length,
271            total_read: 0,
272        }
273    }
274}
275
276impl Stream for BodyStream {
277    type Item = Result<Bytes, io::Error>;
278
279    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
280        match futures_core::ready!(Pin::new(&mut self.body).poll_frame(cx)) {
281            Some(Ok(bytes)) => {
282                let Ok(bytes) = bytes.into_data() else {
283                    return Poll::Ready(None);
284                };
285                self.total_read += bytes.len() as u64;
286                Poll::Ready(Some(Ok(bytes)))
287            },
288            Some(Err(err)) => {
289                // To prevent truncation attacks rustls treats close connection without a close_notify as
290                // an error of type std::io::Error with ErrorKind::UnexpectedEof.
291                // https://docs.rs/rustls/latest/rustls/manual/_03_howto/index.html#unexpected-eof
292                //
293                // The error can be safely ignored if we known that all content was received or is explicitly
294                // set in preferences.
295                let all_content_read = self.content_length.is_some_and(|c| c.0 == self.total_read);
296                if self.is_secure_scheme && all_content_read {
297                    let source = err.source();
298                    let is_unexpected_eof = source
299                        .and_then(|e| e.downcast_ref::<io::Error>())
300                        .is_some_and(|e| e.kind() == io::ErrorKind::UnexpectedEof);
301                    if is_unexpected_eof {
302                        return Poll::Ready(None);
303                    }
304                }
305                Poll::Ready(Some(Err(io::Error::other(err))))
306            },
307            None => Poll::Ready(None),
308        }
309    }
310}