Skip to main content

h2/frame/
headers.rs

1use super::{util, StreamDependency, StreamId};
2use crate::ext::Protocol;
3use crate::frame::{Error, Frame, Head, Kind};
4use crate::hpack::{self, BytesStr};
5
6use http::header::{self, HeaderName, HeaderValue};
7use http::{uri, HeaderMap, Method, Request, StatusCode, Uri};
8
9use bytes::{Buf, BufMut, Bytes, BytesMut};
10
11use std::fmt;
12use std::io::Cursor;
13use std::ops::ControlFlow;
14
15type EncodeBuf<'a> = bytes::buf::Limit<&'a mut BytesMut>;
16
17const MAX_HEADER_LIST_ABUSE_MULTIPLIER: usize = 4;
18
19/// Header frame
20///
21/// This could be either a request or a response.
22#[derive(Eq, PartialEq)]
23pub struct Headers {
24    /// The ID of the stream with which this frame is associated.
25    stream_id: StreamId,
26
27    /// The stream dependency information, if any.
28    stream_dep: Option<StreamDependency>,
29
30    /// The header block fragment
31    header_block: HeaderBlock,
32
33    /// The associated flags
34    flags: HeadersFlag,
35}
36
37#[derive(Copy, Clone, Eq, PartialEq)]
38pub struct HeadersFlag(u8);
39
40#[derive(Eq, PartialEq)]
41pub struct PushPromise {
42    /// The ID of the stream with which this frame is associated.
43    stream_id: StreamId,
44
45    /// The ID of the stream being reserved by this PushPromise.
46    promised_id: StreamId,
47
48    /// The header block fragment
49    header_block: HeaderBlock,
50
51    /// The associated flags
52    flags: PushPromiseFlag,
53}
54
55#[derive(Copy, Clone, Eq, PartialEq)]
56pub struct PushPromiseFlag(u8);
57
58#[derive(Debug)]
59pub struct Continuation {
60    /// Stream ID of continuation frame
61    stream_id: StreamId,
62
63    header_block: EncodingHeaderBlock,
64}
65
66// TODO: These fields shouldn't be `pub`
67#[derive(Debug, Default, Eq, PartialEq)]
68pub struct Pseudo {
69    // Request
70    pub method: Option<Method>,
71    pub scheme: Option<BytesStr>,
72    pub authority: Option<BytesStr>,
73    pub path: Option<BytesStr>,
74    pub protocol: Option<Protocol>,
75
76    // Response
77    pub status: Option<StatusCode>,
78}
79
80#[derive(Debug)]
81pub struct Iter {
82    /// Pseudo headers
83    pseudo: Option<Pseudo>,
84
85    /// Header fields
86    fields: header::IntoIter<HeaderValue>,
87}
88
89#[derive(Debug, PartialEq, Eq)]
90struct HeaderBlock {
91    /// The decoded header fields
92    fields: HeaderMap,
93
94    /// Precomputed size of all of our header fields, for perf reasons
95    field_size: usize,
96
97    /// Set to true if decoding went over the max header list size.
98    is_over_size: bool,
99
100    /// Pseudo headers, these are broken out as they must be sent as part of the
101    /// headers frame.
102    pseudo: Pseudo,
103}
104
105#[derive(Debug)]
106struct EncodingHeaderBlock {
107    hpack: Bytes,
108}
109
110const END_STREAM: u8 = 0x1;
111const END_HEADERS: u8 = 0x4;
112const PADDED: u8 = 0x8;
113const PRIORITY: u8 = 0x20;
114const ALL: u8 = END_STREAM | END_HEADERS | PADDED | PRIORITY;
115
116// ===== impl Headers =====
117
118impl Headers {
119    /// Create a new HEADERS frame
120    pub fn new(stream_id: StreamId, pseudo: Pseudo, fields: HeaderMap) -> Self {
121        Headers {
122            stream_id,
123            stream_dep: None,
124            header_block: HeaderBlock {
125                field_size: calculate_headermap_size(&fields),
126                fields,
127                is_over_size: false,
128                pseudo,
129            },
130            flags: HeadersFlag::default(),
131        }
132    }
133
134    pub fn trailers(stream_id: StreamId, fields: HeaderMap) -> Self {
135        let mut flags = HeadersFlag::default();
136        flags.set_end_stream();
137
138        Headers {
139            stream_id,
140            stream_dep: None,
141            header_block: HeaderBlock {
142                field_size: calculate_headermap_size(&fields),
143                fields,
144                is_over_size: false,
145                pseudo: Pseudo::default(),
146            },
147            flags,
148        }
149    }
150
151    /// Loads the header frame but doesn't actually do HPACK decoding.
152    ///
153    /// HPACK decoding is done in the `load_hpack` step.
154    pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
155        let flags = HeadersFlag(head.flag());
156        let mut pad = 0;
157
158        tracing::trace!("loading headers; flags={:?}", flags);
159
160        if head.stream_id().is_zero() {
161            return Err(Error::InvalidStreamId);
162        }
163
164        // Read the padding length
165        if flags.is_padded() {
166            if src.is_empty() {
167                return Err(Error::MalformedMessage);
168            }
169            pad = src[0] as usize;
170
171            // Drop the padding
172            src.advance(1);
173        }
174
175        // Read the stream dependency
176        let stream_dep = if flags.is_priority() {
177            if src.len() < 5 {
178                return Err(Error::MalformedMessage);
179            }
180            let stream_dep = StreamDependency::load(&src[..5])?;
181
182            if stream_dep.dependency_id() == head.stream_id() {
183                return Err(Error::InvalidDependencyId);
184            }
185
186            // Drop the next 5 bytes
187            src.advance(5);
188
189            Some(stream_dep)
190        } else {
191            None
192        };
193
194        if pad > 0 {
195            if pad > src.len() {
196                return Err(Error::TooMuchPadding);
197            }
198
199            let len = src.len() - pad;
200            src.truncate(len);
201        }
202
203        let headers = Headers {
204            stream_id: head.stream_id(),
205            stream_dep,
206            header_block: HeaderBlock {
207                fields: HeaderMap::new(),
208                field_size: 0,
209                is_over_size: false,
210                pseudo: Pseudo::default(),
211            },
212            flags,
213        };
214
215        Ok((headers, src))
216    }
217
218    pub fn load_hpack(
219        &mut self,
220        src: &mut BytesMut,
221        max_header_list_size: usize,
222        decoder: &mut hpack::Decoder,
223    ) -> Result<(), Error> {
224        self.header_block.load(src, max_header_list_size, decoder)
225    }
226
227    pub fn stream_id(&self) -> StreamId {
228        self.stream_id
229    }
230
231    pub fn is_end_headers(&self) -> bool {
232        self.flags.is_end_headers()
233    }
234
235    pub fn set_end_headers(&mut self) {
236        self.flags.set_end_headers();
237    }
238
239    pub fn is_end_stream(&self) -> bool {
240        self.flags.is_end_stream()
241    }
242
243    pub fn set_end_stream(&mut self) {
244        self.flags.set_end_stream()
245    }
246
247    pub fn is_over_size(&self) -> bool {
248        self.header_block.is_over_size
249    }
250
251    pub fn into_parts(self) -> (Pseudo, HeaderMap) {
252        (self.header_block.pseudo, self.header_block.fields)
253    }
254
255    #[cfg(feature = "unstable")]
256    pub fn pseudo_mut(&mut self) -> &mut Pseudo {
257        &mut self.header_block.pseudo
258    }
259
260    pub(crate) fn pseudo(&self) -> &Pseudo {
261        &self.header_block.pseudo
262    }
263
264    /// Whether it has status 1xx
265    pub(crate) fn is_informational(&self) -> bool {
266        self.header_block.pseudo.is_informational()
267    }
268
269    pub fn fields(&self) -> &HeaderMap {
270        &self.header_block.fields
271    }
272
273    pub fn into_fields(self) -> HeaderMap {
274        self.header_block.fields
275    }
276
277    pub fn encode(
278        self,
279        encoder: &mut hpack::Encoder,
280        dst: &mut EncodeBuf<'_>,
281    ) -> Option<Continuation> {
282        // At this point, the `is_end_headers` flag should always be set
283        debug_assert!(self.flags.is_end_headers());
284
285        // Get the HEADERS frame head
286        let head = self.head();
287
288        self.header_block
289            .into_encoding(encoder)
290            .encode(&head, dst, |_| {})
291    }
292
293    fn head(&self) -> Head {
294        Head::new(Kind::Headers, self.flags.into(), self.stream_id)
295    }
296}
297
298impl<T> From<Headers> for Frame<T> {
299    fn from(src: Headers) -> Self {
300        Frame::Headers(src)
301    }
302}
303
304impl fmt::Debug for Headers {
305    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
306        let mut builder = f.debug_struct("Headers");
307        builder
308            .field("stream_id", &self.stream_id)
309            .field("flags", &self.flags);
310
311        if let Some(ref protocol) = self.header_block.pseudo.protocol {
312            builder.field("protocol", protocol);
313        }
314
315        if let Some(ref dep) = self.stream_dep {
316            builder.field("stream_dep", dep);
317        }
318
319        // `fields` and `pseudo` purposefully not included
320        builder.finish()
321    }
322}
323
324// ===== util =====
325
326#[derive(Debug, PartialEq, Eq)]
327pub struct ParseU64Error;
328
329pub fn parse_u64(src: &[u8]) -> Result<u64, ParseU64Error> {
330    if src.len() > 19 {
331        // At danger for overflow...
332        return Err(ParseU64Error);
333    }
334
335    let mut ret = 0;
336
337    for &d in src {
338        if d < b'0' || d > b'9' {
339            return Err(ParseU64Error);
340        }
341
342        ret *= 10;
343        ret += (d - b'0') as u64;
344    }
345
346    Ok(ret)
347}
348
349// ===== impl PushPromise =====
350
351#[derive(Debug)]
352pub enum PushPromiseHeaderError {
353    InvalidContentLength(Result<u64, ParseU64Error>),
354    NotSafeAndCacheable,
355}
356
357impl PushPromise {
358    pub fn new(
359        stream_id: StreamId,
360        promised_id: StreamId,
361        pseudo: Pseudo,
362        fields: HeaderMap,
363    ) -> Self {
364        PushPromise {
365            flags: PushPromiseFlag::default(),
366            header_block: HeaderBlock {
367                field_size: calculate_headermap_size(&fields),
368                fields,
369                is_over_size: false,
370                pseudo,
371            },
372            promised_id,
373            stream_id,
374        }
375    }
376
377    pub fn validate_request(req: &Request<()>) -> Result<(), PushPromiseHeaderError> {
378        use PushPromiseHeaderError::*;
379        // The spec has some requirements for promised request headers
380        // [https://httpwg.org/specs/rfc7540.html#PushRequests]
381
382        // A promised request "that indicates the presence of a request body
383        // MUST reset the promised stream with a stream error"
384        if let Some(content_length) = req.headers().get(header::CONTENT_LENGTH) {
385            let parsed_length = parse_u64(content_length.as_bytes());
386            if parsed_length != Ok(0) {
387                return Err(InvalidContentLength(parsed_length));
388            }
389        }
390        // "The server MUST include a method in the :method pseudo-header field
391        // that is safe and cacheable"
392        if !Self::safe_and_cacheable(req.method()) {
393            return Err(NotSafeAndCacheable);
394        }
395
396        Ok(())
397    }
398
399    fn safe_and_cacheable(method: &Method) -> bool {
400        // Cacheable: https://httpwg.org/specs/rfc7231.html#cacheable.methods
401        // Safe: https://httpwg.org/specs/rfc7231.html#safe.methods
402        method == Method::GET || method == Method::HEAD
403    }
404
405    pub fn fields(&self) -> &HeaderMap {
406        &self.header_block.fields
407    }
408
409    #[cfg(feature = "unstable")]
410    pub fn into_fields(self) -> HeaderMap {
411        self.header_block.fields
412    }
413
414    /// Loads the push promise frame but doesn't actually do HPACK decoding.
415    ///
416    /// HPACK decoding is done in the `load_hpack` step.
417    pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
418        let flags = PushPromiseFlag(head.flag());
419        let mut pad = 0;
420
421        if head.stream_id().is_zero() {
422            return Err(Error::InvalidStreamId);
423        }
424
425        // Read the padding length
426        if flags.is_padded() {
427            if src.is_empty() {
428                return Err(Error::MalformedMessage);
429            }
430
431            // TODO: Ensure payload is sized correctly
432            pad = src[0] as usize;
433
434            // Drop the padding
435            src.advance(1);
436        }
437
438        if src.len() < 5 {
439            return Err(Error::MalformedMessage);
440        }
441
442        let (promised_id, _) = StreamId::parse(&src[..4]);
443        // Drop promised_id bytes
444        src.advance(4);
445
446        if pad > 0 {
447            if pad > src.len() {
448                return Err(Error::TooMuchPadding);
449            }
450
451            let len = src.len() - pad;
452            src.truncate(len);
453        }
454
455        let frame = PushPromise {
456            flags,
457            header_block: HeaderBlock {
458                fields: HeaderMap::new(),
459                field_size: 0,
460                is_over_size: false,
461                pseudo: Pseudo::default(),
462            },
463            promised_id,
464            stream_id: head.stream_id(),
465        };
466        Ok((frame, src))
467    }
468
469    pub fn load_hpack(
470        &mut self,
471        src: &mut BytesMut,
472        max_header_list_size: usize,
473        decoder: &mut hpack::Decoder,
474    ) -> Result<(), Error> {
475        self.header_block.load(src, max_header_list_size, decoder)
476    }
477
478    pub fn stream_id(&self) -> StreamId {
479        self.stream_id
480    }
481
482    pub fn promised_id(&self) -> StreamId {
483        self.promised_id
484    }
485
486    pub fn is_end_headers(&self) -> bool {
487        self.flags.is_end_headers()
488    }
489
490    pub fn set_end_headers(&mut self) {
491        self.flags.set_end_headers();
492    }
493
494    pub fn is_over_size(&self) -> bool {
495        self.header_block.is_over_size
496    }
497
498    pub fn encode(
499        self,
500        encoder: &mut hpack::Encoder,
501        dst: &mut EncodeBuf<'_>,
502    ) -> Option<Continuation> {
503        // At this point, the `is_end_headers` flag should always be set
504        debug_assert!(self.flags.is_end_headers());
505
506        let head = self.head();
507        let promised_id = self.promised_id;
508
509        self.header_block
510            .into_encoding(encoder)
511            .encode(&head, dst, |dst| {
512                dst.put_u32(promised_id.into());
513            })
514    }
515
516    fn head(&self) -> Head {
517        Head::new(Kind::PushPromise, self.flags.into(), self.stream_id)
518    }
519
520    /// Consume `self`, returning the parts of the frame
521    pub fn into_parts(self) -> (Pseudo, HeaderMap) {
522        (self.header_block.pseudo, self.header_block.fields)
523    }
524}
525
526impl<T> From<PushPromise> for Frame<T> {
527    fn from(src: PushPromise) -> Self {
528        Frame::PushPromise(src)
529    }
530}
531
532impl fmt::Debug for PushPromise {
533    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
534        f.debug_struct("PushPromise")
535            .field("stream_id", &self.stream_id)
536            .field("promised_id", &self.promised_id)
537            .field("flags", &self.flags)
538            // `fields` and `pseudo` purposefully not included
539            .finish()
540    }
541}
542
543// ===== impl Continuation =====
544
545impl Continuation {
546    fn head(&self) -> Head {
547        Head::new(Kind::Continuation, END_HEADERS, self.stream_id)
548    }
549
550    pub fn encode(self, dst: &mut EncodeBuf<'_>) -> Option<Continuation> {
551        // Get the CONTINUATION frame head
552        let head = self.head();
553
554        self.header_block.encode(&head, dst, |_| {})
555    }
556}
557
558// ===== impl Pseudo =====
559
560impl Pseudo {
561    pub fn request(method: Method, uri: Uri, protocol: Option<Protocol>) -> Self {
562        let parts = uri::Parts::from(uri);
563
564        let (scheme, path) = if method == Method::CONNECT && protocol.is_none() {
565            (None, None)
566        } else {
567            let path = parts
568                .path_and_query
569                .map(|v| BytesStr::from(v.as_str()))
570                .unwrap_or(BytesStr::from_static(""));
571
572            let path = if !path.is_empty() {
573                path
574            } else if method == Method::OPTIONS {
575                BytesStr::from_static("*")
576            } else {
577                BytesStr::from_static("/")
578            };
579
580            (parts.scheme, Some(path))
581        };
582
583        let mut pseudo = Pseudo {
584            method: Some(method),
585            scheme: None,
586            authority: None,
587            path,
588            protocol,
589            status: None,
590        };
591
592        // If the URI includes a scheme component, add it to the pseudo headers
593        if let Some(scheme) = scheme {
594            pseudo.set_scheme(scheme);
595        }
596
597        // If the URI includes an authority component, add it to the pseudo
598        // headers
599        if let Some(authority) = parts.authority {
600            pseudo.set_authority(BytesStr::from(authority.as_str()));
601        }
602
603        pseudo
604    }
605
606    pub fn response(status: StatusCode) -> Self {
607        Pseudo {
608            method: None,
609            scheme: None,
610            authority: None,
611            path: None,
612            protocol: None,
613            status: Some(status),
614        }
615    }
616
617    #[cfg(feature = "unstable")]
618    pub fn set_status(&mut self, value: StatusCode) {
619        self.status = Some(value);
620    }
621
622    pub fn set_scheme(&mut self, scheme: uri::Scheme) {
623        let bytes_str = match scheme.as_str() {
624            "http" => BytesStr::from_static("http"),
625            "https" => BytesStr::from_static("https"),
626            s => BytesStr::from(s),
627        };
628        self.scheme = Some(bytes_str);
629    }
630
631    #[cfg(feature = "unstable")]
632    pub fn set_protocol(&mut self, protocol: Protocol) {
633        self.protocol = Some(protocol);
634    }
635
636    pub fn set_authority(&mut self, authority: BytesStr) {
637        self.authority = Some(authority);
638    }
639
640    /// Whether it has status 1xx
641    pub(crate) fn is_informational(&self) -> bool {
642        self.status
643            .map_or(false, |status| status.is_informational())
644    }
645}
646
647// ===== impl EncodingHeaderBlock =====
648
649impl EncodingHeaderBlock {
650    fn encode<F>(mut self, head: &Head, dst: &mut EncodeBuf<'_>, f: F) -> Option<Continuation>
651    where
652        F: FnOnce(&mut EncodeBuf<'_>),
653    {
654        let head_pos = dst.get_ref().len();
655
656        // At this point, we don't know how big the h2 frame will be.
657        // So, we write the head with length 0, then write the body, and
658        // finally write the length once we know the size.
659        head.encode(0, dst);
660
661        let payload_pos = dst.get_ref().len();
662
663        f(dst);
664
665        // Now, encode the header payload
666        let continuation = if self.hpack.len() > dst.remaining_mut() {
667            dst.put((&mut self.hpack).take(dst.remaining_mut()));
668
669            Some(Continuation {
670                stream_id: head.stream_id(),
671                header_block: self,
672            })
673        } else {
674            dst.put_slice(&self.hpack);
675
676            None
677        };
678
679        // Compute the header block length
680        let payload_len = (dst.get_ref().len() - payload_pos) as u64;
681
682        // Write the frame length
683        let payload_len_be = payload_len.to_be_bytes();
684        assert!(payload_len_be[0..5].iter().all(|b| *b == 0));
685        (dst.get_mut()[head_pos..head_pos + 3]).copy_from_slice(&payload_len_be[5..]);
686
687        if continuation.is_some() {
688            // There will be continuation frames, so the `is_end_headers` flag
689            // must be unset
690            debug_assert!(dst.get_ref()[head_pos + 4] & END_HEADERS == END_HEADERS);
691
692            dst.get_mut()[head_pos + 4] -= END_HEADERS;
693        }
694
695        continuation
696    }
697}
698
699// ===== impl Iter =====
700
701impl Iterator for Iter {
702    type Item = hpack::Header<Option<HeaderName>>;
703
704    fn next(&mut self) -> Option<Self::Item> {
705        use crate::hpack::Header::*;
706
707        if let Some(ref mut pseudo) = self.pseudo {
708            if let Some(method) = pseudo.method.take() {
709                return Some(Method(method));
710            }
711
712            if let Some(scheme) = pseudo.scheme.take() {
713                return Some(Scheme(scheme));
714            }
715
716            if let Some(authority) = pseudo.authority.take() {
717                return Some(Authority(authority));
718            }
719
720            if let Some(path) = pseudo.path.take() {
721                return Some(Path(path));
722            }
723
724            if let Some(protocol) = pseudo.protocol.take() {
725                return Some(Protocol(protocol));
726            }
727
728            if let Some(status) = pseudo.status.take() {
729                return Some(Status(status));
730            }
731        }
732
733        self.pseudo = None;
734
735        self.fields
736            .next()
737            .map(|(name, value)| Field { name, value })
738    }
739}
740
741// ===== impl HeadersFlag =====
742
743impl HeadersFlag {
744    pub fn empty() -> HeadersFlag {
745        HeadersFlag(0)
746    }
747
748    pub fn load(bits: u8) -> HeadersFlag {
749        HeadersFlag(bits & ALL)
750    }
751
752    pub fn is_end_stream(&self) -> bool {
753        self.0 & END_STREAM == END_STREAM
754    }
755
756    pub fn set_end_stream(&mut self) {
757        self.0 |= END_STREAM;
758    }
759
760    pub fn is_end_headers(&self) -> bool {
761        self.0 & END_HEADERS == END_HEADERS
762    }
763
764    pub fn set_end_headers(&mut self) {
765        self.0 |= END_HEADERS;
766    }
767
768    pub fn is_padded(&self) -> bool {
769        self.0 & PADDED == PADDED
770    }
771
772    pub fn is_priority(&self) -> bool {
773        self.0 & PRIORITY == PRIORITY
774    }
775}
776
777impl Default for HeadersFlag {
778    /// Returns a `HeadersFlag` value with `END_HEADERS` set.
779    fn default() -> Self {
780        HeadersFlag(END_HEADERS)
781    }
782}
783
784impl From<HeadersFlag> for u8 {
785    fn from(src: HeadersFlag) -> u8 {
786        src.0
787    }
788}
789
790impl fmt::Debug for HeadersFlag {
791    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
792        util::debug_flags(fmt, self.0)
793            .flag_if(self.is_end_headers(), "END_HEADERS")
794            .flag_if(self.is_end_stream(), "END_STREAM")
795            .flag_if(self.is_padded(), "PADDED")
796            .flag_if(self.is_priority(), "PRIORITY")
797            .finish()
798    }
799}
800
801// ===== impl PushPromiseFlag =====
802
803impl PushPromiseFlag {
804    pub fn empty() -> PushPromiseFlag {
805        PushPromiseFlag(0)
806    }
807
808    pub fn load(bits: u8) -> PushPromiseFlag {
809        PushPromiseFlag(bits & ALL)
810    }
811
812    pub fn is_end_headers(&self) -> bool {
813        self.0 & END_HEADERS == END_HEADERS
814    }
815
816    pub fn set_end_headers(&mut self) {
817        self.0 |= END_HEADERS;
818    }
819
820    pub fn is_padded(&self) -> bool {
821        self.0 & PADDED == PADDED
822    }
823}
824
825impl Default for PushPromiseFlag {
826    /// Returns a `PushPromiseFlag` value with `END_HEADERS` set.
827    fn default() -> Self {
828        PushPromiseFlag(END_HEADERS)
829    }
830}
831
832impl From<PushPromiseFlag> for u8 {
833    fn from(src: PushPromiseFlag) -> u8 {
834        src.0
835    }
836}
837
838impl fmt::Debug for PushPromiseFlag {
839    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
840        util::debug_flags(fmt, self.0)
841            .flag_if(self.is_end_headers(), "END_HEADERS")
842            .flag_if(self.is_padded(), "PADDED")
843            .finish()
844    }
845}
846
847// ===== HeaderBlock =====
848
849impl HeaderBlock {
850    fn load(
851        &mut self,
852        src: &mut BytesMut,
853        max_header_list_size: usize,
854        decoder: &mut hpack::Decoder,
855    ) -> Result<(), Error> {
856        let mut reg = !self.fields.is_empty();
857        let mut malformed = false;
858        let mut header_list_way_too_large = false;
859        let mut headers_size = self.calculate_header_list_size();
860        let max_header_list_abuse_size =
861            max_header_list_size.saturating_mul(MAX_HEADER_LIST_ABUSE_MULTIPLIER);
862
863        macro_rules! check_size {
864            () => {{
865                if headers_size > max_header_list_abuse_size {
866                    tracing::trace!("load_hpack; header list size over abuse max");
867                    header_list_way_too_large = true;
868                    ControlFlow::Break(())
869                } else {
870                    if headers_size >= max_header_list_size && !self.is_over_size {
871                        tracing::trace!("load_hpack; header list size over max");
872                        self.is_over_size = true;
873                    }
874                    ControlFlow::Continue(())
875                }
876            }};
877        }
878
879        macro_rules! set_pseudo {
880            ($field:ident, $val:expr) => {{
881                if reg {
882                    tracing::trace!("load_hpack; header malformed -- pseudo not at head of block");
883                    malformed = true;
884                } else if self.pseudo.$field.is_some() {
885                    tracing::trace!("load_hpack; header malformed -- repeated pseudo");
886                    malformed = true;
887                } else {
888                    let __val = $val;
889                    headers_size +=
890                        decoded_header_size(stringify!($field).len() + 1, __val.as_str().len());
891                    if check_size!().is_break() {
892                        return ControlFlow::Break(());
893                    }
894                    if !self.is_over_size {
895                        self.pseudo.$field = Some(__val);
896                    }
897                }
898            }};
899        }
900
901        let mut cursor = Cursor::new(src);
902
903        // If the header frame is malformed, we still have to continue decoding
904        // the headers. A malformed header frame is a stream level error, but
905        // the hpack state is connection level. In order to maintain correct
906        // state for other streams, the hpack decoding process must complete.
907        let res = decoder.decode(&mut cursor, |header| {
908            use crate::hpack::Header::*;
909
910            match header {
911                Field { name, value } => {
912                    // Connection level header fields are not supported and must
913                    // result in a protocol error.
914
915                    if name == header::CONNECTION
916                        || name == header::TRANSFER_ENCODING
917                        || name == header::UPGRADE
918                        || name == "keep-alive"
919                        || name == "proxy-connection"
920                    {
921                        tracing::trace!("load_hpack; connection level header");
922                        malformed = true;
923                    } else if name == header::TE && value != "trailers" {
924                        tracing::trace!(
925                            "load_hpack; TE header not set to trailers; val={:?}",
926                            value
927                        );
928                        malformed = true;
929                    } else {
930                        reg = true;
931
932                        let header_size = decoded_header_size(name.as_str().len(), value.len());
933                        headers_size += header_size;
934                        if check_size!().is_break() {
935                            return ControlFlow::Break(());
936                        }
937                        if !self.is_over_size {
938                            self.field_size += header_size;
939                            if let Err(_) = self.fields.try_append(name, value) {
940                                // HeaderMap capacity exceeded — treat as over-size
941                                // so the stream is rejected downstream (RST_STREAM / 431)
942                                // instead of panicking on the 24,577th unique header.
943                                self.is_over_size = true;
944                            }
945                        }
946                    }
947                }
948                Authority(v) => set_pseudo!(authority, v),
949                Method(v) => set_pseudo!(method, v),
950                Scheme(v) => set_pseudo!(scheme, v),
951                Path(v) => set_pseudo!(path, v),
952                Protocol(v) => set_pseudo!(protocol, v),
953                Status(v) => set_pseudo!(status, v),
954            }
955
956            ControlFlow::Continue(())
957        });
958
959        match res {
960            Ok(()) => {}
961            Err(e) => {
962                tracing::trace!("hpack decoding error; err={:?}", e);
963                return Err(e.into());
964            }
965        }
966
967        if header_list_way_too_large {
968            tracing::trace!("header list way too large; aborting connection");
969            return Err(Error::HeaderListWayTooLarge);
970        }
971
972        if malformed {
973            tracing::trace!("malformed message");
974            return Err(Error::MalformedMessage);
975        }
976
977        Ok(())
978    }
979
980    fn into_encoding(self, encoder: &mut hpack::Encoder) -> EncodingHeaderBlock {
981        let mut hpack = BytesMut::new();
982        let headers = Iter {
983            pseudo: Some(self.pseudo),
984            fields: self.fields.into_iter(),
985        };
986
987        encoder.encode(headers, &mut hpack);
988
989        EncodingHeaderBlock {
990            hpack: hpack.freeze(),
991        }
992    }
993
994    /// Calculates the size of the currently decoded header list.
995    ///
996    /// According to http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE
997    ///
998    /// > The value is based on the uncompressed size of header fields,
999    /// > including the length of the name and value in octets plus an
1000    /// > overhead of 32 octets for each header field.
1001    fn calculate_header_list_size(&self) -> usize {
1002        macro_rules! pseudo_size {
1003            ($name:ident) => {{
1004                self.pseudo
1005                    .$name
1006                    .as_ref()
1007                    .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len()))
1008                    .unwrap_or(0)
1009            }};
1010        }
1011
1012        pseudo_size!(method)
1013            + pseudo_size!(scheme)
1014            + pseudo_size!(status)
1015            + pseudo_size!(authority)
1016            + pseudo_size!(path)
1017            + self.field_size
1018    }
1019}
1020
1021fn calculate_headermap_size(map: &HeaderMap) -> usize {
1022    map.iter()
1023        .map(|(name, value)| decoded_header_size(name.as_str().len(), value.len()))
1024        .sum::<usize>()
1025}
1026
1027fn decoded_header_size(name: usize, value: usize) -> usize {
1028    name + value + 32
1029}
1030
1031#[cfg(test)]
1032mod test {
1033    use super::*;
1034    use crate::frame;
1035    use crate::hpack::{huffman, Encoder};
1036
1037    #[test]
1038    fn test_nameless_header_at_resume() {
1039        let mut encoder = Encoder::default();
1040        let mut dst = BytesMut::new();
1041
1042        let headers = Headers::new(
1043            StreamId::ZERO,
1044            Default::default(),
1045            HeaderMap::from_iter(vec![
1046                (
1047                    HeaderName::from_static("hello"),
1048                    HeaderValue::from_static("world"),
1049                ),
1050                (
1051                    HeaderName::from_static("hello"),
1052                    HeaderValue::from_static("zomg"),
1053                ),
1054                (
1055                    HeaderName::from_static("hello"),
1056                    HeaderValue::from_static("sup"),
1057                ),
1058            ]),
1059        );
1060
1061        let continuation = headers
1062            .encode(&mut encoder, &mut (&mut dst).limit(frame::HEADER_LEN + 8))
1063            .unwrap();
1064
1065        assert_eq!(17, dst.len());
1066        assert_eq!([0, 0, 8, 1, 0, 0, 0, 0, 0], &dst[0..9]);
1067        assert_eq!(&[0x40, 0x80 | 4], &dst[9..11]);
1068        assert_eq!("hello", huff_decode(&dst[11..15]));
1069        assert_eq!(0x80 | 4, dst[15]);
1070
1071        let mut world = dst[16..17].to_owned();
1072
1073        dst.clear();
1074
1075        assert!(continuation
1076            .encode(&mut (&mut dst).limit(frame::HEADER_LEN + 16))
1077            .is_none());
1078
1079        world.extend_from_slice(&dst[9..12]);
1080        assert_eq!("world", huff_decode(&world));
1081
1082        assert_eq!(24, dst.len());
1083        assert_eq!([0, 0, 15, 9, 4, 0, 0, 0, 0], &dst[0..9]);
1084
1085        // // Next is not indexed
1086        assert_eq!(&[15, 47, 0x80 | 3], &dst[12..15]);
1087        assert_eq!("zomg", huff_decode(&dst[15..18]));
1088        assert_eq!(&[15, 47, 0x80 | 3], &dst[18..21]);
1089        assert_eq!("sup", huff_decode(&dst[21..]));
1090    }
1091
1092    fn huff_decode(src: &[u8]) -> BytesMut {
1093        let mut buf = BytesMut::new();
1094        huffman::decode(src, &mut buf).unwrap()
1095    }
1096
1097    #[test]
1098    fn test_connect_request_pseudo_headers_omits_path_and_scheme() {
1099        // CONNECT requests MUST NOT include :scheme & :path pseudo-header fields
1100        // See: https://datatracker.ietf.org/doc/html/rfc9113#section-8.5
1101
1102        assert_eq!(
1103            Pseudo::request(
1104                Method::CONNECT,
1105                Uri::from_static("https://example.com:8443"),
1106                None
1107            ),
1108            Pseudo {
1109                method: Method::CONNECT.into(),
1110                authority: BytesStr::from_static("example.com:8443").into(),
1111                ..Default::default()
1112            }
1113        );
1114
1115        assert_eq!(
1116            Pseudo::request(
1117                Method::CONNECT,
1118                Uri::from_static("https://example.com/test"),
1119                None
1120            ),
1121            Pseudo {
1122                method: Method::CONNECT.into(),
1123                authority: BytesStr::from_static("example.com").into(),
1124                ..Default::default()
1125            }
1126        );
1127
1128        assert_eq!(
1129            Pseudo::request(Method::CONNECT, Uri::from_static("example.com:8443"), None),
1130            Pseudo {
1131                method: Method::CONNECT.into(),
1132                authority: BytesStr::from_static("example.com:8443").into(),
1133                ..Default::default()
1134            }
1135        );
1136    }
1137
1138    #[test]
1139    fn test_extended_connect_request_pseudo_headers_includes_path_and_scheme() {
1140        // On requests that contain the :protocol pseudo-header field, the
1141        // :scheme and :path pseudo-header fields of the target URI (see
1142        // Section 5) MUST also be included.
1143        // See: https://datatracker.ietf.org/doc/html/rfc8441#section-4
1144
1145        assert_eq!(
1146            Pseudo::request(
1147                Method::CONNECT,
1148                Uri::from_static("https://example.com:8443"),
1149                Protocol::from_static("the-bread-protocol").into()
1150            ),
1151            Pseudo {
1152                method: Method::CONNECT.into(),
1153                authority: BytesStr::from_static("example.com:8443").into(),
1154                scheme: BytesStr::from_static("https").into(),
1155                path: BytesStr::from_static("/").into(),
1156                protocol: Protocol::from_static("the-bread-protocol").into(),
1157                ..Default::default()
1158            }
1159        );
1160
1161        assert_eq!(
1162            Pseudo::request(
1163                Method::CONNECT,
1164                Uri::from_static("https://example.com:8443/test"),
1165                Protocol::from_static("the-bread-protocol").into()
1166            ),
1167            Pseudo {
1168                method: Method::CONNECT.into(),
1169                authority: BytesStr::from_static("example.com:8443").into(),
1170                scheme: BytesStr::from_static("https").into(),
1171                path: BytesStr::from_static("/test").into(),
1172                protocol: Protocol::from_static("the-bread-protocol").into(),
1173                ..Default::default()
1174            }
1175        );
1176
1177        assert_eq!(
1178            Pseudo::request(
1179                Method::CONNECT,
1180                Uri::from_static("http://example.com/a/b/c"),
1181                Protocol::from_static("the-bread-protocol").into()
1182            ),
1183            Pseudo {
1184                method: Method::CONNECT.into(),
1185                authority: BytesStr::from_static("example.com").into(),
1186                scheme: BytesStr::from_static("http").into(),
1187                path: BytesStr::from_static("/a/b/c").into(),
1188                protocol: Protocol::from_static("the-bread-protocol").into(),
1189                ..Default::default()
1190            }
1191        );
1192    }
1193
1194    #[test]
1195    fn test_options_request_with_empty_path_has_asterisk_as_pseudo_path() {
1196        // an OPTIONS request for an "http" or "https" URI that does not include a path component;
1197        // these MUST include a ":path" pseudo-header field with a value of '*' (see Section 7.1 of [HTTP]).
1198        // See: https://datatracker.ietf.org/doc/html/rfc9113#section-8.3.1
1199        assert_eq!(
1200            Pseudo::request(Method::OPTIONS, Uri::from_static("example.com:8080"), None,),
1201            Pseudo {
1202                method: Method::OPTIONS.into(),
1203                authority: BytesStr::from_static("example.com:8080").into(),
1204                path: BytesStr::from_static("*").into(),
1205                ..Default::default()
1206            }
1207        );
1208    }
1209
1210    #[test]
1211    fn test_try_append_prevents_panic_on_max_size_reached() {
1212        // Verify that decoding >24,577 unique headers sets `is_over_size`
1213        // instead of panicking via HeaderMap::append().
1214        //
1215        // HeaderMap::MAX_SIZE = 32,768. With 75% load factor, max entries = 24,576.
1216        // try_append returns Err(MaxSizeReached) at entry 24,577.
1217        // Before the fix (using append), this panicked.
1218        //
1219        // We manually construct HPACK bytes for 25,000 unique headers because
1220        // creating a HeaderMap with that many entries also panics on construction.
1221
1222        // Build HPACK-encoded block:
1223        // Pseudo-headers (indexed refs to static table):
1224        //   :method GET         → 0x82 (static index 2)
1225        //   :scheme http        → 0x86 (static index 6)
1226        //   :path /             → 0x84 (static index 4)
1227        //   :authority "localhost" → literal with indexing (name index 0)
1228        //
1229        // Then 25,000 unique headers: "literal without indexing, new name"
1230        //   0x00 → literal without indexing, name index 0
1231        //   <name_len> <name_bytes>
1232        //   <value_len> <value_bytes>
1233
1234        let num_headers = 25_000;
1235
1236        // Build the HPACK block
1237        let mut hpack = Vec::new();
1238
1239        // Pseudo-headers
1240        hpack.push(0x82u8); // :method GET (static index 2)
1241        hpack.push(0x86); // :scheme http (static index 6)
1242        hpack.push(0x84); // :path / (static index 4)
1243
1244        // :authority "localhost" — literal with incremental indexing
1245        hpack.push(0x41); // literal with indexing, name index 1 (= ":authority")
1246        hpack.push(0x09); // value length 9
1247        hpack.extend_from_slice(b"localhost");
1248
1249        // 25,000 unique headers: "literal without indexing, new name"
1250        // Format: 0x00 + name_len + name + value_len + value
1251        for i in 0..num_headers {
1252            let name = format!("x-h-{i}");
1253            hpack.push(0x00u8); // literal without indexing, name index 0
1254            hpack.push(name.len() as u8);
1255            hpack.extend_from_slice(name.as_bytes());
1256            hpack.push(1u8); // value length 1
1257            hpack.push(b'v');
1258        }
1259
1260        // Build the HTTP/2 HEADERS frame: 9-byte header + HPACK payload
1261        let payload_len = hpack.len();
1262        let mut frame = BytesMut::with_capacity(9 + payload_len);
1263
1264        // Frame header: 3 bytes length, 1 byte type (0x01=HEADERS), 1 byte flags, 4 bytes stream_id
1265        frame.put_u8(((payload_len >> 16) & 0xFF) as u8);
1266        frame.put_u8(((payload_len >> 8) & 0xFF) as u8);
1267        frame.put_u8((payload_len & 0xFF) as u8);
1268        frame.put_u8(0x01); // type: HEADERS
1269        frame.put_u8(0x04); // flags: END_HEADERS
1270        frame.put_u32(1); // stream_id: 1
1271
1272        frame.extend_from_slice(&hpack);
1273
1274        // Parse the HEADERS frame
1275        let head = Head::parse(&frame[..9]);
1276        let payload = BytesMut::from(&frame[9..]);
1277        let (mut headers, mut hpack_data) = Headers::load(head, payload).unwrap();
1278        // hpack_data contains the HPACK payload (no padding/priority in our frame)
1279
1280        // Decode the HPACK block — this should NOT panic
1281        let mut decoder = hpack::Decoder::new(4096);
1282        const DEFAULT_MAX_HEADER_LIST_SIZE: usize = 16 << 20; // 16 MB
1283        headers
1284            .load_hpack(&mut hpack_data, DEFAULT_MAX_HEADER_LIST_SIZE, &mut decoder)
1285            .expect("load_hpack should return Ok");
1286
1287        // Verify that is_over_size was set (try_append returned Err)
1288        assert!(
1289            headers.is_over_size(),
1290            "is_over_size should be true when HeaderMap capacity is exceeded"
1291        );
1292    }
1293
1294    #[test]
1295    fn test_non_option_and_non_connect_requests_include_path_and_scheme() {
1296        let methods = [
1297            Method::GET,
1298            Method::POST,
1299            Method::PUT,
1300            Method::DELETE,
1301            Method::HEAD,
1302            Method::PATCH,
1303            Method::TRACE,
1304        ];
1305
1306        for method in methods {
1307            assert_eq!(
1308                Pseudo::request(
1309                    method.clone(),
1310                    Uri::from_static("http://example.com:8080"),
1311                    None,
1312                ),
1313                Pseudo {
1314                    method: method.clone().into(),
1315                    authority: BytesStr::from_static("example.com:8080").into(),
1316                    scheme: BytesStr::from_static("http").into(),
1317                    path: BytesStr::from_static("/").into(),
1318                    ..Default::default()
1319                }
1320            );
1321            assert_eq!(
1322                Pseudo::request(
1323                    method.clone(),
1324                    Uri::from_static("https://example.com/a/b/c"),
1325                    None,
1326                ),
1327                Pseudo {
1328                    method: method.into(),
1329                    authority: BytesStr::from_static("example.com").into(),
1330                    scheme: BytesStr::from_static("https").into(),
1331                    path: BytesStr::from_static("/a/b/c").into(),
1332                    ..Default::default()
1333                }
1334            );
1335        }
1336    }
1337}