Skip to main content

h2/codec/
framed_read.rs

1use crate::frame::{self, Frame, Kind, Reason};
2use crate::frame::{
3    DEFAULT_MAX_FRAME_SIZE, DEFAULT_SETTINGS_HEADER_TABLE_SIZE, MAX_MAX_FRAME_SIZE,
4};
5use crate::proto::Error;
6
7use crate::hpack;
8
9use futures_core::Stream;
10
11use bytes::{Buf, BytesMut};
12
13use std::io;
14
15use std::pin::Pin;
16use std::task::{Context, Poll};
17use tokio::io::AsyncRead;
18use tokio_util::codec::FramedRead as InnerFramedRead;
19use tokio_util::codec::{LengthDelimitedCodec, LengthDelimitedCodecError};
20
21// 16 MB "sane default" taken from golang http2
22const DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE: usize = 16 << 20;
23
24#[derive(Debug)]
25pub struct FramedRead<T> {
26    inner: InnerFramedRead<T, LengthDelimitedCodec>,
27
28    // hpack decoder state
29    hpack: hpack::Decoder,
30
31    max_header_list_size: usize,
32
33    max_continuation_frames: usize,
34
35    partial: Option<Partial>,
36}
37
38/// Partially loaded headers frame
39#[derive(Debug)]
40struct Partial {
41    /// Empty frame
42    frame: Continuable,
43
44    /// Partial header payload
45    buf: BytesMut,
46
47    continuation_frames_count: usize,
48}
49
50#[derive(Debug)]
51enum Continuable {
52    Headers(frame::Headers),
53    PushPromise(frame::PushPromise),
54}
55
56impl<T> FramedRead<T> {
57    pub fn new(inner: InnerFramedRead<T, LengthDelimitedCodec>) -> FramedRead<T> {
58        let max_header_list_size = DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE;
59        let max_continuation_frames =
60            calc_max_continuation_frames(max_header_list_size, inner.decoder().max_frame_length());
61        FramedRead {
62            inner,
63            hpack: hpack::Decoder::new(DEFAULT_SETTINGS_HEADER_TABLE_SIZE),
64            max_header_list_size,
65            max_continuation_frames,
66            partial: None,
67        }
68    }
69
70    pub fn get_ref(&self) -> &T {
71        self.inner.get_ref()
72    }
73
74    pub fn get_mut(&mut self) -> &mut T {
75        self.inner.get_mut()
76    }
77
78    /// Returns the current max frame size setting
79    #[inline]
80    pub fn max_frame_size(&self) -> usize {
81        self.inner.decoder().max_frame_length()
82    }
83
84    /// Updates the max frame size setting.
85    ///
86    /// Must be within 16,384 and 16,777,215.
87    #[inline]
88    pub fn set_max_frame_size(&mut self, val: usize) {
89        assert!(DEFAULT_MAX_FRAME_SIZE as usize <= val && val <= MAX_MAX_FRAME_SIZE as usize);
90        self.inner.decoder_mut().set_max_frame_length(val);
91        // Update max CONTINUATION frames too, since its based on this
92        self.max_continuation_frames = calc_max_continuation_frames(self.max_header_list_size, val);
93    }
94
95    /// Update the max header list size setting.
96    #[inline]
97    pub fn set_max_header_list_size(&mut self, val: usize) {
98        self.max_header_list_size = val;
99        // Update max CONTINUATION frames too, since its based on this
100        self.max_continuation_frames = calc_max_continuation_frames(val, self.max_frame_size());
101    }
102
103    /// Update the header table size setting.
104    #[inline]
105    pub fn set_header_table_size(&mut self, val: usize) {
106        self.hpack.queue_size_update(val);
107    }
108}
109
110fn calc_max_continuation_frames(header_max: usize, frame_max: usize) -> usize {
111    // At least this many frames needed to use max header list size
112    let min_frames_for_list = (header_max / frame_max).max(1);
113    // Some padding for imperfectly packed frames
114    // 25% without floats
115    let padding = min_frames_for_list >> 2;
116    min_frames_for_list.saturating_add(padding).max(5)
117}
118
119/// Decodes a frame.
120///
121/// This method is intentionally de-generified and outlined because it is very large.
122fn decode_frame(
123    hpack: &mut hpack::Decoder,
124    max_header_list_size: usize,
125    max_continuation_frames: usize,
126    partial_inout: &mut Option<Partial>,
127    mut bytes: BytesMut,
128) -> Result<Option<Frame>, Error> {
129    let span = tracing::trace_span!("FramedRead::decode_frame", offset = bytes.len());
130    let _e = span.enter();
131
132    tracing::trace!("decoding frame from {}B", bytes.len());
133
134    // Parse the head
135    let head = frame::Head::parse(&bytes);
136
137    if partial_inout.is_some() && head.kind() != Kind::Continuation {
138        proto_err!(conn: "expected CONTINUATION, got {:?}", head.kind());
139        return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
140    }
141
142    let kind = head.kind();
143
144    tracing::trace!(frame.kind = ?kind);
145
146    macro_rules! header_block {
147        ($frame:ident, $head:ident, $bytes:ident) => ({
148            // Drop the frame header
149            $bytes.advance(frame::HEADER_LEN);
150
151            // Parse the header frame w/o parsing the payload
152            let (mut frame, mut payload) = match frame::$frame::load($head, $bytes) {
153                Ok(res) => res,
154                Err(frame::Error::InvalidDependencyId) => {
155                    proto_err!(stream: "invalid HEADERS dependency ID");
156                    // A stream cannot depend on itself. An endpoint MUST
157                    // treat this as a stream error (Section 5.4.2) of type
158                    // `PROTOCOL_ERROR`.
159                    return Err(Error::library_reset($head.stream_id(), Reason::PROTOCOL_ERROR));
160                },
161                Err(e) => {
162                    proto_err!(conn: "failed to load frame; err={:?}", e);
163                    return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
164                }
165            };
166
167            let is_end_headers = frame.is_end_headers();
168
169            // Load the HPACK encoded headers
170            match frame.load_hpack(&mut payload, max_header_list_size, hpack) {
171                Ok(_) => {},
172                Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {},
173                Err(frame::Error::MalformedMessage) => {
174                    let id = $head.stream_id();
175                    proto_err!(stream: "malformed header block; stream={:?}", id);
176                    return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR));
177                },
178                Err(frame::Error::HeaderListWayTooLarge) => {
179                    proto_err!(conn: "decoded header list size over abuse limit");
180                    return Err(Error::library_go_away_data(
181                        Reason::ENHANCE_YOUR_CALM,
182                        "header_list_way_too_large",
183                    ));
184                },
185                Err(e) => {
186                    proto_err!(conn: "failed HPACK decoding; err={:?}", e);
187                    return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
188                }
189            }
190
191            if is_end_headers {
192                frame.into()
193            } else {
194                tracing::trace!("loaded partial header block");
195                // Defer returning the frame
196                *partial_inout = Some(Partial {
197                    frame: Continuable::$frame(frame),
198                    buf: payload,
199                    continuation_frames_count: 0,
200                });
201
202                return Ok(None);
203            }
204        });
205    }
206
207    let frame = match kind {
208        Kind::Settings => {
209            let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]);
210
211            res.map_err(|e| {
212                proto_err!(conn: "failed to load SETTINGS frame; err={:?}", e);
213                Error::library_go_away(Reason::PROTOCOL_ERROR)
214            })?
215            .into()
216        }
217        Kind::Ping => {
218            let res = frame::Ping::load(head, &bytes[frame::HEADER_LEN..]);
219
220            res.map_err(|e| {
221                proto_err!(conn: "failed to load PING frame; err={:?}", e);
222                Error::library_go_away(Reason::PROTOCOL_ERROR)
223            })?
224            .into()
225        }
226        Kind::WindowUpdate => {
227            let res = frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..]);
228
229            res.map_err(|e| {
230                proto_err!(conn: "failed to load WINDOW_UPDATE frame; err={:?}", e);
231                Error::library_go_away(Reason::PROTOCOL_ERROR)
232            })?
233            .into()
234        }
235        Kind::Data => {
236            bytes.advance(frame::HEADER_LEN);
237            let res = frame::Data::load(head, bytes.freeze());
238
239            // TODO: Should this always be connection level? Probably not...
240            res.map_err(|e| {
241                proto_err!(conn: "failed to load DATA frame; err={:?}", e);
242                Error::library_go_away(Reason::PROTOCOL_ERROR)
243            })?
244            .into()
245        }
246        Kind::Headers => header_block!(Headers, head, bytes),
247        Kind::Reset => {
248            let res = frame::Reset::load(head, &bytes[frame::HEADER_LEN..]);
249            res.map_err(|e| {
250                proto_err!(conn: "failed to load RESET frame; err={:?}", e);
251                Error::library_go_away(Reason::PROTOCOL_ERROR)
252            })?
253            .into()
254        }
255        Kind::GoAway => {
256            let res = frame::GoAway::load(&bytes[frame::HEADER_LEN..]);
257            res.map_err(|e| {
258                proto_err!(conn: "failed to load GO_AWAY frame; err={:?}", e);
259                Error::library_go_away(Reason::PROTOCOL_ERROR)
260            })?
261            .into()
262        }
263        Kind::PushPromise => header_block!(PushPromise, head, bytes),
264        Kind::Priority => {
265            if head.stream_id() == 0 {
266                // Invalid stream identifier
267                proto_err!(conn: "invalid stream ID 0");
268                return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
269            }
270
271            match frame::Priority::load(head, &bytes[frame::HEADER_LEN..]) {
272                Ok(frame) => frame.into(),
273                Err(frame::Error::InvalidDependencyId) => {
274                    // A stream cannot depend on itself. An endpoint MUST
275                    // treat this as a stream error (Section 5.4.2) of type
276                    // `PROTOCOL_ERROR`.
277                    let id = head.stream_id();
278                    proto_err!(stream: "PRIORITY invalid dependency ID; stream={:?}", id);
279                    return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR));
280                }
281                Err(e) => {
282                    proto_err!(conn: "failed to load PRIORITY frame; err={:?};", e);
283                    return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
284                }
285            }
286        }
287        Kind::Continuation => {
288            let is_end_headers = (head.flag() & 0x4) == 0x4;
289
290            let mut partial = match partial_inout.take() {
291                Some(partial) => partial,
292                None => {
293                    proto_err!(conn: "received unexpected CONTINUATION frame");
294                    return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
295                }
296            };
297
298            // The stream identifiers must match
299            if partial.frame.stream_id() != head.stream_id() {
300                proto_err!(conn: "CONTINUATION frame stream ID does not match previous frame stream ID");
301                return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
302            }
303
304            // Check for CONTINUATION flood
305            if is_end_headers {
306                partial.continuation_frames_count = 0;
307            } else {
308                let cnt = partial.continuation_frames_count + 1;
309                if cnt > max_continuation_frames {
310                    tracing::debug!("too_many_continuations, max = {}", max_continuation_frames);
311                    return Err(Error::library_go_away_data(
312                        Reason::ENHANCE_YOUR_CALM,
313                        "too_many_continuations",
314                    ));
315                } else {
316                    partial.continuation_frames_count = cnt;
317                }
318            }
319
320            // Extend the buf
321            if partial.buf.is_empty() {
322                partial.buf = bytes.split_off(frame::HEADER_LEN);
323            } else {
324                if partial.frame.is_over_size() {
325                    // If there was left over bytes previously, they may be
326                    // needed to continue decoding, even though we will
327                    // be ignoring this frame. This is done to keep the HPACK
328                    // decoder state up-to-date.
329                    //
330                    // Still, we need to be careful, because if a malicious
331                    // attacker were to try to send a gigantic string, such
332                    // that it fits over multiple header blocks, we could
333                    // grow memory uncontrollably again, and that'd be a shame.
334                    //
335                    // Instead, we use a simple heuristic to determine if
336                    // we should continue to ignore decoding, or to tell
337                    // the attacker to go away.
338                    if partial.buf.len() + bytes.len() > max_header_list_size {
339                        proto_err!(conn: "CONTINUATION frame header block size over ignorable limit");
340                        return Err(Error::library_go_away(Reason::COMPRESSION_ERROR));
341                    }
342                }
343                partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]);
344            }
345
346            match partial
347                .frame
348                .load_hpack(&mut partial.buf, max_header_list_size, hpack)
349            {
350                Ok(_) => {}
351                Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {}
352                Err(frame::Error::MalformedMessage) => {
353                    let id = head.stream_id();
354                    proto_err!(stream: "malformed CONTINUATION frame; stream={:?}", id);
355                    return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR));
356                }
357                Err(frame::Error::HeaderListWayTooLarge) => {
358                    proto_err!(conn: "decoded CONTINUATION header list size over abuse limit");
359                    return Err(Error::library_go_away_data(
360                        Reason::ENHANCE_YOUR_CALM,
361                        "header_list_way_too_large",
362                    ));
363                }
364                Err(e) => {
365                    proto_err!(conn: "failed HPACK decoding; err={:?}", e);
366                    return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
367                }
368            }
369
370            if is_end_headers {
371                partial.frame.into()
372            } else {
373                *partial_inout = Some(partial);
374                return Ok(None);
375            }
376        }
377        Kind::Unknown => {
378            // Unknown frames are ignored
379            return Ok(None);
380        }
381    };
382
383    Ok(Some(frame))
384}
385
386impl<T> Stream for FramedRead<T>
387where
388    T: AsyncRead + Unpin,
389{
390    type Item = Result<Frame, Error>;
391
392    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
393        let span = tracing::trace_span!("FramedRead::poll_next");
394        let _e = span.enter();
395        loop {
396            tracing::trace!("poll");
397            let bytes = match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
398                Some(Ok(bytes)) => bytes,
399                Some(Err(e)) => return Poll::Ready(Some(Err(map_err(e)))),
400                None => return Poll::Ready(None),
401            };
402
403            tracing::trace!(read.bytes = bytes.len());
404            let Self {
405                ref mut hpack,
406                max_header_list_size,
407                ref mut partial,
408                max_continuation_frames,
409                ..
410            } = *self;
411            if let Some(frame) = decode_frame(
412                hpack,
413                max_header_list_size,
414                max_continuation_frames,
415                partial,
416                bytes,
417            )? {
418                tracing::debug!(?frame, "received");
419                return Poll::Ready(Some(Ok(frame)));
420            }
421        }
422    }
423}
424
425fn map_err(err: io::Error) -> Error {
426    if let io::ErrorKind::InvalidData = err.kind() {
427        if let Some(custom) = err.get_ref() {
428            if custom.is::<LengthDelimitedCodecError>() {
429                return Error::library_go_away(Reason::FRAME_SIZE_ERROR);
430            }
431        }
432    }
433    err.into()
434}
435
436// ===== impl Continuable =====
437
438impl Continuable {
439    fn stream_id(&self) -> frame::StreamId {
440        match *self {
441            Continuable::Headers(ref h) => h.stream_id(),
442            Continuable::PushPromise(ref p) => p.stream_id(),
443        }
444    }
445
446    fn is_over_size(&self) -> bool {
447        match *self {
448            Continuable::Headers(ref h) => h.is_over_size(),
449            Continuable::PushPromise(ref p) => p.is_over_size(),
450        }
451    }
452
453    fn load_hpack(
454        &mut self,
455        src: &mut BytesMut,
456        max_header_list_size: usize,
457        decoder: &mut hpack::Decoder,
458    ) -> Result<(), frame::Error> {
459        match *self {
460            Continuable::Headers(ref mut h) => h.load_hpack(src, max_header_list_size, decoder),
461            Continuable::PushPromise(ref mut p) => p.load_hpack(src, max_header_list_size, decoder),
462        }
463    }
464}
465
466impl<T> From<Continuable> for Frame<T> {
467    fn from(cont: Continuable) -> Self {
468        match cont {
469            Continuable::Headers(mut headers) => {
470                headers.set_end_headers();
471                headers.into()
472            }
473            Continuable::PushPromise(mut push) => {
474                push.set_end_headers();
475                push.into()
476            }
477        }
478    }
479}