1use super::{header::BytesStr, huffman, Header};
2use crate::frame;
3
4use bytes::{Buf, Bytes, BytesMut};
5use http::header;
6use http::method::{self, Method};
7use http::status::{self, StatusCode};
8
9use std::cmp;
10use std::collections::VecDeque;
11use std::io::Cursor;
12use std::ops::ControlFlow;
13use std::str::Utf8Error;
14
15#[derive(Debug)]
17pub struct Decoder {
18 max_size_update: Option<usize>,
20 last_max_update: usize,
21 table: Table,
22 buffer: BytesMut,
23}
24
25#[derive(Debug, Copy, Clone, PartialEq, Eq)]
28pub enum DecoderError {
29 InvalidRepresentation,
30 InvalidIntegerPrefix,
31 InvalidTableIndex,
32 InvalidHuffmanCode,
33 InvalidUtf8,
34 InvalidStatusCode,
35 InvalidPseudoheader,
36 InvalidMaxDynamicSize,
37 IntegerOverflow,
38 NeedMore(NeedMore),
39}
40
41#[derive(Debug, Copy, Clone, PartialEq, Eq)]
42pub enum NeedMore {
43 UnexpectedEndOfStream,
44 IntegerUnderflow,
45 StringUnderflow,
46}
47
48enum Representation {
49 Indexed,
63
64 LiteralWithIndexing,
83
84 LiteralWithoutIndexing,
103
104 LiteralNeverIndexed,
122
123 SizeUpdate,
137}
138
139#[derive(Debug)]
140struct Table {
141 entries: VecDeque<Header>,
142 size: usize,
143 max_size: usize,
144}
145
146struct StringMarker {
147 offset: usize,
148 len: usize,
149 string: Option<Bytes>,
150}
151
152impl Decoder {
155 pub fn new(size: usize) -> Decoder {
157 Decoder {
158 max_size_update: None,
159 last_max_update: size,
160 table: Table::new(size),
161 buffer: BytesMut::with_capacity(4096),
162 }
163 }
164
165 #[allow(dead_code)]
167 pub fn queue_size_update(&mut self, size: usize) {
168 let size = match self.max_size_update {
169 Some(v) => cmp::max(v, size),
170 None => size,
171 };
172
173 self.max_size_update = Some(size);
174 }
175
176 pub fn decode<F>(
178 &mut self,
179 src: &mut Cursor<&mut BytesMut>,
180 mut f: F,
181 ) -> Result<(), DecoderError>
182 where
183 F: FnMut(Header) -> ControlFlow<()>,
184 {
185 use self::Representation::*;
186
187 let mut can_resize = true;
188
189 if let Some(size) = self.max_size_update.take() {
190 self.last_max_update = size;
191 }
192
193 let span = tracing::trace_span!("hpack::decode");
194 let _e = span.enter();
195
196 tracing::trace!("decode");
197
198 while let Some(ty) = peek_u8(src) {
199 match Representation::load(ty)? {
203 Indexed => {
204 tracing::trace!(rem = src.remaining(), kind = %"Indexed");
205 can_resize = false;
206 let entry = self.decode_indexed(src)?;
207 consume(src);
208 if f(entry).is_break() {
209 break;
210 }
211 }
212 LiteralWithIndexing => {
213 tracing::trace!(rem = src.remaining(), kind = %"LiteralWithIndexing");
214 can_resize = false;
215 let entry = self.decode_literal(src, true)?;
216
217 self.table.insert(entry.clone());
219 consume(src);
220
221 if f(entry).is_break() {
222 break;
223 }
224 }
225 LiteralWithoutIndexing => {
226 tracing::trace!(rem = src.remaining(), kind = %"LiteralWithoutIndexing");
227 can_resize = false;
228 let entry = self.decode_literal(src, false)?;
229 consume(src);
230 if f(entry).is_break() {
231 break;
232 }
233 }
234 LiteralNeverIndexed => {
235 tracing::trace!(rem = src.remaining(), kind = %"LiteralNeverIndexed");
236 can_resize = false;
237 let entry = self.decode_literal(src, false)?;
238 consume(src);
239
240 if f(entry).is_break() {
243 break;
244 }
245 }
246 SizeUpdate => {
247 tracing::trace!(rem = src.remaining(), kind = %"SizeUpdate");
248 if !can_resize {
249 return Err(DecoderError::InvalidMaxDynamicSize);
250 }
251
252 self.process_size_update(src)?;
254 consume(src);
255 }
256 }
257 }
258
259 Ok(())
260 }
261
262 fn process_size_update(&mut self, buf: &mut Cursor<&mut BytesMut>) -> Result<(), DecoderError> {
263 let new_size = decode_int(buf, 5)?;
264
265 if new_size > self.last_max_update {
266 return Err(DecoderError::InvalidMaxDynamicSize);
267 }
268
269 tracing::debug!(
270 from = self.table.size(),
271 to = new_size,
272 "Decoder changed max table size"
273 );
274
275 self.table.set_max_size(new_size);
276
277 Ok(())
278 }
279
280 fn decode_indexed(&self, buf: &mut Cursor<&mut BytesMut>) -> Result<Header, DecoderError> {
281 let index = decode_int(buf, 7)?;
282 self.table.get(index)
283 }
284
285 fn decode_literal(
286 &mut self,
287 buf: &mut Cursor<&mut BytesMut>,
288 index: bool,
289 ) -> Result<Header, DecoderError> {
290 let prefix = if index { 6 } else { 4 };
291
292 let table_idx = decode_int(buf, prefix)?;
294
295 if table_idx == 0 {
297 let old_pos = buf.position();
298 let name_marker = self.try_decode_string(buf)?;
299 let value_marker = self.try_decode_string(buf)?;
300 buf.set_position(old_pos);
301 let name = name_marker.consume(buf);
303 let value = value_marker.consume(buf);
304 Header::new(name, value)
305 } else {
306 let e = self.table.get(table_idx)?;
307 let value = self.decode_string(buf)?;
308
309 e.name().into_entry(value)
310 }
311 }
312
313 fn try_decode_string(
314 &mut self,
315 buf: &mut Cursor<&mut BytesMut>,
316 ) -> Result<StringMarker, DecoderError> {
317 let old_pos = buf.position();
318 const HUFF_FLAG: u8 = 0b1000_0000;
319
320 let huff = match peek_u8(buf) {
322 Some(hdr) => (hdr & HUFF_FLAG) == HUFF_FLAG,
323 None => return Err(DecoderError::NeedMore(NeedMore::UnexpectedEndOfStream)),
324 };
325
326 let len = decode_int(buf, 7)?;
328
329 if len > buf.remaining() {
330 tracing::trace!(len, remaining = buf.remaining(), "decode_string underflow",);
331 return Err(DecoderError::NeedMore(NeedMore::StringUnderflow));
332 }
333
334 let offset = (buf.position() - old_pos) as usize;
335 if huff {
336 let ret = {
337 let raw = &buf.chunk()[..len];
338 huffman::decode(raw, &mut self.buffer).map(|buf| StringMarker {
339 offset,
340 len,
341 string: Some(BytesMut::freeze(buf)),
342 })
343 };
344
345 buf.advance(len);
346 ret
347 } else {
348 buf.advance(len);
349 Ok(StringMarker {
350 offset,
351 len,
352 string: None,
353 })
354 }
355 }
356
357 fn decode_string(&mut self, buf: &mut Cursor<&mut BytesMut>) -> Result<Bytes, DecoderError> {
358 let old_pos = buf.position();
359 let marker = self.try_decode_string(buf)?;
360 buf.set_position(old_pos);
361 Ok(marker.consume(buf))
362 }
363}
364
365impl Default for Decoder {
366 fn default() -> Decoder {
367 Decoder::new(4096)
368 }
369}
370
371impl Representation {
374 pub fn load(byte: u8) -> Result<Representation, DecoderError> {
375 const INDEXED: u8 = 0b1000_0000;
376 const LITERAL_WITH_INDEXING: u8 = 0b0100_0000;
377 const LITERAL_WITHOUT_INDEXING: u8 = 0b1111_0000;
378 const LITERAL_NEVER_INDEXED: u8 = 0b0001_0000;
379 const SIZE_UPDATE_MASK: u8 = 0b1110_0000;
380 const SIZE_UPDATE: u8 = 0b0010_0000;
381
382 if byte & INDEXED == INDEXED {
385 Ok(Representation::Indexed)
386 } else if byte & LITERAL_WITH_INDEXING == LITERAL_WITH_INDEXING {
387 Ok(Representation::LiteralWithIndexing)
388 } else if byte & LITERAL_WITHOUT_INDEXING == 0 {
389 Ok(Representation::LiteralWithoutIndexing)
390 } else if byte & LITERAL_WITHOUT_INDEXING == LITERAL_NEVER_INDEXED {
391 Ok(Representation::LiteralNeverIndexed)
392 } else if byte & SIZE_UPDATE_MASK == SIZE_UPDATE {
393 Ok(Representation::SizeUpdate)
394 } else {
395 Err(DecoderError::InvalidRepresentation)
396 }
397 }
398}
399
400fn decode_int<B: Buf>(buf: &mut B, prefix_size: u8) -> Result<usize, DecoderError> {
401 const MAX_BYTES: usize = 5;
405 const VARINT_MASK: u8 = 0b0111_1111;
406 const VARINT_FLAG: u8 = 0b1000_0000;
407
408 if prefix_size < 1 || prefix_size > 8 {
409 return Err(DecoderError::InvalidIntegerPrefix);
410 }
411
412 if !buf.has_remaining() {
413 return Err(DecoderError::NeedMore(NeedMore::IntegerUnderflow));
414 }
415
416 let mask = if prefix_size == 8 {
417 0xFF
418 } else {
419 (1u8 << prefix_size).wrapping_sub(1)
420 };
421
422 let mut ret = (buf.get_u8() & mask) as usize;
423
424 if ret < mask as usize {
425 return Ok(ret);
427 }
428
429 let mut bytes = 1;
434
435 let mut shift = 0;
438
439 while buf.has_remaining() {
440 let b = buf.get_u8();
441
442 bytes += 1;
443 ret += ((b & VARINT_MASK) as usize) << shift;
444 shift += 7;
445
446 if b & VARINT_FLAG == 0 {
447 return Ok(ret);
448 }
449
450 if bytes == MAX_BYTES {
451 return Err(DecoderError::IntegerOverflow);
453 }
454 }
455
456 Err(DecoderError::NeedMore(NeedMore::IntegerUnderflow))
457}
458
459fn peek_u8<B: Buf>(buf: &B) -> Option<u8> {
460 if buf.has_remaining() {
461 Some(buf.chunk()[0])
462 } else {
463 None
464 }
465}
466
467fn take(buf: &mut Cursor<&mut BytesMut>, n: usize) -> Bytes {
468 let pos = buf.position() as usize;
469 let mut head = buf.get_mut().split_to(pos + n);
470 buf.set_position(0);
471 head.advance(pos);
472 head.freeze()
473}
474
475impl StringMarker {
476 fn consume(self, buf: &mut Cursor<&mut BytesMut>) -> Bytes {
477 buf.advance(self.offset);
478 match self.string {
479 Some(string) => {
480 buf.advance(self.len);
481 string
482 }
483 None => take(buf, self.len),
484 }
485 }
486}
487
488fn consume(buf: &mut Cursor<&mut BytesMut>) {
489 take(buf, 0);
493}
494
495impl Table {
498 fn new(max_size: usize) -> Table {
499 Table {
500 entries: VecDeque::new(),
501 size: 0,
502 max_size,
503 }
504 }
505
506 fn size(&self) -> usize {
507 self.size
508 }
509
510 pub fn get(&self, index: usize) -> Result<Header, DecoderError> {
519 if index == 0 {
520 return Err(DecoderError::InvalidTableIndex);
521 }
522
523 if index <= 61 {
524 return Ok(get_static(index));
525 }
526
527 match self.entries.get(index - 62) {
529 Some(e) => Ok(e.clone()),
530 None => Err(DecoderError::InvalidTableIndex),
531 }
532 }
533
534 fn insert(&mut self, entry: Header) {
535 let len = entry.len();
536
537 self.reserve(len);
538
539 if self.size + len <= self.max_size {
540 self.size += len;
541
542 self.entries.push_front(entry);
544 }
545 }
546
547 fn set_max_size(&mut self, size: usize) {
548 self.max_size = size;
549 self.consolidate();
551 }
552
553 fn reserve(&mut self, size: usize) {
554 while self.size + size > self.max_size {
555 match self.entries.pop_back() {
556 Some(last) => {
557 self.size -= last.len();
558 }
559 None => return,
560 }
561 }
562 }
563
564 fn consolidate(&mut self) {
565 while self.size > self.max_size {
566 {
567 let last = match self.entries.back() {
568 Some(x) => x,
569 None => {
570 panic!("Size of table != 0, but no headers left!");
573 }
574 };
575
576 self.size -= last.len();
577 }
578
579 self.entries.pop_back();
580 }
581 }
582}
583
584impl From<Utf8Error> for DecoderError {
587 fn from(_: Utf8Error) -> DecoderError {
588 DecoderError::InvalidUtf8
590 }
591}
592
593impl From<header::InvalidHeaderValue> for DecoderError {
594 fn from(_: header::InvalidHeaderValue) -> DecoderError {
595 DecoderError::InvalidUtf8
597 }
598}
599
600impl From<header::InvalidHeaderName> for DecoderError {
601 fn from(_: header::InvalidHeaderName) -> DecoderError {
602 DecoderError::InvalidUtf8
604 }
605}
606
607impl From<method::InvalidMethod> for DecoderError {
608 fn from(_: method::InvalidMethod) -> DecoderError {
609 DecoderError::InvalidUtf8
611 }
612}
613
614impl From<status::InvalidStatusCode> for DecoderError {
615 fn from(_: status::InvalidStatusCode) -> DecoderError {
616 DecoderError::InvalidUtf8
618 }
619}
620
621impl From<DecoderError> for frame::Error {
622 fn from(src: DecoderError) -> Self {
623 frame::Error::Hpack(src)
624 }
625}
626
627pub fn get_static(idx: usize) -> Header {
629 use http::header::HeaderValue;
630
631 match idx {
632 1 => Header::Authority(BytesStr::from_static("")),
633 2 => Header::Method(Method::GET),
634 3 => Header::Method(Method::POST),
635 4 => Header::Path(BytesStr::from_static("/")),
636 5 => Header::Path(BytesStr::from_static("/index.html")),
637 6 => Header::Scheme(BytesStr::from_static("http")),
638 7 => Header::Scheme(BytesStr::from_static("https")),
639 8 => Header::Status(StatusCode::OK),
640 9 => Header::Status(StatusCode::NO_CONTENT),
641 10 => Header::Status(StatusCode::PARTIAL_CONTENT),
642 11 => Header::Status(StatusCode::NOT_MODIFIED),
643 12 => Header::Status(StatusCode::BAD_REQUEST),
644 13 => Header::Status(StatusCode::NOT_FOUND),
645 14 => Header::Status(StatusCode::INTERNAL_SERVER_ERROR),
646 15 => Header::Field {
647 name: header::ACCEPT_CHARSET,
648 value: HeaderValue::from_static(""),
649 },
650 16 => Header::Field {
651 name: header::ACCEPT_ENCODING,
652 value: HeaderValue::from_static("gzip, deflate"),
653 },
654 17 => Header::Field {
655 name: header::ACCEPT_LANGUAGE,
656 value: HeaderValue::from_static(""),
657 },
658 18 => Header::Field {
659 name: header::ACCEPT_RANGES,
660 value: HeaderValue::from_static(""),
661 },
662 19 => Header::Field {
663 name: header::ACCEPT,
664 value: HeaderValue::from_static(""),
665 },
666 20 => Header::Field {
667 name: header::ACCESS_CONTROL_ALLOW_ORIGIN,
668 value: HeaderValue::from_static(""),
669 },
670 21 => Header::Field {
671 name: header::AGE,
672 value: HeaderValue::from_static(""),
673 },
674 22 => Header::Field {
675 name: header::ALLOW,
676 value: HeaderValue::from_static(""),
677 },
678 23 => Header::Field {
679 name: header::AUTHORIZATION,
680 value: HeaderValue::from_static(""),
681 },
682 24 => Header::Field {
683 name: header::CACHE_CONTROL,
684 value: HeaderValue::from_static(""),
685 },
686 25 => Header::Field {
687 name: header::CONTENT_DISPOSITION,
688 value: HeaderValue::from_static(""),
689 },
690 26 => Header::Field {
691 name: header::CONTENT_ENCODING,
692 value: HeaderValue::from_static(""),
693 },
694 27 => Header::Field {
695 name: header::CONTENT_LANGUAGE,
696 value: HeaderValue::from_static(""),
697 },
698 28 => Header::Field {
699 name: header::CONTENT_LENGTH,
700 value: HeaderValue::from_static(""),
701 },
702 29 => Header::Field {
703 name: header::CONTENT_LOCATION,
704 value: HeaderValue::from_static(""),
705 },
706 30 => Header::Field {
707 name: header::CONTENT_RANGE,
708 value: HeaderValue::from_static(""),
709 },
710 31 => Header::Field {
711 name: header::CONTENT_TYPE,
712 value: HeaderValue::from_static(""),
713 },
714 32 => Header::Field {
715 name: header::COOKIE,
716 value: HeaderValue::from_static(""),
717 },
718 33 => Header::Field {
719 name: header::DATE,
720 value: HeaderValue::from_static(""),
721 },
722 34 => Header::Field {
723 name: header::ETAG,
724 value: HeaderValue::from_static(""),
725 },
726 35 => Header::Field {
727 name: header::EXPECT,
728 value: HeaderValue::from_static(""),
729 },
730 36 => Header::Field {
731 name: header::EXPIRES,
732 value: HeaderValue::from_static(""),
733 },
734 37 => Header::Field {
735 name: header::FROM,
736 value: HeaderValue::from_static(""),
737 },
738 38 => Header::Field {
739 name: header::HOST,
740 value: HeaderValue::from_static(""),
741 },
742 39 => Header::Field {
743 name: header::IF_MATCH,
744 value: HeaderValue::from_static(""),
745 },
746 40 => Header::Field {
747 name: header::IF_MODIFIED_SINCE,
748 value: HeaderValue::from_static(""),
749 },
750 41 => Header::Field {
751 name: header::IF_NONE_MATCH,
752 value: HeaderValue::from_static(""),
753 },
754 42 => Header::Field {
755 name: header::IF_RANGE,
756 value: HeaderValue::from_static(""),
757 },
758 43 => Header::Field {
759 name: header::IF_UNMODIFIED_SINCE,
760 value: HeaderValue::from_static(""),
761 },
762 44 => Header::Field {
763 name: header::LAST_MODIFIED,
764 value: HeaderValue::from_static(""),
765 },
766 45 => Header::Field {
767 name: header::LINK,
768 value: HeaderValue::from_static(""),
769 },
770 46 => Header::Field {
771 name: header::LOCATION,
772 value: HeaderValue::from_static(""),
773 },
774 47 => Header::Field {
775 name: header::MAX_FORWARDS,
776 value: HeaderValue::from_static(""),
777 },
778 48 => Header::Field {
779 name: header::PROXY_AUTHENTICATE,
780 value: HeaderValue::from_static(""),
781 },
782 49 => Header::Field {
783 name: header::PROXY_AUTHORIZATION,
784 value: HeaderValue::from_static(""),
785 },
786 50 => Header::Field {
787 name: header::RANGE,
788 value: HeaderValue::from_static(""),
789 },
790 51 => Header::Field {
791 name: header::REFERER,
792 value: HeaderValue::from_static(""),
793 },
794 52 => Header::Field {
795 name: header::REFRESH,
796 value: HeaderValue::from_static(""),
797 },
798 53 => Header::Field {
799 name: header::RETRY_AFTER,
800 value: HeaderValue::from_static(""),
801 },
802 54 => Header::Field {
803 name: header::SERVER,
804 value: HeaderValue::from_static(""),
805 },
806 55 => Header::Field {
807 name: header::SET_COOKIE,
808 value: HeaderValue::from_static(""),
809 },
810 56 => Header::Field {
811 name: header::STRICT_TRANSPORT_SECURITY,
812 value: HeaderValue::from_static(""),
813 },
814 57 => Header::Field {
815 name: header::TRANSFER_ENCODING,
816 value: HeaderValue::from_static(""),
817 },
818 58 => Header::Field {
819 name: header::USER_AGENT,
820 value: HeaderValue::from_static(""),
821 },
822 59 => Header::Field {
823 name: header::VARY,
824 value: HeaderValue::from_static(""),
825 },
826 60 => Header::Field {
827 name: header::VIA,
828 value: HeaderValue::from_static(""),
829 },
830 61 => Header::Field {
831 name: header::WWW_AUTHENTICATE,
832 value: HeaderValue::from_static(""),
833 },
834 _ => unreachable!(),
835 }
836}
837
838#[cfg(test)]
839mod test {
840 use super::*;
841
842 #[test]
843 fn test_peek_u8() {
844 let b = 0xff;
845 let mut buf = Cursor::new(vec![b]);
846 assert_eq!(peek_u8(&buf), Some(b));
847 assert_eq!(buf.get_u8(), b);
848 assert_eq!(peek_u8(&buf), None);
849 }
850
851 #[test]
852 fn test_decode_string_empty() {
853 let mut de = Decoder::new(0);
854 let mut buf = BytesMut::new();
855 let err = de.decode_string(&mut Cursor::new(&mut buf)).unwrap_err();
856 assert_eq!(err, DecoderError::NeedMore(NeedMore::UnexpectedEndOfStream));
857 }
858
859 #[test]
860 fn test_decode_empty() {
861 let mut de = Decoder::new(0);
862 let mut buf = BytesMut::new();
863 de.decode(&mut Cursor::new(&mut buf), |_| ControlFlow::Continue(()))
864 .unwrap();
865 }
866
867 #[test]
868 fn test_decode_indexed_larger_than_table() {
869 let mut de = Decoder::new(0);
870
871 let mut buf = BytesMut::new();
872 buf.extend([0b01000000, 0x80 | 2]);
873 buf.extend(huff_encode(b"foo"));
874 buf.extend([0x80 | 3]);
875 buf.extend(huff_encode(b"bar"));
876
877 let mut res = vec![];
878 de.decode(&mut Cursor::new(&mut buf), |h| {
879 res.push(h);
880 ControlFlow::Continue(())
881 })
882 .unwrap();
883
884 assert_eq!(res.len(), 1);
885 assert_eq!(de.table.size(), 0);
886
887 match res[0] {
888 Header::Field {
889 ref name,
890 ref value,
891 } => {
892 assert_eq!(name, "foo");
893 assert_eq!(value, "bar");
894 }
895 _ => panic!(),
896 }
897 }
898
899 fn huff_encode(src: &[u8]) -> BytesMut {
900 let mut buf = BytesMut::new();
901 huffman::encode(src, &mut buf);
902 buf
903 }
904
905 #[test]
906 fn test_decode_continuation_header_with_non_huff_encoded_name() {
907 let mut de = Decoder::new(0);
908 let value = huff_encode(b"bar");
909 let mut buf = BytesMut::new();
910 buf.extend([0b01000000, 3]);
912 buf.extend(b"foo");
913 buf.extend([0x80 | 3]);
915 buf.extend(&value[0..1]);
916
917 let mut res = vec![];
918 let e = de
919 .decode(&mut Cursor::new(&mut buf), |h| {
920 res.push(h);
921 ControlFlow::Continue(())
922 })
923 .unwrap_err();
924 assert_eq!(e, DecoderError::NeedMore(NeedMore::StringUnderflow));
926
927 buf.extend(&value[1..]);
929 de.decode(&mut Cursor::new(&mut buf), |h| {
930 res.push(h);
931 ControlFlow::Continue(())
932 })
933 .unwrap();
934
935 assert_eq!(res.len(), 1);
936 assert_eq!(de.table.size(), 0);
937
938 match res[0] {
939 Header::Field {
940 ref name,
941 ref value,
942 } => {
943 assert_eq!(name, "foo");
944 assert_eq!(value, "bar");
945 }
946 _ => panic!(),
947 }
948 }
949}