net/
decoder.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
4
5//! Adapted from an implementation in reqwest.
6
7/*!
8A non-blocking response decoder.
9
10The decoder wraps a stream of bytes and produces a new stream of decompressed bytes.
11The decompressed bytes aren't guaranteed to align to the compressed ones.
12
13If the response is plaintext then no additional work is carried out.
14Bytes are just passed along.
15
16If the response is gzip, deflate or brotli then the bytes are decompressed.
17*/
18
19use std::error::Error;
20use std::fmt;
21use std::io::{self};
22use std::pin::Pin;
23
24use async_compression::tokio::bufread::{BrotliDecoder, GzipDecoder, ZlibDecoder};
25use bytes::Bytes;
26use futures::stream::Peekable;
27use futures::task::{Context, Poll};
28use futures::{Future, Stream};
29use futures_util::StreamExt;
30use headers::{ContentLength, HeaderMapExt};
31use http_body_util::BodyExt;
32use hyper::Response;
33use hyper::body::Body;
34use hyper::header::{CONTENT_ENCODING, HeaderValue, TRANSFER_ENCODING};
35use tokio_util::codec::{BytesCodec, FramedRead};
36use tokio_util::io::StreamReader;
37
38use crate::connector::BoxedBody;
39
40pub const DECODER_BUFFER_SIZE: usize = 8192;
41
42/// A response decompressor over a non-blocking stream of bytes.
43///
44/// The inner decoder may be constructed asynchronously.
45pub struct Decoder {
46    inner: Inner,
47}
48
49#[derive(PartialEq)]
50enum DecoderType {
51    Gzip,
52    Brotli,
53    Deflate,
54}
55
56enum Inner {
57    /// A `PlainText` decoder just returns the response content as is.
58    PlainText(BodyStream),
59    /// A `Gzip` decoder will uncompress the gzipped response content before returning it.
60    Gzip(FramedRead<GzipDecoder<StreamReader<Peekable<BodyStream>, Bytes>>, BytesCodec>),
61    /// A `Delfate` decoder will uncompress the inflated response content before returning it.
62    Deflate(FramedRead<ZlibDecoder<StreamReader<Peekable<BodyStream>, Bytes>>, BytesCodec>),
63    /// A `Brotli` decoder will uncompress the brotli-encoded response content before returning it.
64    Brotli(FramedRead<BrotliDecoder<StreamReader<Peekable<BodyStream>, Bytes>>, BytesCodec>),
65    /// A decoder that doesn't have a value yet.
66    Pending(Pending),
67}
68
69/// A future attempt to poll the response body for EOF so we know whether to use gzip or not.
70struct Pending {
71    body: Peekable<BodyStream>,
72    type_: DecoderType,
73}
74
75impl fmt::Debug for Decoder {
76    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
77        f.debug_struct("Decoder").finish()
78    }
79}
80
81impl Decoder {
82    /// A plain text decoder.
83    ///
84    /// This decoder will emit the underlying bytes as-is.
85    #[inline]
86    fn plain_text(
87        body: BoxedBody,
88        is_secure_scheme: bool,
89        content_length: Option<ContentLength>,
90    ) -> Decoder {
91        Decoder {
92            inner: Inner::PlainText(BodyStream::new(body, is_secure_scheme, content_length)),
93        }
94    }
95
96    /// A pending decoder.
97    ///
98    /// This decoder will buffer and decompress bytes that are encoded in the expected format.
99    #[inline]
100    fn pending(
101        body: BoxedBody,
102        type_: DecoderType,
103        is_secure_scheme: bool,
104        content_length: Option<ContentLength>,
105    ) -> Decoder {
106        Decoder {
107            inner: Inner::Pending(Pending {
108                body: BodyStream::new(body, is_secure_scheme, content_length).peekable(),
109                type_,
110            }),
111        }
112    }
113
114    /// Constructs a Decoder from a hyper response.
115    ///
116    /// A decoder is just a wrapper around the hyper response that knows
117    /// how to decode the content body of the response.
118    ///
119    /// Uses the correct variant by inspecting the Content-Encoding header.
120    pub fn detect(response: Response<BoxedBody>, is_secure_scheme: bool) -> Response<Decoder> {
121        let values = response
122            .headers()
123            .get_all(CONTENT_ENCODING)
124            .iter()
125            .chain(response.headers().get_all(TRANSFER_ENCODING).iter());
126        let decoder = values.fold(None, |acc, enc| {
127            acc.or_else(|| {
128                if enc == HeaderValue::from_static("gzip") {
129                    Some(DecoderType::Gzip)
130                } else if enc == HeaderValue::from_static("br") {
131                    Some(DecoderType::Brotli)
132                } else if enc == HeaderValue::from_static("deflate") {
133                    Some(DecoderType::Deflate)
134                } else {
135                    None
136                }
137            })
138        });
139        let content_length = response.headers().typed_get::<ContentLength>();
140        match decoder {
141            Some(type_) => {
142                response.map(|r| Decoder::pending(r, type_, is_secure_scheme, content_length))
143            },
144            None => response.map(|r| Decoder::plain_text(r, is_secure_scheme, content_length)),
145        }
146    }
147}
148
149impl Stream for Decoder {
150    type Item = Result<Bytes, io::Error>;
151
152    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
153        // Do a read or poll for a pending decoder value.
154        match self.inner {
155            Inner::Pending(ref mut future) => match futures_core::ready!(Pin::new(future).poll(cx))
156            {
157                Ok(inner) => {
158                    self.inner = inner;
159                    self.poll_next(cx)
160                },
161                Err(e) => Poll::Ready(Some(Err(e))),
162            },
163            Inner::PlainText(ref mut body) => Pin::new(body).poll_next(cx),
164            Inner::Gzip(ref mut decoder) => {
165                match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
166                    Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))),
167                    Some(Err(err)) => Poll::Ready(Some(Err(err))),
168                    None => Poll::Ready(None),
169                }
170            },
171            Inner::Brotli(ref mut decoder) => {
172                match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
173                    Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))),
174                    Some(Err(err)) => Poll::Ready(Some(Err(err))),
175                    None => Poll::Ready(None),
176                }
177            },
178            Inner::Deflate(ref mut decoder) => {
179                match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
180                    Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))),
181                    Some(Err(err)) => Poll::Ready(Some(Err(err))),
182                    None => Poll::Ready(None),
183                }
184            },
185        }
186    }
187}
188
189impl Future for Pending {
190    type Output = Result<Inner, io::Error>;
191
192    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
193        match futures_core::ready!(Pin::new(&mut self.body).poll_peek(cx)) {
194            Some(Ok(_)) => {
195                // fallthrough
196            },
197            Some(Err(_e)) => {
198                // error was just a ref, so we need to really poll to move it
199                return Poll::Ready(Err(futures_core::ready!(
200                    Pin::new(&mut self.body).poll_next(cx)
201                )
202                .expect("just peeked Some")
203                .unwrap_err()));
204            },
205            None => return Poll::Ready(Ok(Inner::PlainText(BodyStream::empty()))),
206        };
207
208        let body = std::mem::replace(&mut self.body, BodyStream::empty().peekable());
209
210        match self.type_ {
211            DecoderType::Brotli => Poll::Ready(Ok(Inner::Brotli(FramedRead::with_capacity(
212                BrotliDecoder::new(StreamReader::new(body)),
213                BytesCodec::new(),
214                DECODER_BUFFER_SIZE,
215            )))),
216            DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(FramedRead::with_capacity(
217                GzipDecoder::new(StreamReader::new(body)),
218                BytesCodec::new(),
219                DECODER_BUFFER_SIZE,
220            )))),
221            DecoderType::Deflate => Poll::Ready(Ok(Inner::Deflate(FramedRead::with_capacity(
222                ZlibDecoder::new(StreamReader::new(body)),
223                BytesCodec::new(),
224                DECODER_BUFFER_SIZE,
225            )))),
226        }
227    }
228}
229
230struct BodyStream {
231    body: BoxedBody,
232    is_secure_scheme: bool,
233    content_length: Option<ContentLength>,
234    total_read: u64,
235}
236
237impl BodyStream {
238    fn empty() -> Self {
239        BodyStream {
240            body: http_body_util::Empty::new()
241                .map_err(|_| unreachable!())
242                .boxed(),
243            is_secure_scheme: false,
244            content_length: None,
245            total_read: 0,
246        }
247    }
248
249    fn new(body: BoxedBody, is_secure_scheme: bool, content_length: Option<ContentLength>) -> Self {
250        BodyStream {
251            body,
252            is_secure_scheme,
253            content_length,
254            total_read: 0,
255        }
256    }
257}
258
259impl Stream for BodyStream {
260    type Item = Result<Bytes, io::Error>;
261
262    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
263        match futures_core::ready!(Pin::new(&mut self.body).poll_frame(cx)) {
264            Some(Ok(bytes)) => {
265                let Ok(bytes) = bytes.into_data() else {
266                    return Poll::Ready(None);
267                };
268                self.total_read += bytes.len() as u64;
269                Poll::Ready(Some(Ok(bytes)))
270            },
271            Some(Err(err)) => {
272                // To prevent truncation attacks rustls treats close connection without a close_notify as
273                // an error of type std::io::Error with ErrorKind::UnexpectedEof.
274                // https://docs.rs/rustls/latest/rustls/manual/_03_howto/index.html#unexpected-eof
275                //
276                // The error can be safely ignored if we known that all content was received or is explicitly
277                // set in preferences.
278                let all_content_read = self.content_length.is_some_and(|c| c.0 == self.total_read);
279                if self.is_secure_scheme && all_content_read {
280                    let source = err.source();
281                    let is_unexpected_eof = source
282                        .and_then(|e| e.downcast_ref::<io::Error>())
283                        .is_some_and(|e| e.kind() == io::ErrorKind::UnexpectedEof);
284                    if is_unexpected_eof {
285                        return Poll::Ready(None);
286                    }
287                }
288                Poll::Ready(Some(Err(io::Error::other(err))))
289            },
290            None => Poll::Ready(None),
291        }
292    }
293}