net/
decoder.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at https://mozilla.org/MPL/2.0/. */

//! Adapted from an implementation in reqwest.

/*!
A non-blocking response decoder.

The decoder wraps a stream of bytes and produces a new stream of decompressed bytes.
The decompressed bytes aren't guaranteed to align to the compressed ones.

If the response is plaintext then no additional work is carried out.
Bytes are just passed along.

If the response is gzip, deflate or brotli then the bytes are decompressed.
*/

use std::error::Error;
use std::fmt;
use std::io::{self};
use std::pin::Pin;

use async_compression::tokio::bufread::{BrotliDecoder, GzipDecoder, ZlibDecoder};
use bytes::Bytes;
use futures::stream::Peekable;
use futures::task::{Context, Poll};
use futures::{Future, Stream};
use futures_util::StreamExt;
use headers::{ContentLength, HeaderMapExt};
use hyper::header::{HeaderValue, CONTENT_ENCODING, TRANSFER_ENCODING};
use hyper::{Body, Response};
use servo_config::pref;
use tokio_util::codec::{BytesCodec, FramedRead};
use tokio_util::io::StreamReader;

pub const DECODER_BUFFER_SIZE: usize = 8192;

/// A response decompressor over a non-blocking stream of bytes.
///
/// The inner decoder may be constructed asynchronously.
pub struct Decoder {
    inner: Inner,
}

#[derive(PartialEq)]
enum DecoderType {
    Gzip,
    Brotli,
    Deflate,
}

enum Inner {
    /// A `PlainText` decoder just returns the response content as is.
    PlainText(BodyStream),
    /// A `Gzip` decoder will uncompress the gzipped response content before returning it.
    Gzip(FramedRead<GzipDecoder<StreamReader<Peekable<BodyStream>, Bytes>>, BytesCodec>),
    /// A `Delfate` decoder will uncompress the inflated response content before returning it.
    Deflate(FramedRead<ZlibDecoder<StreamReader<Peekable<BodyStream>, Bytes>>, BytesCodec>),
    /// A `Brotli` decoder will uncompress the brotli-encoded response content before returning it.
    Brotli(FramedRead<BrotliDecoder<StreamReader<Peekable<BodyStream>, Bytes>>, BytesCodec>),
    /// A decoder that doesn't have a value yet.
    Pending(Pending),
}

/// A future attempt to poll the response body for EOF so we know whether to use gzip or not.
struct Pending {
    body: Peekable<BodyStream>,
    type_: DecoderType,
}

impl fmt::Debug for Decoder {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        f.debug_struct("Decoder").finish()
    }
}

impl Decoder {
    /// A plain text decoder.
    ///
    /// This decoder will emit the underlying bytes as-is.
    #[inline]
    fn plain_text(
        body: Body,
        is_secure_scheme: bool,
        content_length: Option<ContentLength>,
    ) -> Decoder {
        Decoder {
            inner: Inner::PlainText(BodyStream::new(body, is_secure_scheme, content_length)),
        }
    }

    /// A pending decoder.
    ///
    /// This decoder will buffer and decompress bytes that are encoded in the expected format.
    #[inline]
    fn pending(
        body: Body,
        type_: DecoderType,
        is_secure_scheme: bool,
        content_length: Option<ContentLength>,
    ) -> Decoder {
        Decoder {
            inner: Inner::Pending(Pending {
                body: BodyStream::new(body, is_secure_scheme, content_length).peekable(),
                type_,
            }),
        }
    }

    /// Constructs a Decoder from a hyper response.
    ///
    /// A decoder is just a wrapper around the hyper response that knows
    /// how to decode the content body of the response.
    ///
    /// Uses the correct variant by inspecting the Content-Encoding header.
    pub fn detect(response: Response<Body>, is_secure_scheme: bool) -> Response<Decoder> {
        let values = response
            .headers()
            .get_all(CONTENT_ENCODING)
            .iter()
            .chain(response.headers().get_all(TRANSFER_ENCODING).iter());
        let decoder = values.fold(None, |acc, enc| {
            acc.or_else(|| {
                if enc == HeaderValue::from_static("gzip") {
                    Some(DecoderType::Gzip)
                } else if enc == HeaderValue::from_static("br") {
                    Some(DecoderType::Brotli)
                } else if enc == HeaderValue::from_static("deflate") {
                    Some(DecoderType::Deflate)
                } else {
                    None
                }
            })
        });
        let content_length = response.headers().typed_get::<ContentLength>();
        match decoder {
            Some(type_) => {
                response.map(|r| Decoder::pending(r, type_, is_secure_scheme, content_length))
            },
            None => response.map(|r| Decoder::plain_text(r, is_secure_scheme, content_length)),
        }
    }
}

impl Stream for Decoder {
    type Item = Result<Bytes, io::Error>;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        // Do a read or poll for a pending decoder value.
        match self.inner {
            Inner::Pending(ref mut future) => match futures_core::ready!(Pin::new(future).poll(cx))
            {
                Ok(inner) => {
                    self.inner = inner;
                    self.poll_next(cx)
                },
                Err(e) => Poll::Ready(Some(Err(e))),
            },
            Inner::PlainText(ref mut body) => Pin::new(body).poll_next(cx),
            Inner::Gzip(ref mut decoder) => {
                match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
                    Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))),
                    Some(Err(err)) => Poll::Ready(Some(Err(err))),
                    None => Poll::Ready(None),
                }
            },
            Inner::Brotli(ref mut decoder) => {
                match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
                    Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))),
                    Some(Err(err)) => Poll::Ready(Some(Err(err))),
                    None => Poll::Ready(None),
                }
            },
            Inner::Deflate(ref mut decoder) => {
                match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
                    Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))),
                    Some(Err(err)) => Poll::Ready(Some(Err(err))),
                    None => Poll::Ready(None),
                }
            },
        }
    }
}

impl Future for Pending {
    type Output = Result<Inner, io::Error>;

    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        match futures_core::ready!(Pin::new(&mut self.body).poll_peek(cx)) {
            Some(Ok(_)) => {
                // fallthrough
            },
            Some(Err(_e)) => {
                // error was just a ref, so we need to really poll to move it
                return Poll::Ready(Err(futures_core::ready!(
                    Pin::new(&mut self.body).poll_next(cx)
                )
                .expect("just peeked Some")
                .unwrap_err()));
            },
            None => return Poll::Ready(Ok(Inner::PlainText(BodyStream::empty()))),
        };

        let body = std::mem::replace(&mut self.body, BodyStream::empty().peekable());

        match self.type_ {
            DecoderType::Brotli => Poll::Ready(Ok(Inner::Brotli(FramedRead::with_capacity(
                BrotliDecoder::new(StreamReader::new(body)),
                BytesCodec::new(),
                DECODER_BUFFER_SIZE,
            )))),
            DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(FramedRead::with_capacity(
                GzipDecoder::new(StreamReader::new(body)),
                BytesCodec::new(),
                DECODER_BUFFER_SIZE,
            )))),
            DecoderType::Deflate => Poll::Ready(Ok(Inner::Deflate(FramedRead::with_capacity(
                ZlibDecoder::new(StreamReader::new(body)),
                BytesCodec::new(),
                DECODER_BUFFER_SIZE,
            )))),
        }
    }
}

struct BodyStream {
    body: Body,
    is_secure_scheme: bool,
    content_length: Option<ContentLength>,
    total_read: u64,
}

impl BodyStream {
    fn empty() -> Self {
        BodyStream {
            body: Body::empty(),
            is_secure_scheme: false,
            content_length: None,
            total_read: 0,
        }
    }

    fn new(body: Body, is_secure_scheme: bool, content_length: Option<ContentLength>) -> Self {
        BodyStream {
            body,
            is_secure_scheme,
            content_length,
            total_read: 0,
        }
    }
}

impl Stream for BodyStream {
    type Item = Result<Bytes, io::Error>;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
        match futures_core::ready!(Pin::new(&mut self.body).poll_next(cx)) {
            Some(Ok(bytes)) => {
                self.total_read += bytes.len() as u64;
                Poll::Ready(Some(Ok(bytes)))
            },
            Some(Err(err)) => {
                // To prevent truncation attacks rustls treats close connection without a close_notify as
                // an error of type std::io::Error with ErrorKind::UnexpectedEof.
                // https://docs.rs/rustls/latest/rustls/manual/_03_howto/index.html#unexpected-eof
                //
                // The error can be safely ignored if we known that all content was received or is explicitly
                // set in preferences.
                let all_content_read = self
                    .content_length
                    .map_or(false, |c| c.0 == self.total_read);
                if self.is_secure_scheme &&
                    (all_content_read || pref!(network.tls.ignore_unexpected_eof))
                {
                    let source = err.source();
                    let is_unexpected_eof = source
                        .and_then(|e| e.downcast_ref::<io::Error>())
                        .map_or(false, |e| e.kind() == io::ErrorKind::UnexpectedEof);
                    if is_unexpected_eof {
                        return Poll::Ready(None);
                    }
                }
                Poll::Ready(Some(Err(io::Error::new(io::ErrorKind::Other, err))))
            },
            None => Poll::Ready(None),
        }
    }
}