use std::error::Error;
use std::fmt;
use std::io::{self};
use std::pin::Pin;
use async_compression::tokio::bufread::{BrotliDecoder, GzipDecoder, ZlibDecoder};
use bytes::Bytes;
use futures::stream::Peekable;
use futures::task::{Context, Poll};
use futures::{Future, Stream};
use futures_util::StreamExt;
use headers::{ContentLength, HeaderMapExt};
use hyper::header::{HeaderValue, CONTENT_ENCODING, TRANSFER_ENCODING};
use hyper::{Body, Response};
use servo_config::pref;
use tokio_util::codec::{BytesCodec, FramedRead};
use tokio_util::io::StreamReader;
pub const DECODER_BUFFER_SIZE: usize = 8192;
pub struct Decoder {
inner: Inner,
}
#[derive(PartialEq)]
enum DecoderType {
Gzip,
Brotli,
Deflate,
}
enum Inner {
PlainText(BodyStream),
Gzip(FramedRead<GzipDecoder<StreamReader<Peekable<BodyStream>, Bytes>>, BytesCodec>),
Deflate(FramedRead<ZlibDecoder<StreamReader<Peekable<BodyStream>, Bytes>>, BytesCodec>),
Brotli(FramedRead<BrotliDecoder<StreamReader<Peekable<BodyStream>, Bytes>>, BytesCodec>),
Pending(Pending),
}
struct Pending {
body: Peekable<BodyStream>,
type_: DecoderType,
}
impl fmt::Debug for Decoder {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Decoder").finish()
}
}
impl Decoder {
#[inline]
fn plain_text(
body: Body,
is_secure_scheme: bool,
content_length: Option<ContentLength>,
) -> Decoder {
Decoder {
inner: Inner::PlainText(BodyStream::new(body, is_secure_scheme, content_length)),
}
}
#[inline]
fn pending(
body: Body,
type_: DecoderType,
is_secure_scheme: bool,
content_length: Option<ContentLength>,
) -> Decoder {
Decoder {
inner: Inner::Pending(Pending {
body: BodyStream::new(body, is_secure_scheme, content_length).peekable(),
type_,
}),
}
}
pub fn detect(response: Response<Body>, is_secure_scheme: bool) -> Response<Decoder> {
let values = response
.headers()
.get_all(CONTENT_ENCODING)
.iter()
.chain(response.headers().get_all(TRANSFER_ENCODING).iter());
let decoder = values.fold(None, |acc, enc| {
acc.or_else(|| {
if enc == HeaderValue::from_static("gzip") {
Some(DecoderType::Gzip)
} else if enc == HeaderValue::from_static("br") {
Some(DecoderType::Brotli)
} else if enc == HeaderValue::from_static("deflate") {
Some(DecoderType::Deflate)
} else {
None
}
})
});
let content_length = response.headers().typed_get::<ContentLength>();
match decoder {
Some(type_) => {
response.map(|r| Decoder::pending(r, type_, is_secure_scheme, content_length))
},
None => response.map(|r| Decoder::plain_text(r, is_secure_scheme, content_length)),
}
}
}
impl Stream for Decoder {
type Item = Result<Bytes, io::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.inner {
Inner::Pending(ref mut future) => match futures_core::ready!(Pin::new(future).poll(cx))
{
Ok(inner) => {
self.inner = inner;
self.poll_next(cx)
},
Err(e) => Poll::Ready(Some(Err(e))),
},
Inner::PlainText(ref mut body) => Pin::new(body).poll_next(cx),
Inner::Gzip(ref mut decoder) => {
match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))),
Some(Err(err)) => Poll::Ready(Some(Err(err))),
None => Poll::Ready(None),
}
},
Inner::Brotli(ref mut decoder) => {
match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))),
Some(Err(err)) => Poll::Ready(Some(Err(err))),
None => Poll::Ready(None),
}
},
Inner::Deflate(ref mut decoder) => {
match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))),
Some(Err(err)) => Poll::Ready(Some(Err(err))),
None => Poll::Ready(None),
}
},
}
}
}
impl Future for Pending {
type Output = Result<Inner, io::Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match futures_core::ready!(Pin::new(&mut self.body).poll_peek(cx)) {
Some(Ok(_)) => {
},
Some(Err(_e)) => {
return Poll::Ready(Err(futures_core::ready!(
Pin::new(&mut self.body).poll_next(cx)
)
.expect("just peeked Some")
.unwrap_err()));
},
None => return Poll::Ready(Ok(Inner::PlainText(BodyStream::empty()))),
};
let body = std::mem::replace(&mut self.body, BodyStream::empty().peekable());
match self.type_ {
DecoderType::Brotli => Poll::Ready(Ok(Inner::Brotli(FramedRead::with_capacity(
BrotliDecoder::new(StreamReader::new(body)),
BytesCodec::new(),
DECODER_BUFFER_SIZE,
)))),
DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(FramedRead::with_capacity(
GzipDecoder::new(StreamReader::new(body)),
BytesCodec::new(),
DECODER_BUFFER_SIZE,
)))),
DecoderType::Deflate => Poll::Ready(Ok(Inner::Deflate(FramedRead::with_capacity(
ZlibDecoder::new(StreamReader::new(body)),
BytesCodec::new(),
DECODER_BUFFER_SIZE,
)))),
}
}
}
struct BodyStream {
body: Body,
is_secure_scheme: bool,
content_length: Option<ContentLength>,
total_read: u64,
}
impl BodyStream {
fn empty() -> Self {
BodyStream {
body: Body::empty(),
is_secure_scheme: false,
content_length: None,
total_read: 0,
}
}
fn new(body: Body, is_secure_scheme: bool, content_length: Option<ContentLength>) -> Self {
BodyStream {
body,
is_secure_scheme,
content_length,
total_read: 0,
}
}
}
impl Stream for BodyStream {
type Item = Result<Bytes, io::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match futures_core::ready!(Pin::new(&mut self.body).poll_next(cx)) {
Some(Ok(bytes)) => {
self.total_read += bytes.len() as u64;
Poll::Ready(Some(Ok(bytes)))
},
Some(Err(err)) => {
let all_content_read = self
.content_length
.map_or(false, |c| c.0 == self.total_read);
if self.is_secure_scheme &&
(all_content_read || pref!(network.tls.ignore_unexpected_eof))
{
let source = err.source();
let is_unexpected_eof = source
.and_then(|e| e.downcast_ref::<io::Error>())
.map_or(false, |e| e.kind() == io::ErrorKind::UnexpectedEof);
if is_unexpected_eof {
return Poll::Ready(None);
}
}
Poll::Ready(Some(Err(io::Error::new(io::ErrorKind::Other, err))))
},
None => Poll::Ready(None),
}
}
}