1use std::collections::HashSet;
2use std::fmt;
3use std::io::IoSlice;
4
5use bytes::buf::{Chain, Take};
6use bytes::{Buf, Bytes};
7use http::{
8 header::{
9 AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE,
10 CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TE, TRAILER, TRANSFER_ENCODING,
11 },
12 HeaderMap, HeaderName,
13};
14
15use super::io::WriteBuf;
16use super::role::{write_headers, write_headers_title_case};
17
18type StaticBuf = &'static [u8];
19
20#[derive(Debug, Clone, PartialEq)]
22pub(crate) struct Encoder {
23 kind: Kind,
24 is_last: bool,
25}
26
27#[derive(Debug)]
28pub(crate) struct EncodedBuf<B> {
29 kind: BufKind<B>,
30}
31
32#[derive(Debug)]
33pub(crate) struct NotEof(u64);
34
35#[derive(Debug, PartialEq, Clone)]
36enum Kind {
37 Chunked(Option<Vec<HeaderName>>),
39 Length(u64),
43 #[cfg(feature = "server")]
48 CloseDelimited,
49}
50
51#[derive(Debug)]
52enum BufKind<B> {
53 Exact(B),
54 Limited(Take<B>),
55 Chunked(Chain<Chain<ChunkSize, B>, StaticBuf>),
56 ChunkedEnd(StaticBuf),
57 Trailers(Chain<Chain<StaticBuf, Bytes>, StaticBuf>),
58}
59
60impl Encoder {
61 fn new(kind: Kind) -> Encoder {
62 Encoder {
63 kind,
64 is_last: false,
65 }
66 }
67 pub(crate) fn chunked() -> Encoder {
68 Encoder::new(Kind::Chunked(None))
69 }
70
71 pub(crate) fn length(len: u64) -> Encoder {
72 Encoder::new(Kind::Length(len))
73 }
74
75 #[cfg(feature = "server")]
76 pub(crate) fn close_delimited() -> Encoder {
77 Encoder::new(Kind::CloseDelimited)
78 }
79
80 pub(crate) fn into_chunked_with_trailing_fields(self, trailers: Vec<HeaderName>) -> Encoder {
81 match self.kind {
82 Kind::Chunked(_) => Encoder {
83 kind: Kind::Chunked(Some(trailers)),
84 is_last: self.is_last,
85 },
86 _ => self,
87 }
88 }
89
90 pub(crate) fn is_eof(&self) -> bool {
91 matches!(self.kind, Kind::Length(0))
92 }
93
94 #[cfg(feature = "server")]
95 pub(crate) fn set_last(mut self, is_last: bool) -> Self {
96 self.is_last = is_last;
97 self
98 }
99
100 pub(crate) fn is_last(&self) -> bool {
101 self.is_last
102 }
103
104 pub(crate) fn is_close_delimited(&self) -> bool {
105 match self.kind {
106 #[cfg(feature = "server")]
107 Kind::CloseDelimited => true,
108 _ => false,
109 }
110 }
111
112 pub(crate) fn is_chunked(&self) -> bool {
113 matches!(self.kind, Kind::Chunked(_))
114 }
115
116 pub(crate) fn end<B>(&self) -> Result<Option<EncodedBuf<B>>, NotEof> {
117 match self.kind {
118 Kind::Length(0) => Ok(None),
119 Kind::Chunked(_) => Ok(Some(EncodedBuf {
120 kind: BufKind::ChunkedEnd(b"0\r\n\r\n"),
121 })),
122 #[cfg(feature = "server")]
123 Kind::CloseDelimited => Ok(None),
124 Kind::Length(n) => Err(NotEof(n)),
125 }
126 }
127
128 pub(crate) fn encode<B>(&mut self, msg: B) -> EncodedBuf<B>
129 where
130 B: Buf,
131 {
132 let len = msg.remaining();
133 debug_assert!(len > 0, "encode() called with empty buf");
134
135 let kind = match self.kind {
136 Kind::Chunked(_) => {
137 trace!("encoding chunked {}B", len);
138 let buf = ChunkSize::new(len)
139 .chain(msg)
140 .chain(b"\r\n" as &'static [u8]);
141 BufKind::Chunked(buf)
142 }
143 Kind::Length(ref mut remaining) => {
144 trace!("sized write, len = {}", len);
145 if len as u64 > *remaining {
146 let limit = *remaining as usize;
147 *remaining = 0;
148 BufKind::Limited(msg.take(limit))
149 } else {
150 *remaining -= len as u64;
151 BufKind::Exact(msg)
152 }
153 }
154 #[cfg(feature = "server")]
155 Kind::CloseDelimited => {
156 trace!("close delimited write {}B", len);
157 BufKind::Exact(msg)
158 }
159 };
160 EncodedBuf { kind }
161 }
162
163 pub(crate) fn encode_trailers<B>(
164 &self,
165 trailers: HeaderMap,
166 title_case_headers: bool,
167 ) -> Option<EncodedBuf<B>> {
168 trace!("encoding trailers");
169 match &self.kind {
170 Kind::Chunked(Some(allowed_trailer_fields)) => {
171 let allowed_set: HashSet<&HeaderName> = allowed_trailer_fields.iter().collect();
172
173 let mut cur_name = None;
174 let mut allowed_trailers = HeaderMap::new();
175
176 for (opt_name, value) in trailers {
177 if let Some(n) = opt_name {
178 cur_name = Some(n);
179 }
180 let name = cur_name.as_ref().expect("current header name");
181
182 if allowed_set.contains(name) {
183 if is_valid_trailer_field(name) {
184 allowed_trailers.insert(name, value);
185 } else {
186 debug!("trailer field is not valid: {}", &name);
187 }
188 } else {
189 debug!("trailer header name not found in trailer header: {}", &name);
190 }
191 }
192
193 let mut buf = Vec::new();
194 if title_case_headers {
195 write_headers_title_case(&allowed_trailers, &mut buf);
196 } else {
197 write_headers(&allowed_trailers, &mut buf);
198 }
199
200 if buf.is_empty() {
201 return None;
202 }
203
204 Some(EncodedBuf {
205 kind: BufKind::Trailers(b"0\r\n".chain(Bytes::from(buf)).chain(b"\r\n")),
206 })
207 }
208 Kind::Chunked(None) => {
209 debug!("attempted to encode trailers, but the trailer header is not set");
210 None
211 }
212 _ => {
213 debug!("attempted to encode trailers for non-chunked response");
214 None
215 }
216 }
217 }
218
219 pub(super) fn encode_and_end<B>(&self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) -> bool
220 where
221 B: Buf,
222 {
223 let len = msg.remaining();
224 debug_assert!(len > 0, "encode() called with empty buf");
225
226 match self.kind {
227 Kind::Chunked(_) => {
228 trace!("encoding chunked {}B", len);
229 let buf = ChunkSize::new(len)
230 .chain(msg)
231 .chain(b"\r\n0\r\n\r\n" as &'static [u8]);
232 dst.buffer(buf);
233 !self.is_last
234 }
235 Kind::Length(remaining) => {
236 use std::cmp::Ordering;
237
238 trace!("sized write, len = {}", len);
239 match (len as u64).cmp(&remaining) {
240 Ordering::Equal => {
241 dst.buffer(msg);
242 !self.is_last
243 }
244 Ordering::Greater => {
245 dst.buffer(msg.take(remaining as usize));
246 !self.is_last
247 }
248 Ordering::Less => {
249 dst.buffer(msg);
250 false
251 }
252 }
253 }
254 #[cfg(feature = "server")]
255 Kind::CloseDelimited => {
256 trace!("close delimited write {}B", len);
257 dst.buffer(msg);
258 false
259 }
260 }
261 }
262}
263
264fn is_valid_trailer_field(name: &HeaderName) -> bool {
265 !matches!(
266 *name,
267 AUTHORIZATION
268 | CACHE_CONTROL
269 | CONTENT_ENCODING
270 | CONTENT_LENGTH
271 | CONTENT_RANGE
272 | CONTENT_TYPE
273 | HOST
274 | MAX_FORWARDS
275 | SET_COOKIE
276 | TRAILER
277 | TRANSFER_ENCODING
278 | TE
279 )
280}
281
282impl<B> Buf for EncodedBuf<B>
283where
284 B: Buf,
285{
286 #[inline]
287 fn remaining(&self) -> usize {
288 match self.kind {
289 BufKind::Exact(ref b) => b.remaining(),
290 BufKind::Limited(ref b) => b.remaining(),
291 BufKind::Chunked(ref b) => b.remaining(),
292 BufKind::ChunkedEnd(ref b) => b.remaining(),
293 BufKind::Trailers(ref b) => b.remaining(),
294 }
295 }
296
297 #[inline]
298 fn chunk(&self) -> &[u8] {
299 match self.kind {
300 BufKind::Exact(ref b) => b.chunk(),
301 BufKind::Limited(ref b) => b.chunk(),
302 BufKind::Chunked(ref b) => b.chunk(),
303 BufKind::ChunkedEnd(ref b) => b.chunk(),
304 BufKind::Trailers(ref b) => b.chunk(),
305 }
306 }
307
308 #[inline]
309 fn advance(&mut self, cnt: usize) {
310 match self.kind {
311 BufKind::Exact(ref mut b) => b.advance(cnt),
312 BufKind::Limited(ref mut b) => b.advance(cnt),
313 BufKind::Chunked(ref mut b) => b.advance(cnt),
314 BufKind::ChunkedEnd(ref mut b) => b.advance(cnt),
315 BufKind::Trailers(ref mut b) => b.advance(cnt),
316 }
317 }
318
319 #[inline]
320 fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize {
321 match self.kind {
322 BufKind::Exact(ref b) => b.chunks_vectored(dst),
323 BufKind::Limited(ref b) => b.chunks_vectored(dst),
324 BufKind::Chunked(ref b) => b.chunks_vectored(dst),
325 BufKind::ChunkedEnd(ref b) => b.chunks_vectored(dst),
326 BufKind::Trailers(ref b) => b.chunks_vectored(dst),
327 }
328 }
329}
330
331#[cfg(target_pointer_width = "32")]
332const USIZE_BYTES: usize = 4;
333
334#[cfg(target_pointer_width = "64")]
335const USIZE_BYTES: usize = 8;
336
337const CHUNK_SIZE_MAX_BYTES: usize = USIZE_BYTES * 2;
339
340#[derive(Clone, Copy)]
341struct ChunkSize {
342 bytes: [u8; CHUNK_SIZE_MAX_BYTES + 2],
343 pos: u8,
344 len: u8,
345}
346
347impl ChunkSize {
348 fn new(len: usize) -> ChunkSize {
349 use std::fmt::Write;
350 let mut size = ChunkSize {
351 bytes: [0; CHUNK_SIZE_MAX_BYTES + 2],
352 pos: 0,
353 len: 0,
354 };
355 write!(&mut size, "{:X}\r\n", len).expect("CHUNK_SIZE_MAX_BYTES should fit any usize");
356 size
357 }
358}
359
360impl Buf for ChunkSize {
361 #[inline]
362 fn remaining(&self) -> usize {
363 (self.len - self.pos).into()
364 }
365
366 #[inline]
367 fn chunk(&self) -> &[u8] {
368 &self.bytes[self.pos.into()..self.len.into()]
369 }
370
371 #[inline]
372 fn advance(&mut self, cnt: usize) {
373 assert!(cnt <= self.remaining());
374 self.pos += cnt as u8; }
376}
377
378impl fmt::Debug for ChunkSize {
379 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
380 f.debug_struct("ChunkSize")
381 .field("bytes", &&self.bytes[..self.len.into()])
382 .field("pos", &self.pos)
383 .finish()
384 }
385}
386
387impl fmt::Write for ChunkSize {
388 fn write_str(&mut self, num: &str) -> fmt::Result {
389 use std::io::Write;
390 (&mut self.bytes[self.len.into()..])
391 .write_all(num.as_bytes())
392 .expect("&mut [u8].write() cannot error");
393 self.len += num.len() as u8; Ok(())
395 }
396}
397
398impl<B: Buf> From<B> for EncodedBuf<B> {
399 fn from(buf: B) -> Self {
400 EncodedBuf {
401 kind: BufKind::Exact(buf),
402 }
403 }
404}
405
406impl<B: Buf> From<Take<B>> for EncodedBuf<B> {
407 fn from(buf: Take<B>) -> Self {
408 EncodedBuf {
409 kind: BufKind::Limited(buf),
410 }
411 }
412}
413
414impl<B: Buf> From<Chain<Chain<ChunkSize, B>, StaticBuf>> for EncodedBuf<B> {
415 fn from(buf: Chain<Chain<ChunkSize, B>, StaticBuf>) -> Self {
416 EncodedBuf {
417 kind: BufKind::Chunked(buf),
418 }
419 }
420}
421
422impl fmt::Display for NotEof {
423 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
424 write!(f, "early end, expected {} more bytes", self.0)
425 }
426}
427
428impl std::error::Error for NotEof {}
429
430#[cfg(test)]
431mod tests {
432 use bytes::BufMut;
433 use http::{
434 header::{
435 AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE,
436 CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TE, TRAILER, TRANSFER_ENCODING,
437 },
438 HeaderMap, HeaderName, HeaderValue,
439 };
440
441 use super::super::io::Cursor;
442 use super::Encoder;
443
444 #[test]
445 fn chunked() {
446 let mut encoder = Encoder::chunked();
447 let mut dst = Vec::new();
448
449 let msg1 = b"foo bar".as_ref();
450 let buf1 = encoder.encode(msg1);
451 dst.put(buf1);
452 assert_eq!(dst, b"7\r\nfoo bar\r\n");
453
454 let msg2 = b"baz quux herp".as_ref();
455 let buf2 = encoder.encode(msg2);
456 dst.put(buf2);
457
458 assert_eq!(dst, b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n");
459
460 let end = encoder.end::<Cursor<Vec<u8>>>().unwrap().unwrap();
461 dst.put(end);
462
463 assert_eq!(
464 dst,
465 b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n0\r\n\r\n".as_ref()
466 );
467 }
468
469 #[test]
470 fn length() {
471 let max_len = 8;
472 let mut encoder = Encoder::length(max_len as u64);
473 let mut dst = Vec::new();
474
475 let msg1 = b"foo bar".as_ref();
476 let buf1 = encoder.encode(msg1);
477 dst.put(buf1);
478
479 assert_eq!(dst, b"foo bar");
480 assert!(!encoder.is_eof());
481 encoder.end::<()>().unwrap_err();
482
483 let msg2 = b"baz".as_ref();
484 let buf2 = encoder.encode(msg2);
485 dst.put(buf2);
486
487 assert_eq!(dst.len(), max_len);
488 assert_eq!(dst, b"foo barb");
489 assert!(encoder.is_eof());
490 assert!(encoder.end::<()>().unwrap().is_none());
491 }
492
493 #[cfg(feature = "server")]
494 #[test]
495 fn eof() {
496 let mut encoder = Encoder::close_delimited();
497 let mut dst = Vec::new();
498
499 let msg1 = b"foo bar".as_ref();
500 let buf1 = encoder.encode(msg1);
501 dst.put(buf1);
502
503 assert_eq!(dst, b"foo bar");
504 assert!(!encoder.is_eof());
505 encoder.end::<()>().unwrap();
506
507 let msg2 = b"baz".as_ref();
508 let buf2 = encoder.encode(msg2);
509 dst.put(buf2);
510
511 assert_eq!(dst, b"foo barbaz");
512 assert!(!encoder.is_eof());
513 encoder.end::<()>().unwrap();
514 }
515
516 #[test]
517 fn chunked_with_valid_trailers() {
518 let encoder = Encoder::chunked();
519 let trailers = vec![HeaderName::from_static("chunky-trailer")];
520 let encoder = encoder.into_chunked_with_trailing_fields(trailers);
521
522 let headers = HeaderMap::from_iter(vec![
523 (
524 HeaderName::from_static("chunky-trailer"),
525 HeaderValue::from_static("header data"),
526 ),
527 (
528 HeaderName::from_static("should-not-be-included"),
529 HeaderValue::from_static("oops"),
530 ),
531 ]);
532
533 let buf1 = encoder.encode_trailers::<&[u8]>(headers, false).unwrap();
534
535 let mut dst = Vec::new();
536 dst.put(buf1);
537 assert_eq!(dst, b"0\r\nchunky-trailer: header data\r\n\r\n");
538 }
539
540 #[test]
541 fn chunked_with_multiple_trailer_headers() {
542 let encoder = Encoder::chunked();
543 let trailers = vec![
544 HeaderName::from_static("chunky-trailer"),
545 HeaderName::from_static("chunky-trailer-2"),
546 ];
547 let encoder = encoder.into_chunked_with_trailing_fields(trailers);
548
549 let headers = HeaderMap::from_iter(vec![
550 (
551 HeaderName::from_static("chunky-trailer"),
552 HeaderValue::from_static("header data"),
553 ),
554 (
555 HeaderName::from_static("chunky-trailer-2"),
556 HeaderValue::from_static("more header data"),
557 ),
558 ]);
559
560 let buf1 = encoder.encode_trailers::<&[u8]>(headers, false).unwrap();
561
562 let mut dst = Vec::new();
563 dst.put(buf1);
564 assert_eq!(
565 dst,
566 b"0\r\nchunky-trailer: header data\r\nchunky-trailer-2: more header data\r\n\r\n"
567 );
568 }
569
570 #[test]
571 fn chunked_with_no_trailer_header() {
572 let encoder = Encoder::chunked();
573
574 let headers = HeaderMap::from_iter(vec![(
575 HeaderName::from_static("chunky-trailer"),
576 HeaderValue::from_static("header data"),
577 )]);
578
579 assert!(encoder
580 .encode_trailers::<&[u8]>(headers.clone(), false)
581 .is_none());
582
583 let trailers = vec![];
584 let encoder = encoder.into_chunked_with_trailing_fields(trailers);
585
586 assert!(encoder.encode_trailers::<&[u8]>(headers, false).is_none());
587 }
588
589 #[test]
590 fn chunked_with_invalid_trailers() {
591 let encoder = Encoder::chunked();
592
593 let trailers = vec![
594 AUTHORIZATION,
595 CACHE_CONTROL,
596 CONTENT_ENCODING,
597 CONTENT_LENGTH,
598 CONTENT_RANGE,
599 CONTENT_TYPE,
600 HOST,
601 MAX_FORWARDS,
602 SET_COOKIE,
603 TRAILER,
604 TRANSFER_ENCODING,
605 TE,
606 ];
607 let encoder = encoder.into_chunked_with_trailing_fields(trailers);
608
609 let mut headers = HeaderMap::new();
610 headers.insert(AUTHORIZATION, HeaderValue::from_static("header data"));
611 headers.insert(CACHE_CONTROL, HeaderValue::from_static("header data"));
612 headers.insert(CONTENT_ENCODING, HeaderValue::from_static("header data"));
613 headers.insert(CONTENT_LENGTH, HeaderValue::from_static("header data"));
614 headers.insert(CONTENT_RANGE, HeaderValue::from_static("header data"));
615 headers.insert(CONTENT_TYPE, HeaderValue::from_static("header data"));
616 headers.insert(HOST, HeaderValue::from_static("header data"));
617 headers.insert(MAX_FORWARDS, HeaderValue::from_static("header data"));
618 headers.insert(SET_COOKIE, HeaderValue::from_static("header data"));
619 headers.insert(TRAILER, HeaderValue::from_static("header data"));
620 headers.insert(TRANSFER_ENCODING, HeaderValue::from_static("header data"));
621 headers.insert(TE, HeaderValue::from_static("header data"));
622
623 assert!(encoder.encode_trailers::<&[u8]>(headers, true).is_none());
624 }
625
626 #[test]
627 fn chunked_with_title_case_headers() {
628 let encoder = Encoder::chunked();
629 let trailers = vec![HeaderName::from_static("chunky-trailer")];
630 let encoder = encoder.into_chunked_with_trailing_fields(trailers);
631
632 let headers = HeaderMap::from_iter(vec![(
633 HeaderName::from_static("chunky-trailer"),
634 HeaderValue::from_static("header data"),
635 )]);
636 let buf1 = encoder.encode_trailers::<&[u8]>(headers, true).unwrap();
637
638 let mut dst = Vec::new();
639 dst.put(buf1);
640 assert_eq!(dst, b"0\r\nChunky-Trailer: header data\r\n\r\n");
641 }
642
643 #[test]
644 fn chunked_trailers_case_insensitive_matching() {
645 let encoder = Encoder::chunked();
658 let trailers = vec![HeaderName::from_static("chunky-trailer")];
659 let encoder = encoder.into_chunked_with_trailing_fields(trailers);
660
661 let headers = HeaderMap::from_iter(vec![(
663 HeaderName::from_static("chunky-trailer"),
664 HeaderValue::from_static("trailer value"),
665 )]);
666
667 let buf = encoder.encode_trailers::<&[u8]>(headers, false).unwrap();
668 let mut dst = Vec::new();
669 dst.put(buf);
670 assert_eq!(dst, b"0\r\nchunky-trailer: trailer value\r\n\r\n");
671 }
672}