1use super::{util, StreamDependency, StreamId};
2use crate::ext::Protocol;
3use crate::frame::{Error, Frame, Head, Kind};
4use crate::hpack::{self, BytesStr};
5
6use http::header::{self, HeaderName, HeaderValue};
7use http::{uri, HeaderMap, Method, Request, StatusCode, Uri};
8
9use bytes::{Buf, BufMut, Bytes, BytesMut};
10
11use std::fmt;
12use std::io::Cursor;
13use std::ops::ControlFlow;
14
15type EncodeBuf<'a> = bytes::buf::Limit<&'a mut BytesMut>;
16
17const MAX_HEADER_LIST_ABUSE_MULTIPLIER: usize = 4;
18
19#[derive(Eq, PartialEq)]
23pub struct Headers {
24 stream_id: StreamId,
26
27 stream_dep: Option<StreamDependency>,
29
30 header_block: HeaderBlock,
32
33 flags: HeadersFlag,
35}
36
37#[derive(Copy, Clone, Eq, PartialEq)]
38pub struct HeadersFlag(u8);
39
40#[derive(Eq, PartialEq)]
41pub struct PushPromise {
42 stream_id: StreamId,
44
45 promised_id: StreamId,
47
48 header_block: HeaderBlock,
50
51 flags: PushPromiseFlag,
53}
54
55#[derive(Copy, Clone, Eq, PartialEq)]
56pub struct PushPromiseFlag(u8);
57
58#[derive(Debug)]
59pub struct Continuation {
60 stream_id: StreamId,
62
63 header_block: EncodingHeaderBlock,
64}
65
66#[derive(Debug, Default, Eq, PartialEq)]
68pub struct Pseudo {
69 pub method: Option<Method>,
71 pub scheme: Option<BytesStr>,
72 pub authority: Option<BytesStr>,
73 pub path: Option<BytesStr>,
74 pub protocol: Option<Protocol>,
75
76 pub status: Option<StatusCode>,
78}
79
80#[derive(Debug)]
81pub struct Iter {
82 pseudo: Option<Pseudo>,
84
85 fields: header::IntoIter<HeaderValue>,
87}
88
89#[derive(Debug, PartialEq, Eq)]
90struct HeaderBlock {
91 fields: HeaderMap,
93
94 field_size: usize,
96
97 is_over_size: bool,
99
100 pseudo: Pseudo,
103}
104
105#[derive(Debug)]
106struct EncodingHeaderBlock {
107 hpack: Bytes,
108}
109
110const END_STREAM: u8 = 0x1;
111const END_HEADERS: u8 = 0x4;
112const PADDED: u8 = 0x8;
113const PRIORITY: u8 = 0x20;
114const ALL: u8 = END_STREAM | END_HEADERS | PADDED | PRIORITY;
115
116impl Headers {
119 pub fn new(stream_id: StreamId, pseudo: Pseudo, fields: HeaderMap) -> Self {
121 Headers {
122 stream_id,
123 stream_dep: None,
124 header_block: HeaderBlock {
125 field_size: calculate_headermap_size(&fields),
126 fields,
127 is_over_size: false,
128 pseudo,
129 },
130 flags: HeadersFlag::default(),
131 }
132 }
133
134 pub fn trailers(stream_id: StreamId, fields: HeaderMap) -> Self {
135 let mut flags = HeadersFlag::default();
136 flags.set_end_stream();
137
138 Headers {
139 stream_id,
140 stream_dep: None,
141 header_block: HeaderBlock {
142 field_size: calculate_headermap_size(&fields),
143 fields,
144 is_over_size: false,
145 pseudo: Pseudo::default(),
146 },
147 flags,
148 }
149 }
150
151 pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
155 let flags = HeadersFlag(head.flag());
156 let mut pad = 0;
157
158 tracing::trace!("loading headers; flags={:?}", flags);
159
160 if head.stream_id().is_zero() {
161 return Err(Error::InvalidStreamId);
162 }
163
164 if flags.is_padded() {
166 if src.is_empty() {
167 return Err(Error::MalformedMessage);
168 }
169 pad = src[0] as usize;
170
171 src.advance(1);
173 }
174
175 let stream_dep = if flags.is_priority() {
177 if src.len() < 5 {
178 return Err(Error::MalformedMessage);
179 }
180 let stream_dep = StreamDependency::load(&src[..5])?;
181
182 if stream_dep.dependency_id() == head.stream_id() {
183 return Err(Error::InvalidDependencyId);
184 }
185
186 src.advance(5);
188
189 Some(stream_dep)
190 } else {
191 None
192 };
193
194 if pad > 0 {
195 if pad > src.len() {
196 return Err(Error::TooMuchPadding);
197 }
198
199 let len = src.len() - pad;
200 src.truncate(len);
201 }
202
203 let headers = Headers {
204 stream_id: head.stream_id(),
205 stream_dep,
206 header_block: HeaderBlock {
207 fields: HeaderMap::new(),
208 field_size: 0,
209 is_over_size: false,
210 pseudo: Pseudo::default(),
211 },
212 flags,
213 };
214
215 Ok((headers, src))
216 }
217
218 pub fn load_hpack(
219 &mut self,
220 src: &mut BytesMut,
221 max_header_list_size: usize,
222 decoder: &mut hpack::Decoder,
223 ) -> Result<(), Error> {
224 self.header_block.load(src, max_header_list_size, decoder)
225 }
226
227 pub fn stream_id(&self) -> StreamId {
228 self.stream_id
229 }
230
231 pub fn is_end_headers(&self) -> bool {
232 self.flags.is_end_headers()
233 }
234
235 pub fn set_end_headers(&mut self) {
236 self.flags.set_end_headers();
237 }
238
239 pub fn is_end_stream(&self) -> bool {
240 self.flags.is_end_stream()
241 }
242
243 pub fn set_end_stream(&mut self) {
244 self.flags.set_end_stream()
245 }
246
247 pub fn is_over_size(&self) -> bool {
248 self.header_block.is_over_size
249 }
250
251 pub fn into_parts(self) -> (Pseudo, HeaderMap) {
252 (self.header_block.pseudo, self.header_block.fields)
253 }
254
255 #[cfg(feature = "unstable")]
256 pub fn pseudo_mut(&mut self) -> &mut Pseudo {
257 &mut self.header_block.pseudo
258 }
259
260 pub(crate) fn pseudo(&self) -> &Pseudo {
261 &self.header_block.pseudo
262 }
263
264 pub(crate) fn is_informational(&self) -> bool {
266 self.header_block.pseudo.is_informational()
267 }
268
269 pub fn fields(&self) -> &HeaderMap {
270 &self.header_block.fields
271 }
272
273 pub fn into_fields(self) -> HeaderMap {
274 self.header_block.fields
275 }
276
277 pub fn encode(
278 self,
279 encoder: &mut hpack::Encoder,
280 dst: &mut EncodeBuf<'_>,
281 ) -> Option<Continuation> {
282 debug_assert!(self.flags.is_end_headers());
284
285 let head = self.head();
287
288 self.header_block
289 .into_encoding(encoder)
290 .encode(&head, dst, |_| {})
291 }
292
293 fn head(&self) -> Head {
294 Head::new(Kind::Headers, self.flags.into(), self.stream_id)
295 }
296}
297
298impl<T> From<Headers> for Frame<T> {
299 fn from(src: Headers) -> Self {
300 Frame::Headers(src)
301 }
302}
303
304impl fmt::Debug for Headers {
305 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
306 let mut builder = f.debug_struct("Headers");
307 builder
308 .field("stream_id", &self.stream_id)
309 .field("flags", &self.flags);
310
311 if let Some(ref protocol) = self.header_block.pseudo.protocol {
312 builder.field("protocol", protocol);
313 }
314
315 if let Some(ref dep) = self.stream_dep {
316 builder.field("stream_dep", dep);
317 }
318
319 builder.finish()
321 }
322}
323
324#[derive(Debug, PartialEq, Eq)]
327pub struct ParseU64Error;
328
329pub fn parse_u64(src: &[u8]) -> Result<u64, ParseU64Error> {
330 if src.len() > 19 {
331 return Err(ParseU64Error);
333 }
334
335 let mut ret = 0;
336
337 for &d in src {
338 if d < b'0' || d > b'9' {
339 return Err(ParseU64Error);
340 }
341
342 ret *= 10;
343 ret += (d - b'0') as u64;
344 }
345
346 Ok(ret)
347}
348
349#[derive(Debug)]
352pub enum PushPromiseHeaderError {
353 InvalidContentLength(Result<u64, ParseU64Error>),
354 NotSafeAndCacheable,
355}
356
357impl PushPromise {
358 pub fn new(
359 stream_id: StreamId,
360 promised_id: StreamId,
361 pseudo: Pseudo,
362 fields: HeaderMap,
363 ) -> Self {
364 PushPromise {
365 flags: PushPromiseFlag::default(),
366 header_block: HeaderBlock {
367 field_size: calculate_headermap_size(&fields),
368 fields,
369 is_over_size: false,
370 pseudo,
371 },
372 promised_id,
373 stream_id,
374 }
375 }
376
377 pub fn validate_request(req: &Request<()>) -> Result<(), PushPromiseHeaderError> {
378 use PushPromiseHeaderError::*;
379 if let Some(content_length) = req.headers().get(header::CONTENT_LENGTH) {
385 let parsed_length = parse_u64(content_length.as_bytes());
386 if parsed_length != Ok(0) {
387 return Err(InvalidContentLength(parsed_length));
388 }
389 }
390 if !Self::safe_and_cacheable(req.method()) {
393 return Err(NotSafeAndCacheable);
394 }
395
396 Ok(())
397 }
398
399 fn safe_and_cacheable(method: &Method) -> bool {
400 method == Method::GET || method == Method::HEAD
403 }
404
405 pub fn fields(&self) -> &HeaderMap {
406 &self.header_block.fields
407 }
408
409 #[cfg(feature = "unstable")]
410 pub fn into_fields(self) -> HeaderMap {
411 self.header_block.fields
412 }
413
414 pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
418 let flags = PushPromiseFlag(head.flag());
419 let mut pad = 0;
420
421 if head.stream_id().is_zero() {
422 return Err(Error::InvalidStreamId);
423 }
424
425 if flags.is_padded() {
427 if src.is_empty() {
428 return Err(Error::MalformedMessage);
429 }
430
431 pad = src[0] as usize;
433
434 src.advance(1);
436 }
437
438 if src.len() < 5 {
439 return Err(Error::MalformedMessage);
440 }
441
442 let (promised_id, _) = StreamId::parse(&src[..4]);
443 src.advance(4);
445
446 if pad > 0 {
447 if pad > src.len() {
448 return Err(Error::TooMuchPadding);
449 }
450
451 let len = src.len() - pad;
452 src.truncate(len);
453 }
454
455 let frame = PushPromise {
456 flags,
457 header_block: HeaderBlock {
458 fields: HeaderMap::new(),
459 field_size: 0,
460 is_over_size: false,
461 pseudo: Pseudo::default(),
462 },
463 promised_id,
464 stream_id: head.stream_id(),
465 };
466 Ok((frame, src))
467 }
468
469 pub fn load_hpack(
470 &mut self,
471 src: &mut BytesMut,
472 max_header_list_size: usize,
473 decoder: &mut hpack::Decoder,
474 ) -> Result<(), Error> {
475 self.header_block.load(src, max_header_list_size, decoder)
476 }
477
478 pub fn stream_id(&self) -> StreamId {
479 self.stream_id
480 }
481
482 pub fn promised_id(&self) -> StreamId {
483 self.promised_id
484 }
485
486 pub fn is_end_headers(&self) -> bool {
487 self.flags.is_end_headers()
488 }
489
490 pub fn set_end_headers(&mut self) {
491 self.flags.set_end_headers();
492 }
493
494 pub fn is_over_size(&self) -> bool {
495 self.header_block.is_over_size
496 }
497
498 pub fn encode(
499 self,
500 encoder: &mut hpack::Encoder,
501 dst: &mut EncodeBuf<'_>,
502 ) -> Option<Continuation> {
503 debug_assert!(self.flags.is_end_headers());
505
506 let head = self.head();
507 let promised_id = self.promised_id;
508
509 self.header_block
510 .into_encoding(encoder)
511 .encode(&head, dst, |dst| {
512 dst.put_u32(promised_id.into());
513 })
514 }
515
516 fn head(&self) -> Head {
517 Head::new(Kind::PushPromise, self.flags.into(), self.stream_id)
518 }
519
520 pub fn into_parts(self) -> (Pseudo, HeaderMap) {
522 (self.header_block.pseudo, self.header_block.fields)
523 }
524}
525
526impl<T> From<PushPromise> for Frame<T> {
527 fn from(src: PushPromise) -> Self {
528 Frame::PushPromise(src)
529 }
530}
531
532impl fmt::Debug for PushPromise {
533 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
534 f.debug_struct("PushPromise")
535 .field("stream_id", &self.stream_id)
536 .field("promised_id", &self.promised_id)
537 .field("flags", &self.flags)
538 .finish()
540 }
541}
542
543impl Continuation {
546 fn head(&self) -> Head {
547 Head::new(Kind::Continuation, END_HEADERS, self.stream_id)
548 }
549
550 pub fn encode(self, dst: &mut EncodeBuf<'_>) -> Option<Continuation> {
551 let head = self.head();
553
554 self.header_block.encode(&head, dst, |_| {})
555 }
556}
557
558impl Pseudo {
561 pub fn request(method: Method, uri: Uri, protocol: Option<Protocol>) -> Self {
562 let parts = uri::Parts::from(uri);
563
564 let (scheme, path) = if method == Method::CONNECT && protocol.is_none() {
565 (None, None)
566 } else {
567 let path = parts
568 .path_and_query
569 .map(|v| BytesStr::from(v.as_str()))
570 .unwrap_or(BytesStr::from_static(""));
571
572 let path = if !path.is_empty() {
573 path
574 } else if method == Method::OPTIONS {
575 BytesStr::from_static("*")
576 } else {
577 BytesStr::from_static("/")
578 };
579
580 (parts.scheme, Some(path))
581 };
582
583 let mut pseudo = Pseudo {
584 method: Some(method),
585 scheme: None,
586 authority: None,
587 path,
588 protocol,
589 status: None,
590 };
591
592 if let Some(scheme) = scheme {
594 pseudo.set_scheme(scheme);
595 }
596
597 if let Some(authority) = parts.authority {
600 pseudo.set_authority(BytesStr::from(authority.as_str()));
601 }
602
603 pseudo
604 }
605
606 pub fn response(status: StatusCode) -> Self {
607 Pseudo {
608 method: None,
609 scheme: None,
610 authority: None,
611 path: None,
612 protocol: None,
613 status: Some(status),
614 }
615 }
616
617 #[cfg(feature = "unstable")]
618 pub fn set_status(&mut self, value: StatusCode) {
619 self.status = Some(value);
620 }
621
622 pub fn set_scheme(&mut self, scheme: uri::Scheme) {
623 let bytes_str = match scheme.as_str() {
624 "http" => BytesStr::from_static("http"),
625 "https" => BytesStr::from_static("https"),
626 s => BytesStr::from(s),
627 };
628 self.scheme = Some(bytes_str);
629 }
630
631 #[cfg(feature = "unstable")]
632 pub fn set_protocol(&mut self, protocol: Protocol) {
633 self.protocol = Some(protocol);
634 }
635
636 pub fn set_authority(&mut self, authority: BytesStr) {
637 self.authority = Some(authority);
638 }
639
640 pub(crate) fn is_informational(&self) -> bool {
642 self.status
643 .map_or(false, |status| status.is_informational())
644 }
645}
646
647impl EncodingHeaderBlock {
650 fn encode<F>(mut self, head: &Head, dst: &mut EncodeBuf<'_>, f: F) -> Option<Continuation>
651 where
652 F: FnOnce(&mut EncodeBuf<'_>),
653 {
654 let head_pos = dst.get_ref().len();
655
656 head.encode(0, dst);
660
661 let payload_pos = dst.get_ref().len();
662
663 f(dst);
664
665 let continuation = if self.hpack.len() > dst.remaining_mut() {
667 dst.put((&mut self.hpack).take(dst.remaining_mut()));
668
669 Some(Continuation {
670 stream_id: head.stream_id(),
671 header_block: self,
672 })
673 } else {
674 dst.put_slice(&self.hpack);
675
676 None
677 };
678
679 let payload_len = (dst.get_ref().len() - payload_pos) as u64;
681
682 let payload_len_be = payload_len.to_be_bytes();
684 assert!(payload_len_be[0..5].iter().all(|b| *b == 0));
685 (dst.get_mut()[head_pos..head_pos + 3]).copy_from_slice(&payload_len_be[5..]);
686
687 if continuation.is_some() {
688 debug_assert!(dst.get_ref()[head_pos + 4] & END_HEADERS == END_HEADERS);
691
692 dst.get_mut()[head_pos + 4] -= END_HEADERS;
693 }
694
695 continuation
696 }
697}
698
699impl Iterator for Iter {
702 type Item = hpack::Header<Option<HeaderName>>;
703
704 fn next(&mut self) -> Option<Self::Item> {
705 use crate::hpack::Header::*;
706
707 if let Some(ref mut pseudo) = self.pseudo {
708 if let Some(method) = pseudo.method.take() {
709 return Some(Method(method));
710 }
711
712 if let Some(scheme) = pseudo.scheme.take() {
713 return Some(Scheme(scheme));
714 }
715
716 if let Some(authority) = pseudo.authority.take() {
717 return Some(Authority(authority));
718 }
719
720 if let Some(path) = pseudo.path.take() {
721 return Some(Path(path));
722 }
723
724 if let Some(protocol) = pseudo.protocol.take() {
725 return Some(Protocol(protocol));
726 }
727
728 if let Some(status) = pseudo.status.take() {
729 return Some(Status(status));
730 }
731 }
732
733 self.pseudo = None;
734
735 self.fields
736 .next()
737 .map(|(name, value)| Field { name, value })
738 }
739}
740
741impl HeadersFlag {
744 pub fn empty() -> HeadersFlag {
745 HeadersFlag(0)
746 }
747
748 pub fn load(bits: u8) -> HeadersFlag {
749 HeadersFlag(bits & ALL)
750 }
751
752 pub fn is_end_stream(&self) -> bool {
753 self.0 & END_STREAM == END_STREAM
754 }
755
756 pub fn set_end_stream(&mut self) {
757 self.0 |= END_STREAM;
758 }
759
760 pub fn is_end_headers(&self) -> bool {
761 self.0 & END_HEADERS == END_HEADERS
762 }
763
764 pub fn set_end_headers(&mut self) {
765 self.0 |= END_HEADERS;
766 }
767
768 pub fn is_padded(&self) -> bool {
769 self.0 & PADDED == PADDED
770 }
771
772 pub fn is_priority(&self) -> bool {
773 self.0 & PRIORITY == PRIORITY
774 }
775}
776
777impl Default for HeadersFlag {
778 fn default() -> Self {
780 HeadersFlag(END_HEADERS)
781 }
782}
783
784impl From<HeadersFlag> for u8 {
785 fn from(src: HeadersFlag) -> u8 {
786 src.0
787 }
788}
789
790impl fmt::Debug for HeadersFlag {
791 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
792 util::debug_flags(fmt, self.0)
793 .flag_if(self.is_end_headers(), "END_HEADERS")
794 .flag_if(self.is_end_stream(), "END_STREAM")
795 .flag_if(self.is_padded(), "PADDED")
796 .flag_if(self.is_priority(), "PRIORITY")
797 .finish()
798 }
799}
800
801impl PushPromiseFlag {
804 pub fn empty() -> PushPromiseFlag {
805 PushPromiseFlag(0)
806 }
807
808 pub fn load(bits: u8) -> PushPromiseFlag {
809 PushPromiseFlag(bits & ALL)
810 }
811
812 pub fn is_end_headers(&self) -> bool {
813 self.0 & END_HEADERS == END_HEADERS
814 }
815
816 pub fn set_end_headers(&mut self) {
817 self.0 |= END_HEADERS;
818 }
819
820 pub fn is_padded(&self) -> bool {
821 self.0 & PADDED == PADDED
822 }
823}
824
825impl Default for PushPromiseFlag {
826 fn default() -> Self {
828 PushPromiseFlag(END_HEADERS)
829 }
830}
831
832impl From<PushPromiseFlag> for u8 {
833 fn from(src: PushPromiseFlag) -> u8 {
834 src.0
835 }
836}
837
838impl fmt::Debug for PushPromiseFlag {
839 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
840 util::debug_flags(fmt, self.0)
841 .flag_if(self.is_end_headers(), "END_HEADERS")
842 .flag_if(self.is_padded(), "PADDED")
843 .finish()
844 }
845}
846
847impl HeaderBlock {
850 fn load(
851 &mut self,
852 src: &mut BytesMut,
853 max_header_list_size: usize,
854 decoder: &mut hpack::Decoder,
855 ) -> Result<(), Error> {
856 let mut reg = !self.fields.is_empty();
857 let mut malformed = false;
858 let mut header_list_way_too_large = false;
859 let mut headers_size = self.calculate_header_list_size();
860 let max_header_list_abuse_size =
861 max_header_list_size.saturating_mul(MAX_HEADER_LIST_ABUSE_MULTIPLIER);
862
863 macro_rules! check_size {
864 () => {{
865 if headers_size > max_header_list_abuse_size {
866 tracing::trace!("load_hpack; header list size over abuse max");
867 header_list_way_too_large = true;
868 ControlFlow::Break(())
869 } else {
870 if headers_size >= max_header_list_size && !self.is_over_size {
871 tracing::trace!("load_hpack; header list size over max");
872 self.is_over_size = true;
873 }
874 ControlFlow::Continue(())
875 }
876 }};
877 }
878
879 macro_rules! set_pseudo {
880 ($field:ident, $val:expr) => {{
881 if reg {
882 tracing::trace!("load_hpack; header malformed -- pseudo not at head of block");
883 malformed = true;
884 } else if self.pseudo.$field.is_some() {
885 tracing::trace!("load_hpack; header malformed -- repeated pseudo");
886 malformed = true;
887 } else {
888 let __val = $val;
889 headers_size +=
890 decoded_header_size(stringify!($field).len() + 1, __val.as_str().len());
891 if check_size!().is_break() {
892 return ControlFlow::Break(());
893 }
894 if !self.is_over_size {
895 self.pseudo.$field = Some(__val);
896 }
897 }
898 }};
899 }
900
901 let mut cursor = Cursor::new(src);
902
903 let res = decoder.decode(&mut cursor, |header| {
908 use crate::hpack::Header::*;
909
910 match header {
911 Field { name, value } => {
912 if name == header::CONNECTION
916 || name == header::TRANSFER_ENCODING
917 || name == header::UPGRADE
918 || name == "keep-alive"
919 || name == "proxy-connection"
920 {
921 tracing::trace!("load_hpack; connection level header");
922 malformed = true;
923 } else if name == header::TE && value != "trailers" {
924 tracing::trace!(
925 "load_hpack; TE header not set to trailers; val={:?}",
926 value
927 );
928 malformed = true;
929 } else {
930 reg = true;
931
932 let header_size = decoded_header_size(name.as_str().len(), value.len());
933 headers_size += header_size;
934 if check_size!().is_break() {
935 return ControlFlow::Break(());
936 }
937 if !self.is_over_size {
938 self.field_size += header_size;
939 if let Err(_) = self.fields.try_append(name, value) {
940 self.is_over_size = true;
944 }
945 }
946 }
947 }
948 Authority(v) => set_pseudo!(authority, v),
949 Method(v) => set_pseudo!(method, v),
950 Scheme(v) => set_pseudo!(scheme, v),
951 Path(v) => set_pseudo!(path, v),
952 Protocol(v) => set_pseudo!(protocol, v),
953 Status(v) => set_pseudo!(status, v),
954 }
955
956 ControlFlow::Continue(())
957 });
958
959 match res {
960 Ok(()) => {}
961 Err(e) => {
962 tracing::trace!("hpack decoding error; err={:?}", e);
963 return Err(e.into());
964 }
965 }
966
967 if header_list_way_too_large {
968 tracing::trace!("header list way too large; aborting connection");
969 return Err(Error::HeaderListWayTooLarge);
970 }
971
972 if malformed {
973 tracing::trace!("malformed message");
974 return Err(Error::MalformedMessage);
975 }
976
977 Ok(())
978 }
979
980 fn into_encoding(self, encoder: &mut hpack::Encoder) -> EncodingHeaderBlock {
981 let mut hpack = BytesMut::new();
982 let headers = Iter {
983 pseudo: Some(self.pseudo),
984 fields: self.fields.into_iter(),
985 };
986
987 encoder.encode(headers, &mut hpack);
988
989 EncodingHeaderBlock {
990 hpack: hpack.freeze(),
991 }
992 }
993
994 fn calculate_header_list_size(&self) -> usize {
1002 macro_rules! pseudo_size {
1003 ($name:ident) => {{
1004 self.pseudo
1005 .$name
1006 .as_ref()
1007 .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len()))
1008 .unwrap_or(0)
1009 }};
1010 }
1011
1012 pseudo_size!(method)
1013 + pseudo_size!(scheme)
1014 + pseudo_size!(status)
1015 + pseudo_size!(authority)
1016 + pseudo_size!(path)
1017 + self.field_size
1018 }
1019}
1020
1021fn calculate_headermap_size(map: &HeaderMap) -> usize {
1022 map.iter()
1023 .map(|(name, value)| decoded_header_size(name.as_str().len(), value.len()))
1024 .sum::<usize>()
1025}
1026
1027fn decoded_header_size(name: usize, value: usize) -> usize {
1028 name + value + 32
1029}
1030
1031#[cfg(test)]
1032mod test {
1033 use super::*;
1034 use crate::frame;
1035 use crate::hpack::{huffman, Encoder};
1036
1037 #[test]
1038 fn test_nameless_header_at_resume() {
1039 let mut encoder = Encoder::default();
1040 let mut dst = BytesMut::new();
1041
1042 let headers = Headers::new(
1043 StreamId::ZERO,
1044 Default::default(),
1045 HeaderMap::from_iter(vec![
1046 (
1047 HeaderName::from_static("hello"),
1048 HeaderValue::from_static("world"),
1049 ),
1050 (
1051 HeaderName::from_static("hello"),
1052 HeaderValue::from_static("zomg"),
1053 ),
1054 (
1055 HeaderName::from_static("hello"),
1056 HeaderValue::from_static("sup"),
1057 ),
1058 ]),
1059 );
1060
1061 let continuation = headers
1062 .encode(&mut encoder, &mut (&mut dst).limit(frame::HEADER_LEN + 8))
1063 .unwrap();
1064
1065 assert_eq!(17, dst.len());
1066 assert_eq!([0, 0, 8, 1, 0, 0, 0, 0, 0], &dst[0..9]);
1067 assert_eq!(&[0x40, 0x80 | 4], &dst[9..11]);
1068 assert_eq!("hello", huff_decode(&dst[11..15]));
1069 assert_eq!(0x80 | 4, dst[15]);
1070
1071 let mut world = dst[16..17].to_owned();
1072
1073 dst.clear();
1074
1075 assert!(continuation
1076 .encode(&mut (&mut dst).limit(frame::HEADER_LEN + 16))
1077 .is_none());
1078
1079 world.extend_from_slice(&dst[9..12]);
1080 assert_eq!("world", huff_decode(&world));
1081
1082 assert_eq!(24, dst.len());
1083 assert_eq!([0, 0, 15, 9, 4, 0, 0, 0, 0], &dst[0..9]);
1084
1085 assert_eq!(&[15, 47, 0x80 | 3], &dst[12..15]);
1087 assert_eq!("zomg", huff_decode(&dst[15..18]));
1088 assert_eq!(&[15, 47, 0x80 | 3], &dst[18..21]);
1089 assert_eq!("sup", huff_decode(&dst[21..]));
1090 }
1091
1092 fn huff_decode(src: &[u8]) -> BytesMut {
1093 let mut buf = BytesMut::new();
1094 huffman::decode(src, &mut buf).unwrap()
1095 }
1096
1097 #[test]
1098 fn test_connect_request_pseudo_headers_omits_path_and_scheme() {
1099 assert_eq!(
1103 Pseudo::request(
1104 Method::CONNECT,
1105 Uri::from_static("https://example.com:8443"),
1106 None
1107 ),
1108 Pseudo {
1109 method: Method::CONNECT.into(),
1110 authority: BytesStr::from_static("example.com:8443").into(),
1111 ..Default::default()
1112 }
1113 );
1114
1115 assert_eq!(
1116 Pseudo::request(
1117 Method::CONNECT,
1118 Uri::from_static("https://example.com/test"),
1119 None
1120 ),
1121 Pseudo {
1122 method: Method::CONNECT.into(),
1123 authority: BytesStr::from_static("example.com").into(),
1124 ..Default::default()
1125 }
1126 );
1127
1128 assert_eq!(
1129 Pseudo::request(Method::CONNECT, Uri::from_static("example.com:8443"), None),
1130 Pseudo {
1131 method: Method::CONNECT.into(),
1132 authority: BytesStr::from_static("example.com:8443").into(),
1133 ..Default::default()
1134 }
1135 );
1136 }
1137
1138 #[test]
1139 fn test_extended_connect_request_pseudo_headers_includes_path_and_scheme() {
1140 assert_eq!(
1146 Pseudo::request(
1147 Method::CONNECT,
1148 Uri::from_static("https://example.com:8443"),
1149 Protocol::from_static("the-bread-protocol").into()
1150 ),
1151 Pseudo {
1152 method: Method::CONNECT.into(),
1153 authority: BytesStr::from_static("example.com:8443").into(),
1154 scheme: BytesStr::from_static("https").into(),
1155 path: BytesStr::from_static("/").into(),
1156 protocol: Protocol::from_static("the-bread-protocol").into(),
1157 ..Default::default()
1158 }
1159 );
1160
1161 assert_eq!(
1162 Pseudo::request(
1163 Method::CONNECT,
1164 Uri::from_static("https://example.com:8443/test"),
1165 Protocol::from_static("the-bread-protocol").into()
1166 ),
1167 Pseudo {
1168 method: Method::CONNECT.into(),
1169 authority: BytesStr::from_static("example.com:8443").into(),
1170 scheme: BytesStr::from_static("https").into(),
1171 path: BytesStr::from_static("/test").into(),
1172 protocol: Protocol::from_static("the-bread-protocol").into(),
1173 ..Default::default()
1174 }
1175 );
1176
1177 assert_eq!(
1178 Pseudo::request(
1179 Method::CONNECT,
1180 Uri::from_static("http://example.com/a/b/c"),
1181 Protocol::from_static("the-bread-protocol").into()
1182 ),
1183 Pseudo {
1184 method: Method::CONNECT.into(),
1185 authority: BytesStr::from_static("example.com").into(),
1186 scheme: BytesStr::from_static("http").into(),
1187 path: BytesStr::from_static("/a/b/c").into(),
1188 protocol: Protocol::from_static("the-bread-protocol").into(),
1189 ..Default::default()
1190 }
1191 );
1192 }
1193
1194 #[test]
1195 fn test_options_request_with_empty_path_has_asterisk_as_pseudo_path() {
1196 assert_eq!(
1200 Pseudo::request(Method::OPTIONS, Uri::from_static("example.com:8080"), None,),
1201 Pseudo {
1202 method: Method::OPTIONS.into(),
1203 authority: BytesStr::from_static("example.com:8080").into(),
1204 path: BytesStr::from_static("*").into(),
1205 ..Default::default()
1206 }
1207 );
1208 }
1209
1210 #[test]
1211 fn test_try_append_prevents_panic_on_max_size_reached() {
1212 let num_headers = 25_000;
1235
1236 let mut hpack = Vec::new();
1238
1239 hpack.push(0x82u8); hpack.push(0x86); hpack.push(0x84); hpack.push(0x41); hpack.push(0x09); hpack.extend_from_slice(b"localhost");
1248
1249 for i in 0..num_headers {
1252 let name = format!("x-h-{i}");
1253 hpack.push(0x00u8); hpack.push(name.len() as u8);
1255 hpack.extend_from_slice(name.as_bytes());
1256 hpack.push(1u8); hpack.push(b'v');
1258 }
1259
1260 let payload_len = hpack.len();
1262 let mut frame = BytesMut::with_capacity(9 + payload_len);
1263
1264 frame.put_u8(((payload_len >> 16) & 0xFF) as u8);
1266 frame.put_u8(((payload_len >> 8) & 0xFF) as u8);
1267 frame.put_u8((payload_len & 0xFF) as u8);
1268 frame.put_u8(0x01); frame.put_u8(0x04); frame.put_u32(1); frame.extend_from_slice(&hpack);
1273
1274 let head = Head::parse(&frame[..9]);
1276 let payload = BytesMut::from(&frame[9..]);
1277 let (mut headers, mut hpack_data) = Headers::load(head, payload).unwrap();
1278 let mut decoder = hpack::Decoder::new(4096);
1282 const DEFAULT_MAX_HEADER_LIST_SIZE: usize = 16 << 20; headers
1284 .load_hpack(&mut hpack_data, DEFAULT_MAX_HEADER_LIST_SIZE, &mut decoder)
1285 .expect("load_hpack should return Ok");
1286
1287 assert!(
1289 headers.is_over_size(),
1290 "is_over_size should be true when HeaderMap capacity is exceeded"
1291 );
1292 }
1293
1294 #[test]
1295 fn test_non_option_and_non_connect_requests_include_path_and_scheme() {
1296 let methods = [
1297 Method::GET,
1298 Method::POST,
1299 Method::PUT,
1300 Method::DELETE,
1301 Method::HEAD,
1302 Method::PATCH,
1303 Method::TRACE,
1304 ];
1305
1306 for method in methods {
1307 assert_eq!(
1308 Pseudo::request(
1309 method.clone(),
1310 Uri::from_static("http://example.com:8080"),
1311 None,
1312 ),
1313 Pseudo {
1314 method: method.clone().into(),
1315 authority: BytesStr::from_static("example.com:8080").into(),
1316 scheme: BytesStr::from_static("http").into(),
1317 path: BytesStr::from_static("/").into(),
1318 ..Default::default()
1319 }
1320 );
1321 assert_eq!(
1322 Pseudo::request(
1323 method.clone(),
1324 Uri::from_static("https://example.com/a/b/c"),
1325 None,
1326 ),
1327 Pseudo {
1328 method: method.into(),
1329 authority: BytesStr::from_static("example.com").into(),
1330 scheme: BytesStr::from_static("https").into(),
1331 path: BytesStr::from_static("/a/b/c").into(),
1332 ..Default::default()
1333 }
1334 );
1335 }
1336 }
1337}