hyper/proto/h1/
encode.rs

1use std::collections::HashSet;
2use std::fmt;
3use std::io::IoSlice;
4
5use bytes::buf::{Chain, Take};
6use bytes::{Buf, Bytes};
7use http::{
8    header::{
9        AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE,
10        CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TE, TRAILER, TRANSFER_ENCODING,
11    },
12    HeaderMap, HeaderName,
13};
14
15use super::io::WriteBuf;
16use super::role::{write_headers, write_headers_title_case};
17
18type StaticBuf = &'static [u8];
19
20/// Encoders to handle different Transfer-Encodings.
21#[derive(Debug, Clone, PartialEq)]
22pub(crate) struct Encoder {
23    kind: Kind,
24    is_last: bool,
25}
26
27#[derive(Debug)]
28pub(crate) struct EncodedBuf<B> {
29    kind: BufKind<B>,
30}
31
32#[derive(Debug)]
33pub(crate) struct NotEof(u64);
34
35#[derive(Debug, PartialEq, Clone)]
36enum Kind {
37    /// An Encoder for when Transfer-Encoding includes `chunked`.
38    Chunked(Option<Vec<HeaderName>>),
39    /// An Encoder for when Content-Length is set.
40    ///
41    /// Enforces that the body is not longer than the Content-Length header.
42    Length(u64),
43    /// An Encoder for when neither Content-Length nor Chunked encoding is set.
44    ///
45    /// This is mostly only used with HTTP/1.0 with a length. This kind requires
46    /// the connection to be closed when the body is finished.
47    #[cfg(feature = "server")]
48    CloseDelimited,
49}
50
51#[derive(Debug)]
52enum BufKind<B> {
53    Exact(B),
54    Limited(Take<B>),
55    Chunked(Chain<Chain<ChunkSize, B>, StaticBuf>),
56    ChunkedEnd(StaticBuf),
57    Trailers(Chain<Chain<StaticBuf, Bytes>, StaticBuf>),
58}
59
60impl Encoder {
61    fn new(kind: Kind) -> Encoder {
62        Encoder {
63            kind,
64            is_last: false,
65        }
66    }
67    pub(crate) fn chunked() -> Encoder {
68        Encoder::new(Kind::Chunked(None))
69    }
70
71    pub(crate) fn length(len: u64) -> Encoder {
72        Encoder::new(Kind::Length(len))
73    }
74
75    #[cfg(feature = "server")]
76    pub(crate) fn close_delimited() -> Encoder {
77        Encoder::new(Kind::CloseDelimited)
78    }
79
80    pub(crate) fn into_chunked_with_trailing_fields(self, trailers: Vec<HeaderName>) -> Encoder {
81        match self.kind {
82            Kind::Chunked(_) => Encoder {
83                kind: Kind::Chunked(Some(trailers)),
84                is_last: self.is_last,
85            },
86            _ => self,
87        }
88    }
89
90    pub(crate) fn is_eof(&self) -> bool {
91        matches!(self.kind, Kind::Length(0))
92    }
93
94    #[cfg(feature = "server")]
95    pub(crate) fn set_last(mut self, is_last: bool) -> Self {
96        self.is_last = is_last;
97        self
98    }
99
100    pub(crate) fn is_last(&self) -> bool {
101        self.is_last
102    }
103
104    pub(crate) fn is_close_delimited(&self) -> bool {
105        match self.kind {
106            #[cfg(feature = "server")]
107            Kind::CloseDelimited => true,
108            _ => false,
109        }
110    }
111
112    pub(crate) fn is_chunked(&self) -> bool {
113        matches!(self.kind, Kind::Chunked(_))
114    }
115
116    pub(crate) fn end<B>(&self) -> Result<Option<EncodedBuf<B>>, NotEof> {
117        match self.kind {
118            Kind::Length(0) => Ok(None),
119            Kind::Chunked(_) => Ok(Some(EncodedBuf {
120                kind: BufKind::ChunkedEnd(b"0\r\n\r\n"),
121            })),
122            #[cfg(feature = "server")]
123            Kind::CloseDelimited => Ok(None),
124            Kind::Length(n) => Err(NotEof(n)),
125        }
126    }
127
128    pub(crate) fn encode<B>(&mut self, msg: B) -> EncodedBuf<B>
129    where
130        B: Buf,
131    {
132        let len = msg.remaining();
133        debug_assert!(len > 0, "encode() called with empty buf");
134
135        let kind = match self.kind {
136            Kind::Chunked(_) => {
137                trace!("encoding chunked {}B", len);
138                let buf = ChunkSize::new(len)
139                    .chain(msg)
140                    .chain(b"\r\n" as &'static [u8]);
141                BufKind::Chunked(buf)
142            }
143            Kind::Length(ref mut remaining) => {
144                trace!("sized write, len = {}", len);
145                if len as u64 > *remaining {
146                    let limit = *remaining as usize;
147                    *remaining = 0;
148                    BufKind::Limited(msg.take(limit))
149                } else {
150                    *remaining -= len as u64;
151                    BufKind::Exact(msg)
152                }
153            }
154            #[cfg(feature = "server")]
155            Kind::CloseDelimited => {
156                trace!("close delimited write {}B", len);
157                BufKind::Exact(msg)
158            }
159        };
160        EncodedBuf { kind }
161    }
162
163    pub(crate) fn encode_trailers<B>(
164        &self,
165        trailers: HeaderMap,
166        title_case_headers: bool,
167    ) -> Option<EncodedBuf<B>> {
168        trace!("encoding trailers");
169        match &self.kind {
170            Kind::Chunked(Some(allowed_trailer_fields)) => {
171                let allowed_set: HashSet<&HeaderName> = allowed_trailer_fields.iter().collect();
172
173                let mut cur_name = None;
174                let mut allowed_trailers = HeaderMap::new();
175
176                for (opt_name, value) in trailers {
177                    if let Some(n) = opt_name {
178                        cur_name = Some(n);
179                    }
180                    let name = cur_name.as_ref().expect("current header name");
181
182                    if allowed_set.contains(name) {
183                        if is_valid_trailer_field(name) {
184                            allowed_trailers.insert(name, value);
185                        } else {
186                            debug!("trailer field is not valid: {}", &name);
187                        }
188                    } else {
189                        debug!("trailer header name not found in trailer header: {}", &name);
190                    }
191                }
192
193                let mut buf = Vec::new();
194                if title_case_headers {
195                    write_headers_title_case(&allowed_trailers, &mut buf);
196                } else {
197                    write_headers(&allowed_trailers, &mut buf);
198                }
199
200                if buf.is_empty() {
201                    return None;
202                }
203
204                Some(EncodedBuf {
205                    kind: BufKind::Trailers(b"0\r\n".chain(Bytes::from(buf)).chain(b"\r\n")),
206                })
207            }
208            Kind::Chunked(None) => {
209                debug!("attempted to encode trailers, but the trailer header is not set");
210                None
211            }
212            _ => {
213                debug!("attempted to encode trailers for non-chunked response");
214                None
215            }
216        }
217    }
218
219    pub(super) fn encode_and_end<B>(&self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) -> bool
220    where
221        B: Buf,
222    {
223        let len = msg.remaining();
224        debug_assert!(len > 0, "encode() called with empty buf");
225
226        match self.kind {
227            Kind::Chunked(_) => {
228                trace!("encoding chunked {}B", len);
229                let buf = ChunkSize::new(len)
230                    .chain(msg)
231                    .chain(b"\r\n0\r\n\r\n" as &'static [u8]);
232                dst.buffer(buf);
233                !self.is_last
234            }
235            Kind::Length(remaining) => {
236                use std::cmp::Ordering;
237
238                trace!("sized write, len = {}", len);
239                match (len as u64).cmp(&remaining) {
240                    Ordering::Equal => {
241                        dst.buffer(msg);
242                        !self.is_last
243                    }
244                    Ordering::Greater => {
245                        dst.buffer(msg.take(remaining as usize));
246                        !self.is_last
247                    }
248                    Ordering::Less => {
249                        dst.buffer(msg);
250                        false
251                    }
252                }
253            }
254            #[cfg(feature = "server")]
255            Kind::CloseDelimited => {
256                trace!("close delimited write {}B", len);
257                dst.buffer(msg);
258                false
259            }
260        }
261    }
262}
263
264fn is_valid_trailer_field(name: &HeaderName) -> bool {
265    !matches!(
266        *name,
267        AUTHORIZATION
268            | CACHE_CONTROL
269            | CONTENT_ENCODING
270            | CONTENT_LENGTH
271            | CONTENT_RANGE
272            | CONTENT_TYPE
273            | HOST
274            | MAX_FORWARDS
275            | SET_COOKIE
276            | TRAILER
277            | TRANSFER_ENCODING
278            | TE
279    )
280}
281
282impl<B> Buf for EncodedBuf<B>
283where
284    B: Buf,
285{
286    #[inline]
287    fn remaining(&self) -> usize {
288        match self.kind {
289            BufKind::Exact(ref b) => b.remaining(),
290            BufKind::Limited(ref b) => b.remaining(),
291            BufKind::Chunked(ref b) => b.remaining(),
292            BufKind::ChunkedEnd(ref b) => b.remaining(),
293            BufKind::Trailers(ref b) => b.remaining(),
294        }
295    }
296
297    #[inline]
298    fn chunk(&self) -> &[u8] {
299        match self.kind {
300            BufKind::Exact(ref b) => b.chunk(),
301            BufKind::Limited(ref b) => b.chunk(),
302            BufKind::Chunked(ref b) => b.chunk(),
303            BufKind::ChunkedEnd(ref b) => b.chunk(),
304            BufKind::Trailers(ref b) => b.chunk(),
305        }
306    }
307
308    #[inline]
309    fn advance(&mut self, cnt: usize) {
310        match self.kind {
311            BufKind::Exact(ref mut b) => b.advance(cnt),
312            BufKind::Limited(ref mut b) => b.advance(cnt),
313            BufKind::Chunked(ref mut b) => b.advance(cnt),
314            BufKind::ChunkedEnd(ref mut b) => b.advance(cnt),
315            BufKind::Trailers(ref mut b) => b.advance(cnt),
316        }
317    }
318
319    #[inline]
320    fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize {
321        match self.kind {
322            BufKind::Exact(ref b) => b.chunks_vectored(dst),
323            BufKind::Limited(ref b) => b.chunks_vectored(dst),
324            BufKind::Chunked(ref b) => b.chunks_vectored(dst),
325            BufKind::ChunkedEnd(ref b) => b.chunks_vectored(dst),
326            BufKind::Trailers(ref b) => b.chunks_vectored(dst),
327        }
328    }
329}
330
331#[cfg(target_pointer_width = "32")]
332const USIZE_BYTES: usize = 4;
333
334#[cfg(target_pointer_width = "64")]
335const USIZE_BYTES: usize = 8;
336
337// each byte will become 2 hex
338const CHUNK_SIZE_MAX_BYTES: usize = USIZE_BYTES * 2;
339
340#[derive(Clone, Copy)]
341struct ChunkSize {
342    bytes: [u8; CHUNK_SIZE_MAX_BYTES + 2],
343    pos: u8,
344    len: u8,
345}
346
347impl ChunkSize {
348    fn new(len: usize) -> ChunkSize {
349        use std::fmt::Write;
350        let mut size = ChunkSize {
351            bytes: [0; CHUNK_SIZE_MAX_BYTES + 2],
352            pos: 0,
353            len: 0,
354        };
355        write!(&mut size, "{:X}\r\n", len).expect("CHUNK_SIZE_MAX_BYTES should fit any usize");
356        size
357    }
358}
359
360impl Buf for ChunkSize {
361    #[inline]
362    fn remaining(&self) -> usize {
363        (self.len - self.pos).into()
364    }
365
366    #[inline]
367    fn chunk(&self) -> &[u8] {
368        &self.bytes[self.pos.into()..self.len.into()]
369    }
370
371    #[inline]
372    fn advance(&mut self, cnt: usize) {
373        assert!(cnt <= self.remaining());
374        self.pos += cnt as u8; // just asserted cnt fits in u8
375    }
376}
377
378impl fmt::Debug for ChunkSize {
379    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
380        f.debug_struct("ChunkSize")
381            .field("bytes", &&self.bytes[..self.len.into()])
382            .field("pos", &self.pos)
383            .finish()
384    }
385}
386
387impl fmt::Write for ChunkSize {
388    fn write_str(&mut self, num: &str) -> fmt::Result {
389        use std::io::Write;
390        (&mut self.bytes[self.len.into()..])
391            .write_all(num.as_bytes())
392            .expect("&mut [u8].write() cannot error");
393        self.len += num.len() as u8; // safe because bytes is never bigger than 256
394        Ok(())
395    }
396}
397
398impl<B: Buf> From<B> for EncodedBuf<B> {
399    fn from(buf: B) -> Self {
400        EncodedBuf {
401            kind: BufKind::Exact(buf),
402        }
403    }
404}
405
406impl<B: Buf> From<Take<B>> for EncodedBuf<B> {
407    fn from(buf: Take<B>) -> Self {
408        EncodedBuf {
409            kind: BufKind::Limited(buf),
410        }
411    }
412}
413
414impl<B: Buf> From<Chain<Chain<ChunkSize, B>, StaticBuf>> for EncodedBuf<B> {
415    fn from(buf: Chain<Chain<ChunkSize, B>, StaticBuf>) -> Self {
416        EncodedBuf {
417            kind: BufKind::Chunked(buf),
418        }
419    }
420}
421
422impl fmt::Display for NotEof {
423    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
424        write!(f, "early end, expected {} more bytes", self.0)
425    }
426}
427
428impl std::error::Error for NotEof {}
429
430#[cfg(test)]
431mod tests {
432    use bytes::BufMut;
433    use http::{
434        header::{
435            AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE,
436            CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TE, TRAILER, TRANSFER_ENCODING,
437        },
438        HeaderMap, HeaderName, HeaderValue,
439    };
440
441    use super::super::io::Cursor;
442    use super::Encoder;
443
444    #[test]
445    fn chunked() {
446        let mut encoder = Encoder::chunked();
447        let mut dst = Vec::new();
448
449        let msg1 = b"foo bar".as_ref();
450        let buf1 = encoder.encode(msg1);
451        dst.put(buf1);
452        assert_eq!(dst, b"7\r\nfoo bar\r\n");
453
454        let msg2 = b"baz quux herp".as_ref();
455        let buf2 = encoder.encode(msg2);
456        dst.put(buf2);
457
458        assert_eq!(dst, b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n");
459
460        let end = encoder.end::<Cursor<Vec<u8>>>().unwrap().unwrap();
461        dst.put(end);
462
463        assert_eq!(
464            dst,
465            b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n0\r\n\r\n".as_ref()
466        );
467    }
468
469    #[test]
470    fn length() {
471        let max_len = 8;
472        let mut encoder = Encoder::length(max_len as u64);
473        let mut dst = Vec::new();
474
475        let msg1 = b"foo bar".as_ref();
476        let buf1 = encoder.encode(msg1);
477        dst.put(buf1);
478
479        assert_eq!(dst, b"foo bar");
480        assert!(!encoder.is_eof());
481        encoder.end::<()>().unwrap_err();
482
483        let msg2 = b"baz".as_ref();
484        let buf2 = encoder.encode(msg2);
485        dst.put(buf2);
486
487        assert_eq!(dst.len(), max_len);
488        assert_eq!(dst, b"foo barb");
489        assert!(encoder.is_eof());
490        assert!(encoder.end::<()>().unwrap().is_none());
491    }
492
493    #[cfg(feature = "server")]
494    #[test]
495    fn eof() {
496        let mut encoder = Encoder::close_delimited();
497        let mut dst = Vec::new();
498
499        let msg1 = b"foo bar".as_ref();
500        let buf1 = encoder.encode(msg1);
501        dst.put(buf1);
502
503        assert_eq!(dst, b"foo bar");
504        assert!(!encoder.is_eof());
505        encoder.end::<()>().unwrap();
506
507        let msg2 = b"baz".as_ref();
508        let buf2 = encoder.encode(msg2);
509        dst.put(buf2);
510
511        assert_eq!(dst, b"foo barbaz");
512        assert!(!encoder.is_eof());
513        encoder.end::<()>().unwrap();
514    }
515
516    #[test]
517    fn chunked_with_valid_trailers() {
518        let encoder = Encoder::chunked();
519        let trailers = vec![HeaderName::from_static("chunky-trailer")];
520        let encoder = encoder.into_chunked_with_trailing_fields(trailers);
521
522        let headers = HeaderMap::from_iter(vec![
523            (
524                HeaderName::from_static("chunky-trailer"),
525                HeaderValue::from_static("header data"),
526            ),
527            (
528                HeaderName::from_static("should-not-be-included"),
529                HeaderValue::from_static("oops"),
530            ),
531        ]);
532
533        let buf1 = encoder.encode_trailers::<&[u8]>(headers, false).unwrap();
534
535        let mut dst = Vec::new();
536        dst.put(buf1);
537        assert_eq!(dst, b"0\r\nchunky-trailer: header data\r\n\r\n");
538    }
539
540    #[test]
541    fn chunked_with_multiple_trailer_headers() {
542        let encoder = Encoder::chunked();
543        let trailers = vec![
544            HeaderName::from_static("chunky-trailer"),
545            HeaderName::from_static("chunky-trailer-2"),
546        ];
547        let encoder = encoder.into_chunked_with_trailing_fields(trailers);
548
549        let headers = HeaderMap::from_iter(vec![
550            (
551                HeaderName::from_static("chunky-trailer"),
552                HeaderValue::from_static("header data"),
553            ),
554            (
555                HeaderName::from_static("chunky-trailer-2"),
556                HeaderValue::from_static("more header data"),
557            ),
558        ]);
559
560        let buf1 = encoder.encode_trailers::<&[u8]>(headers, false).unwrap();
561
562        let mut dst = Vec::new();
563        dst.put(buf1);
564        assert_eq!(
565            dst,
566            b"0\r\nchunky-trailer: header data\r\nchunky-trailer-2: more header data\r\n\r\n"
567        );
568    }
569
570    #[test]
571    fn chunked_with_no_trailer_header() {
572        let encoder = Encoder::chunked();
573
574        let headers = HeaderMap::from_iter(vec![(
575            HeaderName::from_static("chunky-trailer"),
576            HeaderValue::from_static("header data"),
577        )]);
578
579        assert!(encoder
580            .encode_trailers::<&[u8]>(headers.clone(), false)
581            .is_none());
582
583        let trailers = vec![];
584        let encoder = encoder.into_chunked_with_trailing_fields(trailers);
585
586        assert!(encoder.encode_trailers::<&[u8]>(headers, false).is_none());
587    }
588
589    #[test]
590    fn chunked_with_invalid_trailers() {
591        let encoder = Encoder::chunked();
592
593        let trailers = vec![
594            AUTHORIZATION,
595            CACHE_CONTROL,
596            CONTENT_ENCODING,
597            CONTENT_LENGTH,
598            CONTENT_RANGE,
599            CONTENT_TYPE,
600            HOST,
601            MAX_FORWARDS,
602            SET_COOKIE,
603            TRAILER,
604            TRANSFER_ENCODING,
605            TE,
606        ];
607        let encoder = encoder.into_chunked_with_trailing_fields(trailers);
608
609        let mut headers = HeaderMap::new();
610        headers.insert(AUTHORIZATION, HeaderValue::from_static("header data"));
611        headers.insert(CACHE_CONTROL, HeaderValue::from_static("header data"));
612        headers.insert(CONTENT_ENCODING, HeaderValue::from_static("header data"));
613        headers.insert(CONTENT_LENGTH, HeaderValue::from_static("header data"));
614        headers.insert(CONTENT_RANGE, HeaderValue::from_static("header data"));
615        headers.insert(CONTENT_TYPE, HeaderValue::from_static("header data"));
616        headers.insert(HOST, HeaderValue::from_static("header data"));
617        headers.insert(MAX_FORWARDS, HeaderValue::from_static("header data"));
618        headers.insert(SET_COOKIE, HeaderValue::from_static("header data"));
619        headers.insert(TRAILER, HeaderValue::from_static("header data"));
620        headers.insert(TRANSFER_ENCODING, HeaderValue::from_static("header data"));
621        headers.insert(TE, HeaderValue::from_static("header data"));
622
623        assert!(encoder.encode_trailers::<&[u8]>(headers, true).is_none());
624    }
625
626    #[test]
627    fn chunked_with_title_case_headers() {
628        let encoder = Encoder::chunked();
629        let trailers = vec![HeaderName::from_static("chunky-trailer")];
630        let encoder = encoder.into_chunked_with_trailing_fields(trailers);
631
632        let headers = HeaderMap::from_iter(vec![(
633            HeaderName::from_static("chunky-trailer"),
634            HeaderValue::from_static("header data"),
635        )]);
636        let buf1 = encoder.encode_trailers::<&[u8]>(headers, true).unwrap();
637
638        let mut dst = Vec::new();
639        dst.put(buf1);
640        assert_eq!(dst, b"0\r\nChunky-Trailer: header data\r\n\r\n");
641    }
642
643    #[test]
644    fn chunked_trailers_case_insensitive_matching() {
645        // Regression test for issue #4010: HTTP/1.1 trailers are case-sensitive
646        //
647        // Previously, the Trailer header values were stored as HeaderValue (preserving case)
648        // and compared against HeaderName (which is always lowercase). This caused trailers
649        // declared as "Chunky-Trailer" to not match actual trailers sent as "chunky-trailer".
650        //
651        // The fix converts Trailer header values to HeaderName during parsing, which
652        // normalizes the case and enables proper case-insensitive matching.
653        //
654        // Note: HeaderName::from_static() requires lowercase input. In real usage,
655        // HeaderName::from_bytes() is used to parse the Trailer header value, which
656        // normalizes mixed-case input like "Chunky-Trailer" to "chunky-trailer".
657        let encoder = Encoder::chunked();
658        let trailers = vec![HeaderName::from_static("chunky-trailer")];
659        let encoder = encoder.into_chunked_with_trailing_fields(trailers);
660
661        // The actual trailer being sent
662        let headers = HeaderMap::from_iter(vec![(
663            HeaderName::from_static("chunky-trailer"),
664            HeaderValue::from_static("trailer value"),
665        )]);
666
667        let buf = encoder.encode_trailers::<&[u8]>(headers, false).unwrap();
668        let mut dst = Vec::new();
669        dst.put(buf);
670        assert_eq!(dst, b"0\r\nchunky-trailer: trailer value\r\n\r\n");
671    }
672}