1use std::error::Error as StdError;
2use std::fmt;
3use std::io;
4use std::task::{Context, Poll};
5
6use bytes::{BufMut, Bytes, BytesMut};
7use futures_core::ready;
8use http::{HeaderMap, HeaderName, HeaderValue};
9use http_body::Frame;
10
11use super::io::MemRead;
12use super::role::DEFAULT_MAX_HEADERS;
13use super::DecodedLength;
14
15use self::Kind::{Chunked, Eof, Length};
16
17const CHUNKED_EXTENSIONS_LIMIT: u64 = 1024 * 16;
21
22const TRAILER_LIMIT: usize = 1024 * 16;
26
27#[derive(Clone, PartialEq)]
32pub(crate) struct Decoder {
33 kind: Kind,
34}
35
36#[derive(Debug, Clone, PartialEq)]
37enum Kind {
38 Length(u64),
40 Chunked {
42 state: ChunkedState,
43 chunk_len: u64,
44 extensions_cnt: u64,
45 trailers_buf: Option<BytesMut>,
46 trailers_cnt: usize,
47 h1_max_headers: Option<usize>,
48 h1_max_header_size: Option<usize>,
49 },
50 Eof(bool),
67}
68
69#[derive(Debug, PartialEq, Clone, Copy)]
70enum ChunkedState {
71 Start,
72 Size,
73 SizeLws,
74 Extension,
75 SizeLf,
76 Body,
77 BodyCr,
78 BodyLf,
79 Trailer,
80 TrailerLf,
81 EndCr,
82 EndLf,
83 End,
84}
85
86impl Decoder {
87 pub(crate) fn length(x: u64) -> Decoder {
90 Decoder {
91 kind: Kind::Length(x),
92 }
93 }
94
95 pub(crate) fn chunked(
96 h1_max_headers: Option<usize>,
97 h1_max_header_size: Option<usize>,
98 ) -> Decoder {
99 Decoder {
100 kind: Kind::Chunked {
101 state: ChunkedState::new(),
102 chunk_len: 0,
103 extensions_cnt: 0,
104 trailers_buf: None,
105 trailers_cnt: 0,
106 h1_max_headers,
107 h1_max_header_size,
108 },
109 }
110 }
111
112 pub(crate) fn eof() -> Decoder {
113 Decoder {
114 kind: Kind::Eof(false),
115 }
116 }
117
118 pub(super) fn new(
119 len: DecodedLength,
120 h1_max_headers: Option<usize>,
121 h1_max_header_size: Option<usize>,
122 ) -> Self {
123 match len {
124 DecodedLength::CHUNKED => Decoder::chunked(h1_max_headers, h1_max_header_size),
125 DecodedLength::CLOSE_DELIMITED => Decoder::eof(),
126 length => Decoder::length(length.danger_len()),
127 }
128 }
129
130 pub(crate) fn is_eof(&self) -> bool {
133 matches!(
134 self.kind,
135 Length(0)
136 | Chunked {
137 state: ChunkedState::End,
138 ..
139 }
140 | Eof(true)
141 )
142 }
143
144 pub(crate) fn decode<R: MemRead>(
145 &mut self,
146 cx: &mut Context<'_>,
147 body: &mut R,
148 ) -> Poll<Result<Frame<Bytes>, io::Error>> {
149 trace!("decode; state={:?}", self.kind);
150 match self.kind {
151 Length(ref mut remaining) => {
152 if *remaining == 0 {
153 Poll::Ready(Ok(Frame::data(Bytes::new())))
154 } else {
155 let to_read = usize::try_from(*remaining).unwrap_or(usize::MAX);
156 let buf = ready!(body.read_mem(cx, to_read))?;
157 let num = buf.as_ref().len() as u64;
158 if num > *remaining {
159 *remaining = 0;
160 } else if num == 0 {
161 return Poll::Ready(Err(io::Error::new(
162 io::ErrorKind::UnexpectedEof,
163 IncompleteBody,
164 )));
165 } else {
166 *remaining -= num;
167 }
168 Poll::Ready(Ok(Frame::data(buf)))
169 }
170 }
171 Chunked {
172 ref mut state,
173 ref mut chunk_len,
174 ref mut extensions_cnt,
175 ref mut trailers_buf,
176 ref mut trailers_cnt,
177 ref h1_max_headers,
178 ref h1_max_header_size,
179 } => {
180 let h1_max_headers = h1_max_headers.unwrap_or(DEFAULT_MAX_HEADERS);
181 let h1_max_header_size = h1_max_header_size.unwrap_or(TRAILER_LIMIT);
182 loop {
183 let mut buf = None;
184 *state = ready!(state.step(
186 cx,
187 body,
188 StepArgs {
189 chunk_size: chunk_len,
190 extensions_cnt,
191 chunk_buf: &mut buf,
192 trailers_buf,
193 trailers_cnt,
194 max_headers_cnt: h1_max_headers,
195 max_headers_bytes: h1_max_header_size,
196 }
197 ))?;
198 if *state == ChunkedState::End {
199 trace!("end of chunked");
200
201 if trailers_buf.is_some() {
202 trace!("found possible trailers");
203
204 if *trailers_cnt >= h1_max_headers {
206 return Poll::Ready(Err(io::Error::new(
207 io::ErrorKind::InvalidData,
208 "chunk trailers count overflow",
209 )));
210 }
211 match decode_trailers(
212 &mut trailers_buf.take().expect("Trailer is None"),
213 *trailers_cnt,
214 ) {
215 Ok(headers) => {
216 return Poll::Ready(Ok(Frame::trailers(headers)));
217 }
218 Err(e) => {
219 return Poll::Ready(Err(e));
220 }
221 }
222 }
223
224 return Poll::Ready(Ok(Frame::data(Bytes::new())));
225 }
226 if let Some(buf) = buf {
227 return Poll::Ready(Ok(Frame::data(buf)));
228 }
229 }
230 }
231 Eof(ref mut is_eof) => {
232 if *is_eof {
233 Poll::Ready(Ok(Frame::data(Bytes::new())))
234 } else {
235 body.read_mem(cx, 8192).map_ok(|slice| {
239 *is_eof = slice.is_empty();
240 Frame::data(slice)
241 })
242 }
243 }
244 }
245 }
246
247 #[cfg(test)]
248 async fn decode_fut<R: MemRead>(&mut self, body: &mut R) -> Result<Frame<Bytes>, io::Error> {
249 futures_util::future::poll_fn(move |cx| self.decode(cx, body)).await
250 }
251}
252
253impl fmt::Debug for Decoder {
254 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
255 fmt::Debug::fmt(&self.kind, f)
256 }
257}
258
259macro_rules! byte (
260 ($rdr:ident, $cx:expr) => ({
261 let buf = ready!($rdr.read_mem($cx, 1))?;
262 if !buf.is_empty() {
263 buf[0]
264 } else {
265 return Poll::Ready(Err(io::Error::new(io::ErrorKind::UnexpectedEof,
266 "unexpected EOF during chunk size line")));
267 }
268 })
269);
270
271macro_rules! or_overflow {
272 ($e:expr) => (
273 match $e {
274 Some(val) => val,
275 None => return Poll::Ready(Err(io::Error::new(
276 io::ErrorKind::InvalidData,
277 "invalid chunk size: overflow",
278 ))),
279 }
280 )
281}
282
283macro_rules! put_u8 {
284 ($trailers_buf:expr, $byte:expr, $limit:expr) => {
285 $trailers_buf.put_u8($byte);
286
287 if $trailers_buf.len() >= $limit {
288 return Poll::Ready(Err(io::Error::new(
289 io::ErrorKind::InvalidData,
290 "chunk trailers bytes over limit",
291 )));
292 }
293 };
294}
295
296struct StepArgs<'a> {
297 chunk_size: &'a mut u64,
298 chunk_buf: &'a mut Option<Bytes>,
299 extensions_cnt: &'a mut u64,
300 trailers_buf: &'a mut Option<BytesMut>,
301 trailers_cnt: &'a mut usize,
302 max_headers_cnt: usize,
303 max_headers_bytes: usize,
304}
305
306impl ChunkedState {
307 fn new() -> ChunkedState {
308 ChunkedState::Start
309 }
310 fn step<R: MemRead>(
311 &self,
312 cx: &mut Context<'_>,
313 body: &mut R,
314 StepArgs {
315 chunk_size,
316 chunk_buf,
317 extensions_cnt,
318 trailers_buf,
319 trailers_cnt,
320 max_headers_cnt,
321 max_headers_bytes,
322 }: StepArgs<'_>,
323 ) -> Poll<Result<ChunkedState, io::Error>> {
324 use self::ChunkedState::*;
325 match *self {
326 Start => ChunkedState::read_start(cx, body, chunk_size),
327 Size => ChunkedState::read_size(cx, body, chunk_size),
328 SizeLws => ChunkedState::read_size_lws(cx, body),
329 Extension => ChunkedState::read_extension(cx, body, extensions_cnt),
330 SizeLf => ChunkedState::read_size_lf(cx, body, *chunk_size),
331 Body => ChunkedState::read_body(cx, body, chunk_size, chunk_buf),
332 BodyCr => ChunkedState::read_body_cr(cx, body),
333 BodyLf => ChunkedState::read_body_lf(cx, body),
334 Trailer => ChunkedState::read_trailer(cx, body, trailers_buf, max_headers_bytes),
335 TrailerLf => ChunkedState::read_trailer_lf(
336 cx,
337 body,
338 trailers_buf,
339 trailers_cnt,
340 max_headers_cnt,
341 max_headers_bytes,
342 ),
343 EndCr => ChunkedState::read_end_cr(cx, body, trailers_buf, max_headers_bytes),
344 EndLf => ChunkedState::read_end_lf(cx, body, trailers_buf, max_headers_bytes),
345 End => Poll::Ready(Ok(ChunkedState::End)),
346 }
347 }
348
349 fn read_start<R: MemRead>(
350 cx: &mut Context<'_>,
351 rdr: &mut R,
352 size: &mut u64,
353 ) -> Poll<Result<ChunkedState, io::Error>> {
354 trace!("Read chunk start");
355
356 let radix = 16;
357 match byte!(rdr, cx) {
358 b @ b'0'..=b'9' => {
359 *size = or_overflow!(size.checked_mul(radix));
360 *size = or_overflow!(size.checked_add((b - b'0') as u64));
361 }
362 b @ b'a'..=b'f' => {
363 *size = or_overflow!(size.checked_mul(radix));
364 *size = or_overflow!(size.checked_add((b + 10 - b'a') as u64));
365 }
366 b @ b'A'..=b'F' => {
367 *size = or_overflow!(size.checked_mul(radix));
368 *size = or_overflow!(size.checked_add((b + 10 - b'A') as u64));
369 }
370 _ => {
371 return Poll::Ready(Err(io::Error::new(
372 io::ErrorKind::InvalidInput,
373 "Invalid chunk size line: missing size digit",
374 )));
375 }
376 }
377
378 Poll::Ready(Ok(ChunkedState::Size))
379 }
380
381 fn read_size<R: MemRead>(
382 cx: &mut Context<'_>,
383 rdr: &mut R,
384 size: &mut u64,
385 ) -> Poll<Result<ChunkedState, io::Error>> {
386 trace!("Read chunk hex size");
387
388 let radix = 16;
389 match byte!(rdr, cx) {
390 b @ b'0'..=b'9' => {
391 *size = or_overflow!(size.checked_mul(radix));
392 *size = or_overflow!(size.checked_add((b - b'0') as u64));
393 }
394 b @ b'a'..=b'f' => {
395 *size = or_overflow!(size.checked_mul(radix));
396 *size = or_overflow!(size.checked_add((b + 10 - b'a') as u64));
397 }
398 b @ b'A'..=b'F' => {
399 *size = or_overflow!(size.checked_mul(radix));
400 *size = or_overflow!(size.checked_add((b + 10 - b'A') as u64));
401 }
402 b'\t' | b' ' => return Poll::Ready(Ok(ChunkedState::SizeLws)),
403 b';' => return Poll::Ready(Ok(ChunkedState::Extension)),
404 b'\r' => return Poll::Ready(Ok(ChunkedState::SizeLf)),
405 _ => {
406 return Poll::Ready(Err(io::Error::new(
407 io::ErrorKind::InvalidInput,
408 "Invalid chunk size line: Invalid Size",
409 )));
410 }
411 }
412 Poll::Ready(Ok(ChunkedState::Size))
413 }
414 fn read_size_lws<R: MemRead>(
415 cx: &mut Context<'_>,
416 rdr: &mut R,
417 ) -> Poll<Result<ChunkedState, io::Error>> {
418 trace!("read_size_lws");
419 match byte!(rdr, cx) {
420 b'\t' | b' ' => Poll::Ready(Ok(ChunkedState::SizeLws)),
422 b';' => Poll::Ready(Ok(ChunkedState::Extension)),
423 b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)),
424 _ => Poll::Ready(Err(io::Error::new(
425 io::ErrorKind::InvalidInput,
426 "Invalid chunk size linear white space",
427 ))),
428 }
429 }
430 fn read_extension<R: MemRead>(
431 cx: &mut Context<'_>,
432 rdr: &mut R,
433 extensions_cnt: &mut u64,
434 ) -> Poll<Result<ChunkedState, io::Error>> {
435 trace!("read_extension");
436 match byte!(rdr, cx) {
443 b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)),
444 b'\n' => Poll::Ready(Err(io::Error::new(
445 io::ErrorKind::InvalidData,
446 "invalid chunk extension contains newline",
447 ))),
448 _ => {
449 *extensions_cnt += 1;
450 if *extensions_cnt >= CHUNKED_EXTENSIONS_LIMIT {
451 Poll::Ready(Err(io::Error::new(
452 io::ErrorKind::InvalidData,
453 "chunk extensions over limit",
454 )))
455 } else {
456 Poll::Ready(Ok(ChunkedState::Extension))
457 }
458 } }
460 }
461 fn read_size_lf<R: MemRead>(
462 cx: &mut Context<'_>,
463 rdr: &mut R,
464 size: u64,
465 ) -> Poll<Result<ChunkedState, io::Error>> {
466 trace!("Chunk size is {:?}", size);
467 match byte!(rdr, cx) {
468 b'\n' => {
469 if size == 0 {
470 Poll::Ready(Ok(ChunkedState::EndCr))
471 } else {
472 debug!("incoming chunked header: {0:#X} ({0} bytes)", size);
473 Poll::Ready(Ok(ChunkedState::Body))
474 }
475 }
476 _ => Poll::Ready(Err(io::Error::new(
477 io::ErrorKind::InvalidInput,
478 "Invalid chunk size LF",
479 ))),
480 }
481 }
482
483 fn read_body<R: MemRead>(
484 cx: &mut Context<'_>,
485 rdr: &mut R,
486 rem: &mut u64,
487 buf: &mut Option<Bytes>,
488 ) -> Poll<Result<ChunkedState, io::Error>> {
489 trace!("Chunked read, remaining={:?}", rem);
490
491 let to_read = usize::try_from(*rem).unwrap_or(usize::MAX);
493 let slice = ready!(rdr.read_mem(cx, to_read))?;
494 let count = slice.len();
495
496 if count == 0 {
497 *rem = 0;
498 return Poll::Ready(Err(io::Error::new(
499 io::ErrorKind::UnexpectedEof,
500 IncompleteBody,
501 )));
502 }
503 *buf = Some(slice);
504 *rem -= count as u64;
505
506 if *rem > 0 {
507 Poll::Ready(Ok(ChunkedState::Body))
508 } else {
509 Poll::Ready(Ok(ChunkedState::BodyCr))
510 }
511 }
512 fn read_body_cr<R: MemRead>(
513 cx: &mut Context<'_>,
514 rdr: &mut R,
515 ) -> Poll<Result<ChunkedState, io::Error>> {
516 match byte!(rdr, cx) {
517 b'\r' => Poll::Ready(Ok(ChunkedState::BodyLf)),
518 _ => Poll::Ready(Err(io::Error::new(
519 io::ErrorKind::InvalidInput,
520 "Invalid chunk body CR",
521 ))),
522 }
523 }
524 fn read_body_lf<R: MemRead>(
525 cx: &mut Context<'_>,
526 rdr: &mut R,
527 ) -> Poll<Result<ChunkedState, io::Error>> {
528 match byte!(rdr, cx) {
529 b'\n' => Poll::Ready(Ok(ChunkedState::Start)),
530 _ => Poll::Ready(Err(io::Error::new(
531 io::ErrorKind::InvalidInput,
532 "Invalid chunk body LF",
533 ))),
534 }
535 }
536
537 fn read_trailer<R: MemRead>(
538 cx: &mut Context<'_>,
539 rdr: &mut R,
540 trailers_buf: &mut Option<BytesMut>,
541 h1_max_header_size: usize,
542 ) -> Poll<Result<ChunkedState, io::Error>> {
543 trace!("read_trailer");
544 let byte = byte!(rdr, cx);
545
546 put_u8!(
547 trailers_buf.as_mut().expect("trailers_buf is None"),
548 byte,
549 h1_max_header_size
550 );
551
552 match byte {
553 b'\r' => Poll::Ready(Ok(ChunkedState::TrailerLf)),
554 _ => Poll::Ready(Ok(ChunkedState::Trailer)),
555 }
556 }
557
558 fn read_trailer_lf<R: MemRead>(
559 cx: &mut Context<'_>,
560 rdr: &mut R,
561 trailers_buf: &mut Option<BytesMut>,
562 trailers_cnt: &mut usize,
563 h1_max_headers: usize,
564 h1_max_header_size: usize,
565 ) -> Poll<Result<ChunkedState, io::Error>> {
566 let byte = byte!(rdr, cx);
567 match byte {
568 b'\n' => {
569 if *trailers_cnt >= h1_max_headers {
570 return Poll::Ready(Err(io::Error::new(
571 io::ErrorKind::InvalidData,
572 "chunk trailers count overflow",
573 )));
574 }
575 *trailers_cnt += 1;
576
577 put_u8!(
578 trailers_buf.as_mut().expect("trailers_buf is None"),
579 byte,
580 h1_max_header_size
581 );
582
583 Poll::Ready(Ok(ChunkedState::EndCr))
584 }
585 _ => Poll::Ready(Err(io::Error::new(
586 io::ErrorKind::InvalidInput,
587 "Invalid trailer end LF",
588 ))),
589 }
590 }
591
592 fn read_end_cr<R: MemRead>(
593 cx: &mut Context<'_>,
594 rdr: &mut R,
595 trailers_buf: &mut Option<BytesMut>,
596 h1_max_header_size: usize,
597 ) -> Poll<Result<ChunkedState, io::Error>> {
598 let byte = byte!(rdr, cx);
599 match byte {
600 b'\r' => {
601 if let Some(trailers_buf) = trailers_buf {
602 put_u8!(trailers_buf, byte, h1_max_header_size);
603 }
604 Poll::Ready(Ok(ChunkedState::EndLf))
605 }
606 byte => {
607 match trailers_buf {
608 None => {
609 let mut buf = BytesMut::with_capacity(64);
611 buf.put_u8(byte);
612 *trailers_buf = Some(buf);
613 }
614 Some(ref mut trailers_buf) => {
615 put_u8!(trailers_buf, byte, h1_max_header_size);
616 }
617 }
618
619 Poll::Ready(Ok(ChunkedState::Trailer))
620 }
621 }
622 }
623 fn read_end_lf<R: MemRead>(
624 cx: &mut Context<'_>,
625 rdr: &mut R,
626 trailers_buf: &mut Option<BytesMut>,
627 h1_max_header_size: usize,
628 ) -> Poll<Result<ChunkedState, io::Error>> {
629 let byte = byte!(rdr, cx);
630 match byte {
631 b'\n' => {
632 if let Some(trailers_buf) = trailers_buf {
633 put_u8!(trailers_buf, byte, h1_max_header_size);
634 }
635 Poll::Ready(Ok(ChunkedState::End))
636 }
637 _ => Poll::Ready(Err(io::Error::new(
638 io::ErrorKind::InvalidInput,
639 "Invalid chunk end LF",
640 ))),
641 }
642 }
643}
644
645fn decode_trailers(buf: &mut BytesMut, count: usize) -> Result<HeaderMap, io::Error> {
647 let mut trailers = HeaderMap::new();
648 let mut headers = vec![httparse::EMPTY_HEADER; count];
649 let res = httparse::parse_headers(buf, &mut headers);
650 match res {
651 Ok(httparse::Status::Complete((_, headers))) => {
652 for header in headers.iter() {
653 use std::convert::TryFrom;
654 let name = match HeaderName::try_from(header.name) {
655 Ok(name) => name,
656 Err(_) => {
657 return Err(io::Error::new(
658 io::ErrorKind::InvalidInput,
659 format!("Invalid header name: {:?}", &header),
660 ));
661 }
662 };
663
664 let value = match HeaderValue::from_bytes(header.value) {
665 Ok(value) => value,
666 Err(_) => {
667 return Err(io::Error::new(
668 io::ErrorKind::InvalidInput,
669 format!("Invalid header value: {:?}", &header),
670 ));
671 }
672 };
673
674 trailers.insert(name, value);
675 }
676
677 Ok(trailers)
678 }
679 Ok(httparse::Status::Partial) => Err(io::Error::new(
680 io::ErrorKind::InvalidInput,
681 "Partial header",
682 )),
683 Err(e) => Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
684 }
685}
686
687#[derive(Debug)]
688struct IncompleteBody;
689
690impl fmt::Display for IncompleteBody {
691 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
692 write!(f, "end of file before message length reached")
693 }
694}
695
696impl StdError for IncompleteBody {}
697
698#[cfg(test)]
699mod tests {
700 use super::*;
701 use crate::rt::{Read, ReadBuf};
702 use std::pin::Pin;
703 use std::time::Duration;
704
705 impl MemRead for &[u8] {
706 fn read_mem(&mut self, _: &mut Context<'_>, len: usize) -> Poll<io::Result<Bytes>> {
707 let n = std::cmp::min(len, self.len());
708 if n > 0 {
709 let (a, b) = self.split_at(n);
710 let buf = Bytes::copy_from_slice(a);
711 *self = b;
712 Poll::Ready(Ok(buf))
713 } else {
714 Poll::Ready(Ok(Bytes::new()))
715 }
716 }
717 }
718
719 impl MemRead for &mut (dyn Read + Unpin) {
720 fn read_mem(&mut self, cx: &mut Context<'_>, len: usize) -> Poll<io::Result<Bytes>> {
721 let mut v = vec![0; len];
722 let mut buf = ReadBuf::new(&mut v);
723 ready!(Pin::new(self).poll_read(cx, buf.unfilled())?);
724 Poll::Ready(Ok(Bytes::copy_from_slice(buf.filled())))
725 }
726 }
727
728 impl MemRead for Bytes {
729 fn read_mem(&mut self, _: &mut Context<'_>, len: usize) -> Poll<io::Result<Bytes>> {
730 let n = std::cmp::min(len, self.len());
731 let ret = self.split_to(n);
732 Poll::Ready(Ok(ret))
733 }
734 }
735
736 #[cfg(not(miri))]
747 #[tokio::test]
748 async fn test_read_chunk_size() {
749 use std::io::ErrorKind::{InvalidData, InvalidInput, UnexpectedEof};
750
751 async fn read(s: &str) -> u64 {
752 let mut state = ChunkedState::new();
753 let rdr = &mut s.as_bytes();
754 let mut size = 0;
755 let mut ext_cnt = 0;
756 let mut trailers_cnt = 0;
757 loop {
758 let result = futures_util::future::poll_fn(|cx| {
759 state.step(
760 cx,
761 rdr,
762 StepArgs {
763 chunk_size: &mut size,
764 extensions_cnt: &mut ext_cnt,
765 chunk_buf: &mut None,
766 trailers_buf: &mut None,
767 trailers_cnt: &mut trailers_cnt,
768 max_headers_cnt: DEFAULT_MAX_HEADERS,
769 max_headers_bytes: TRAILER_LIMIT,
770 },
771 )
772 })
773 .await;
774 let desc = format!("read_size failed for {:?}", s);
775 state = result.expect(&desc);
776 if state == ChunkedState::Body || state == ChunkedState::EndCr {
777 break;
778 }
779 }
780 size
781 }
782
783 async fn read_err(s: &str, expected_err: io::ErrorKind) {
784 let mut state = ChunkedState::new();
785 let rdr = &mut s.as_bytes();
786 let mut size = 0;
787 let mut ext_cnt = 0;
788 let mut trailers_cnt = 0;
789 loop {
790 let result = futures_util::future::poll_fn(|cx| {
791 state.step(
792 cx,
793 rdr,
794 StepArgs {
795 chunk_size: &mut size,
796 extensions_cnt: &mut ext_cnt,
797 chunk_buf: &mut None,
798 trailers_buf: &mut None,
799 trailers_cnt: &mut trailers_cnt,
800 max_headers_cnt: DEFAULT_MAX_HEADERS,
801 max_headers_bytes: TRAILER_LIMIT,
802 },
803 )
804 })
805 .await;
806 state = match result {
807 Ok(s) => s,
808 Err(e) => {
809 assert!(
810 expected_err == e.kind(),
811 "Reading {:?}, expected {:?}, but got {:?}",
812 s,
813 expected_err,
814 e.kind()
815 );
816 return;
817 }
818 };
819 if state == ChunkedState::Body || state == ChunkedState::End {
820 panic!("Was Ok. Expected Err for {:?}", s);
821 }
822 }
823 }
824
825 assert_eq!(1, read("1\r\n").await);
826 assert_eq!(1, read("01\r\n").await);
827 assert_eq!(0, read("0\r\n").await);
828 assert_eq!(0, read("00\r\n").await);
829 assert_eq!(10, read("A\r\n").await);
830 assert_eq!(10, read("a\r\n").await);
831 assert_eq!(255, read("Ff\r\n").await);
832 assert_eq!(255, read("Ff \r\n").await);
833 read_err("F\rF", InvalidInput).await;
835 read_err("F", UnexpectedEof).await;
836 read_err("\r\n\r\n", InvalidInput).await;
838 read_err("\r\n", InvalidInput).await;
839 read_err("X\r\n", InvalidInput).await;
841 read_err("1X\r\n", InvalidInput).await;
842 read_err("-\r\n", InvalidInput).await;
843 read_err("-1\r\n", InvalidInput).await;
844 assert_eq!(1, read("1;extension\r\n").await);
846 assert_eq!(10, read("a;ext name=value\r\n").await);
847 assert_eq!(1, read("1;extension;extension2\r\n").await);
848 assert_eq!(1, read("1;;; ;\r\n").await);
849 assert_eq!(2, read("2; extension...\r\n").await);
850 assert_eq!(3, read("3 ; extension=123\r\n").await);
851 assert_eq!(3, read("3 ;\r\n").await);
852 assert_eq!(3, read("3 ; \r\n").await);
853 read_err("1 invalid extension\r\n", InvalidInput).await;
855 read_err("1 A\r\n", InvalidInput).await;
856 read_err("1;no CRLF", UnexpectedEof).await;
857 read_err("1;reject\nnewlines\r\n", InvalidData).await;
858 read_err("f0000000000000003\r\n", InvalidData).await;
860 }
861
862 #[cfg(not(miri))]
863 #[tokio::test]
864 async fn test_read_sized_early_eof() {
865 let mut bytes = &b"foo bar"[..];
866 let mut decoder = Decoder::length(10);
867 assert_eq!(
868 decoder
869 .decode_fut(&mut bytes)
870 .await
871 .unwrap()
872 .data_ref()
873 .unwrap()
874 .len(),
875 7
876 );
877 let e = decoder.decode_fut(&mut bytes).await.unwrap_err();
878 assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof);
879 }
880
881 #[cfg(not(miri))]
882 #[tokio::test]
883 async fn test_read_chunked_early_eof() {
884 let mut bytes = &b"\
885 9\r\n\
886 foo bar\
887 "[..];
888 let mut decoder = Decoder::chunked(None, None);
889 assert_eq!(
890 decoder
891 .decode_fut(&mut bytes)
892 .await
893 .unwrap()
894 .data_ref()
895 .unwrap()
896 .len(),
897 7
898 );
899 let e = decoder.decode_fut(&mut bytes).await.unwrap_err();
900 assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof);
901 }
902
903 #[cfg(not(miri))]
904 #[tokio::test]
905 async fn test_read_chunked_single_read() {
906 let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\n"[..];
907 let buf = Decoder::chunked(None, None)
908 .decode_fut(&mut mock_buf)
909 .await
910 .expect("decode")
911 .into_data()
912 .expect("unknown frame type");
913 assert_eq!(16, buf.len());
914 let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String");
915 assert_eq!("1234567890abcdef", &result);
916 }
917
918 #[tokio::test]
919 async fn test_read_chunked_with_missing_zero_digit() {
920 let mut mock_buf = &b"1\r\nZ\r\n\r\n\r\n"[..];
922 let mut decoder = Decoder::chunked(None, None);
923 let buf = decoder
924 .decode_fut(&mut mock_buf)
925 .await
926 .expect("decode")
927 .into_data()
928 .expect("unknown frame type");
929 assert_eq!("Z", buf);
930
931 let err = decoder
932 .decode_fut(&mut mock_buf)
933 .await
934 .expect_err("decode 2");
935 assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
936 }
937
938 #[tokio::test]
939 async fn test_read_chunked_extensions_over_limit() {
940 let per_chunk = super::CHUNKED_EXTENSIONS_LIMIT * 2 / 3;
943 let mut scratch = vec![];
944 for _ in 0..2 {
945 scratch.extend(b"1;");
946 scratch.extend(b"x".repeat(per_chunk as usize));
947 scratch.extend(b"\r\nA\r\n");
948 }
949 scratch.extend(b"0\r\n\r\n");
950 let mut mock_buf = Bytes::from(scratch);
951
952 let mut decoder = Decoder::chunked(None, None);
953 let buf1 = decoder
954 .decode_fut(&mut mock_buf)
955 .await
956 .expect("decode1")
957 .into_data()
958 .expect("unknown frame type");
959 assert_eq!(&buf1[..], b"A");
960
961 let err = decoder
962 .decode_fut(&mut mock_buf)
963 .await
964 .expect_err("decode2");
965 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
966 assert_eq!(err.to_string(), "chunk extensions over limit");
967 }
968
969 #[cfg(not(miri))]
970 #[tokio::test]
971 async fn test_read_chunked_trailer_with_missing_lf() {
972 let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\nbad\r\r\n"[..];
973 let mut decoder = Decoder::chunked(None, None);
974 decoder.decode_fut(&mut mock_buf).await.expect("decode");
975 let e = decoder.decode_fut(&mut mock_buf).await.unwrap_err();
976 assert_eq!(e.kind(), io::ErrorKind::InvalidInput);
977 }
978
979 #[cfg(not(miri))]
980 #[tokio::test]
981 async fn test_read_chunked_after_eof() {
982 let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\n\r\n"[..];
983 let mut decoder = Decoder::chunked(None, None);
984
985 let buf = decoder
987 .decode_fut(&mut mock_buf)
988 .await
989 .unwrap()
990 .into_data()
991 .expect("unknown frame type");
992 assert_eq!(16, buf.len());
993 let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String");
994 assert_eq!("1234567890abcdef", &result);
995
996 let buf = decoder
998 .decode_fut(&mut mock_buf)
999 .await
1000 .expect("decode")
1001 .into_data()
1002 .expect("unknown frame type");
1003 assert_eq!(0, buf.len());
1004
1005 let buf = decoder
1007 .decode_fut(&mut mock_buf)
1008 .await
1009 .expect("decode")
1010 .into_data()
1011 .expect("unknown frame type");
1012 assert_eq!(0, buf.len());
1013 }
1014
1015 async fn read_async(mut decoder: Decoder, content: &[u8], block_at: usize) -> String {
1018 let mut outs = Vec::new();
1019
1020 let mut ins = crate::common::io::Compat::new(if block_at == 0 {
1021 tokio_test::io::Builder::new()
1022 .wait(Duration::from_millis(10))
1023 .read(content)
1024 .build()
1025 } else {
1026 tokio_test::io::Builder::new()
1027 .read(&content[..block_at])
1028 .wait(Duration::from_millis(10))
1029 .read(&content[block_at..])
1030 .build()
1031 });
1032
1033 let mut ins = &mut ins as &mut (dyn Read + Unpin);
1034
1035 loop {
1036 let buf = decoder
1037 .decode_fut(&mut ins)
1038 .await
1039 .expect("unexpected decode error")
1040 .into_data()
1041 .expect("unexpected frame type");
1042 if buf.is_empty() {
1043 break; }
1045 outs.extend(buf.as_ref());
1046 }
1047
1048 String::from_utf8(outs).expect("decode String")
1049 }
1050
1051 async fn all_async_cases(content: &str, expected: &str, decoder: Decoder) {
1054 let content_len = content.len();
1055 for block_at in 0..content_len {
1056 let actual = read_async(decoder.clone(), content.as_bytes(), block_at).await;
1057 assert_eq!(expected, &actual) }
1059 }
1060
1061 #[cfg(not(miri))]
1062 #[tokio::test]
1063 async fn test_read_length_async() {
1064 let content = "foobar";
1065 all_async_cases(content, content, Decoder::length(content.len() as u64)).await;
1066 }
1067
1068 #[cfg(not(miri))]
1069 #[tokio::test]
1070 async fn test_read_chunked_async() {
1071 let content = "3\r\nfoo\r\n3\r\nbar\r\n0\r\n\r\n";
1072 let expected = "foobar";
1073 all_async_cases(content, expected, Decoder::chunked(None, None)).await;
1074 }
1075
1076 #[cfg(not(miri))]
1077 #[tokio::test]
1078 async fn test_read_eof_async() {
1079 let content = "foobar";
1080 all_async_cases(content, content, Decoder::eof()).await;
1081 }
1082
1083 #[cfg(all(feature = "nightly", not(miri)))]
1084 #[bench]
1085 fn bench_decode_chunked_1kb(b: &mut test::Bencher) {
1086 let rt = new_runtime();
1087
1088 const LEN: usize = 1024;
1089 let mut vec = Vec::new();
1090 vec.extend(format!("{:x}\r\n", LEN).as_bytes());
1091 vec.extend(&[0; LEN][..]);
1092 vec.extend(b"\r\n");
1093 let content = Bytes::from(vec);
1094
1095 b.bytes = LEN as u64;
1096
1097 b.iter(|| {
1098 let mut decoder = Decoder::chunked(None, None);
1099 rt.block_on(async {
1100 let mut raw = content.clone();
1101 let chunk = decoder
1102 .decode_fut(&mut raw)
1103 .await
1104 .unwrap()
1105 .into_data()
1106 .unwrap();
1107 assert_eq!(chunk.len(), LEN);
1108 });
1109 });
1110 }
1111
1112 #[cfg(all(feature = "nightly", not(miri)))]
1113 #[bench]
1114 fn bench_decode_length_1kb(b: &mut test::Bencher) {
1115 let rt = new_runtime();
1116
1117 const LEN: usize = 1024;
1118 let content = Bytes::from(&[0; LEN][..]);
1119 b.bytes = LEN as u64;
1120
1121 b.iter(|| {
1122 let mut decoder = Decoder::length(LEN as u64);
1123 rt.block_on(async {
1124 let mut raw = content.clone();
1125 let chunk = decoder
1126 .decode_fut(&mut raw)
1127 .await
1128 .unwrap()
1129 .into_data()
1130 .unwrap();
1131 assert_eq!(chunk.len(), LEN);
1132 });
1133 });
1134 }
1135
1136 #[cfg(feature = "nightly")]
1137 fn new_runtime() -> tokio::runtime::Runtime {
1138 tokio::runtime::Builder::new_current_thread()
1139 .enable_all()
1140 .build()
1141 .expect("rt build")
1142 }
1143
1144 #[test]
1145 fn test_decode_trailers() {
1146 let mut buf = BytesMut::new();
1147 buf.extend_from_slice(
1148 b"Expires: Wed, 21 Oct 2015 07:28:00 GMT\r\nX-Stream-Error: failed to decode\r\n\r\n",
1149 );
1150 let headers = decode_trailers(&mut buf, 2).expect("decode_trailers");
1151 assert_eq!(headers.len(), 2);
1152 assert_eq!(
1153 headers.get("Expires").unwrap(),
1154 "Wed, 21 Oct 2015 07:28:00 GMT"
1155 );
1156 assert_eq!(headers.get("X-Stream-Error").unwrap(), "failed to decode");
1157 }
1158
1159 #[tokio::test]
1160 async fn test_trailer_max_headers_enforced() {
1161 let h1_max_headers = 10;
1162 let mut scratch = vec![];
1163 scratch.extend(b"10\r\n1234567890abcdef\r\n0\r\n");
1164 for i in 0..h1_max_headers {
1165 scratch.extend(format!("trailer{}: {}\r\n", i, i).as_bytes());
1166 }
1167 scratch.extend(b"\r\n");
1168 let mut mock_buf = Bytes::from(scratch);
1169
1170 let mut decoder = Decoder::chunked(Some(h1_max_headers), None);
1171
1172 let buf = decoder
1174 .decode_fut(&mut mock_buf)
1175 .await
1176 .unwrap()
1177 .into_data()
1178 .expect("unknown frame type");
1179 assert_eq!(16, buf.len());
1180
1181 let err = decoder
1183 .decode_fut(&mut mock_buf)
1184 .await
1185 .expect_err("trailer fields over limit");
1186 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
1187 }
1188
1189 #[tokio::test]
1190 async fn test_trailer_max_header_size_huge_trailer() {
1191 let max_header_size = 1024;
1192 let mut scratch = vec![];
1193 scratch.extend(b"10\r\n1234567890abcdef\r\n0\r\n");
1194 scratch.extend(format!("huge_trailer: {}\r\n", "x".repeat(max_header_size)).as_bytes());
1195 scratch.extend(b"\r\n");
1196 let mut mock_buf = Bytes::from(scratch);
1197
1198 let mut decoder = Decoder::chunked(None, Some(max_header_size));
1199
1200 let buf = decoder
1202 .decode_fut(&mut mock_buf)
1203 .await
1204 .unwrap()
1205 .into_data()
1206 .expect("unknown frame type");
1207 assert_eq!(16, buf.len());
1208
1209 let err = decoder
1211 .decode_fut(&mut mock_buf)
1212 .await
1213 .expect_err("trailers over limit");
1214 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
1215 }
1216
1217 #[tokio::test]
1218 async fn test_trailer_max_header_size_many_small_trailers() {
1219 let max_headers = 10;
1220 let header_size = 64;
1221 let mut scratch = vec![];
1222 scratch.extend(b"10\r\n1234567890abcdef\r\n0\r\n");
1223
1224 for i in 0..max_headers {
1225 scratch.extend(format!("trailer{}: {}\r\n", i, "x".repeat(header_size)).as_bytes());
1226 }
1227
1228 scratch.extend(b"\r\n");
1229 let mut mock_buf = Bytes::from(scratch);
1230
1231 let mut decoder = Decoder::chunked(None, Some(max_headers * header_size));
1232
1233 let buf = decoder
1235 .decode_fut(&mut mock_buf)
1236 .await
1237 .unwrap()
1238 .into_data()
1239 .expect("unknown frame type");
1240 assert_eq!(16, buf.len());
1241
1242 let err = decoder
1244 .decode_fut(&mut mock_buf)
1245 .await
1246 .expect_err("trailers over limit");
1247 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
1248 }
1249}