1use crate::frame::{self, Frame, Kind, Reason};
2use crate::frame::{
3 DEFAULT_MAX_FRAME_SIZE, DEFAULT_SETTINGS_HEADER_TABLE_SIZE, MAX_MAX_FRAME_SIZE,
4};
5use crate::proto::Error;
6
7use crate::hpack;
8
9use futures_core::Stream;
10
11use bytes::{Buf, BytesMut};
12
13use std::io;
14
15use std::pin::Pin;
16use std::task::{Context, Poll};
17use tokio::io::AsyncRead;
18use tokio_util::codec::FramedRead as InnerFramedRead;
19use tokio_util::codec::{LengthDelimitedCodec, LengthDelimitedCodecError};
20
21const DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE: usize = 16 << 20;
23
24#[derive(Debug)]
25pub struct FramedRead<T> {
26 inner: InnerFramedRead<T, LengthDelimitedCodec>,
27
28 hpack: hpack::Decoder,
30
31 max_header_list_size: usize,
32
33 max_continuation_frames: usize,
34
35 partial: Option<Partial>,
36}
37
38#[derive(Debug)]
40struct Partial {
41 frame: Continuable,
43
44 buf: BytesMut,
46
47 continuation_frames_count: usize,
48}
49
50#[derive(Debug)]
51enum Continuable {
52 Headers(frame::Headers),
53 PushPromise(frame::PushPromise),
54}
55
56impl<T> FramedRead<T> {
57 pub fn new(inner: InnerFramedRead<T, LengthDelimitedCodec>) -> FramedRead<T> {
58 let max_header_list_size = DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE;
59 let max_continuation_frames =
60 calc_max_continuation_frames(max_header_list_size, inner.decoder().max_frame_length());
61 FramedRead {
62 inner,
63 hpack: hpack::Decoder::new(DEFAULT_SETTINGS_HEADER_TABLE_SIZE),
64 max_header_list_size,
65 max_continuation_frames,
66 partial: None,
67 }
68 }
69
70 pub fn get_ref(&self) -> &T {
71 self.inner.get_ref()
72 }
73
74 pub fn get_mut(&mut self) -> &mut T {
75 self.inner.get_mut()
76 }
77
78 #[inline]
80 pub fn max_frame_size(&self) -> usize {
81 self.inner.decoder().max_frame_length()
82 }
83
84 #[inline]
88 pub fn set_max_frame_size(&mut self, val: usize) {
89 assert!(DEFAULT_MAX_FRAME_SIZE as usize <= val && val <= MAX_MAX_FRAME_SIZE as usize);
90 self.inner.decoder_mut().set_max_frame_length(val);
91 self.max_continuation_frames = calc_max_continuation_frames(self.max_header_list_size, val);
93 }
94
95 #[inline]
97 pub fn set_max_header_list_size(&mut self, val: usize) {
98 self.max_header_list_size = val;
99 self.max_continuation_frames = calc_max_continuation_frames(val, self.max_frame_size());
101 }
102
103 #[inline]
105 pub fn set_header_table_size(&mut self, val: usize) {
106 self.hpack.queue_size_update(val);
107 }
108}
109
110fn calc_max_continuation_frames(header_max: usize, frame_max: usize) -> usize {
111 let min_frames_for_list = (header_max / frame_max).max(1);
113 let padding = min_frames_for_list >> 2;
116 min_frames_for_list.saturating_add(padding).max(5)
117}
118
119fn decode_frame(
123 hpack: &mut hpack::Decoder,
124 max_header_list_size: usize,
125 max_continuation_frames: usize,
126 partial_inout: &mut Option<Partial>,
127 mut bytes: BytesMut,
128) -> Result<Option<Frame>, Error> {
129 let span = tracing::trace_span!("FramedRead::decode_frame", offset = bytes.len());
130 let _e = span.enter();
131
132 tracing::trace!("decoding frame from {}B", bytes.len());
133
134 let head = frame::Head::parse(&bytes);
136
137 if partial_inout.is_some() && head.kind() != Kind::Continuation {
138 proto_err!(conn: "expected CONTINUATION, got {:?}", head.kind());
139 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
140 }
141
142 let kind = head.kind();
143
144 tracing::trace!(frame.kind = ?kind);
145
146 macro_rules! header_block {
147 ($frame:ident, $head:ident, $bytes:ident) => ({
148 $bytes.advance(frame::HEADER_LEN);
150
151 let (mut frame, mut payload) = match frame::$frame::load($head, $bytes) {
153 Ok(res) => res,
154 Err(frame::Error::InvalidDependencyId) => {
155 proto_err!(stream: "invalid HEADERS dependency ID");
156 return Err(Error::library_reset($head.stream_id(), Reason::PROTOCOL_ERROR));
160 },
161 Err(e) => {
162 proto_err!(conn: "failed to load frame; err={:?}", e);
163 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
164 }
165 };
166
167 let is_end_headers = frame.is_end_headers();
168
169 match frame.load_hpack(&mut payload, max_header_list_size, hpack) {
171 Ok(_) => {},
172 Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {},
173 Err(frame::Error::MalformedMessage) => {
174 let id = $head.stream_id();
175 proto_err!(stream: "malformed header block; stream={:?}", id);
176 return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR));
177 },
178 Err(frame::Error::HeaderListWayTooLarge) => {
179 proto_err!(conn: "decoded header list size over abuse limit");
180 return Err(Error::library_go_away_data(
181 Reason::ENHANCE_YOUR_CALM,
182 "header_list_way_too_large",
183 ));
184 },
185 Err(e) => {
186 proto_err!(conn: "failed HPACK decoding; err={:?}", e);
187 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
188 }
189 }
190
191 if is_end_headers {
192 frame.into()
193 } else {
194 tracing::trace!("loaded partial header block");
195 *partial_inout = Some(Partial {
197 frame: Continuable::$frame(frame),
198 buf: payload,
199 continuation_frames_count: 0,
200 });
201
202 return Ok(None);
203 }
204 });
205 }
206
207 let frame = match kind {
208 Kind::Settings => {
209 let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]);
210
211 res.map_err(|e| {
212 proto_err!(conn: "failed to load SETTINGS frame; err={:?}", e);
213 Error::library_go_away(Reason::PROTOCOL_ERROR)
214 })?
215 .into()
216 }
217 Kind::Ping => {
218 let res = frame::Ping::load(head, &bytes[frame::HEADER_LEN..]);
219
220 res.map_err(|e| {
221 proto_err!(conn: "failed to load PING frame; err={:?}", e);
222 Error::library_go_away(Reason::PROTOCOL_ERROR)
223 })?
224 .into()
225 }
226 Kind::WindowUpdate => {
227 let res = frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..]);
228
229 res.map_err(|e| {
230 proto_err!(conn: "failed to load WINDOW_UPDATE frame; err={:?}", e);
231 Error::library_go_away(Reason::PROTOCOL_ERROR)
232 })?
233 .into()
234 }
235 Kind::Data => {
236 bytes.advance(frame::HEADER_LEN);
237 let res = frame::Data::load(head, bytes.freeze());
238
239 res.map_err(|e| {
241 proto_err!(conn: "failed to load DATA frame; err={:?}", e);
242 Error::library_go_away(Reason::PROTOCOL_ERROR)
243 })?
244 .into()
245 }
246 Kind::Headers => header_block!(Headers, head, bytes),
247 Kind::Reset => {
248 let res = frame::Reset::load(head, &bytes[frame::HEADER_LEN..]);
249 res.map_err(|e| {
250 proto_err!(conn: "failed to load RESET frame; err={:?}", e);
251 Error::library_go_away(Reason::PROTOCOL_ERROR)
252 })?
253 .into()
254 }
255 Kind::GoAway => {
256 let res = frame::GoAway::load(&bytes[frame::HEADER_LEN..]);
257 res.map_err(|e| {
258 proto_err!(conn: "failed to load GO_AWAY frame; err={:?}", e);
259 Error::library_go_away(Reason::PROTOCOL_ERROR)
260 })?
261 .into()
262 }
263 Kind::PushPromise => header_block!(PushPromise, head, bytes),
264 Kind::Priority => {
265 if head.stream_id() == 0 {
266 proto_err!(conn: "invalid stream ID 0");
268 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
269 }
270
271 match frame::Priority::load(head, &bytes[frame::HEADER_LEN..]) {
272 Ok(frame) => frame.into(),
273 Err(frame::Error::InvalidDependencyId) => {
274 let id = head.stream_id();
278 proto_err!(stream: "PRIORITY invalid dependency ID; stream={:?}", id);
279 return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR));
280 }
281 Err(e) => {
282 proto_err!(conn: "failed to load PRIORITY frame; err={:?};", e);
283 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
284 }
285 }
286 }
287 Kind::Continuation => {
288 let is_end_headers = (head.flag() & 0x4) == 0x4;
289
290 let mut partial = match partial_inout.take() {
291 Some(partial) => partial,
292 None => {
293 proto_err!(conn: "received unexpected CONTINUATION frame");
294 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
295 }
296 };
297
298 if partial.frame.stream_id() != head.stream_id() {
300 proto_err!(conn: "CONTINUATION frame stream ID does not match previous frame stream ID");
301 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
302 }
303
304 if is_end_headers {
306 partial.continuation_frames_count = 0;
307 } else {
308 let cnt = partial.continuation_frames_count + 1;
309 if cnt > max_continuation_frames {
310 tracing::debug!("too_many_continuations, max = {}", max_continuation_frames);
311 return Err(Error::library_go_away_data(
312 Reason::ENHANCE_YOUR_CALM,
313 "too_many_continuations",
314 ));
315 } else {
316 partial.continuation_frames_count = cnt;
317 }
318 }
319
320 if partial.buf.is_empty() {
322 partial.buf = bytes.split_off(frame::HEADER_LEN);
323 } else {
324 if partial.frame.is_over_size() {
325 if partial.buf.len() + bytes.len() > max_header_list_size {
339 proto_err!(conn: "CONTINUATION frame header block size over ignorable limit");
340 return Err(Error::library_go_away(Reason::COMPRESSION_ERROR));
341 }
342 }
343 partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]);
344 }
345
346 match partial
347 .frame
348 .load_hpack(&mut partial.buf, max_header_list_size, hpack)
349 {
350 Ok(_) => {}
351 Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {}
352 Err(frame::Error::MalformedMessage) => {
353 let id = head.stream_id();
354 proto_err!(stream: "malformed CONTINUATION frame; stream={:?}", id);
355 return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR));
356 }
357 Err(frame::Error::HeaderListWayTooLarge) => {
358 proto_err!(conn: "decoded CONTINUATION header list size over abuse limit");
359 return Err(Error::library_go_away_data(
360 Reason::ENHANCE_YOUR_CALM,
361 "header_list_way_too_large",
362 ));
363 }
364 Err(e) => {
365 proto_err!(conn: "failed HPACK decoding; err={:?}", e);
366 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
367 }
368 }
369
370 if is_end_headers {
371 partial.frame.into()
372 } else {
373 *partial_inout = Some(partial);
374 return Ok(None);
375 }
376 }
377 Kind::Unknown => {
378 return Ok(None);
380 }
381 };
382
383 Ok(Some(frame))
384}
385
386impl<T> Stream for FramedRead<T>
387where
388 T: AsyncRead + Unpin,
389{
390 type Item = Result<Frame, Error>;
391
392 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
393 let span = tracing::trace_span!("FramedRead::poll_next");
394 let _e = span.enter();
395 loop {
396 tracing::trace!("poll");
397 let bytes = match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
398 Some(Ok(bytes)) => bytes,
399 Some(Err(e)) => return Poll::Ready(Some(Err(map_err(e)))),
400 None => return Poll::Ready(None),
401 };
402
403 tracing::trace!(read.bytes = bytes.len());
404 let Self {
405 ref mut hpack,
406 max_header_list_size,
407 ref mut partial,
408 max_continuation_frames,
409 ..
410 } = *self;
411 if let Some(frame) = decode_frame(
412 hpack,
413 max_header_list_size,
414 max_continuation_frames,
415 partial,
416 bytes,
417 )? {
418 tracing::debug!(?frame, "received");
419 return Poll::Ready(Some(Ok(frame)));
420 }
421 }
422 }
423}
424
425fn map_err(err: io::Error) -> Error {
426 if let io::ErrorKind::InvalidData = err.kind() {
427 if let Some(custom) = err.get_ref() {
428 if custom.is::<LengthDelimitedCodecError>() {
429 return Error::library_go_away(Reason::FRAME_SIZE_ERROR);
430 }
431 }
432 }
433 err.into()
434}
435
436impl Continuable {
439 fn stream_id(&self) -> frame::StreamId {
440 match *self {
441 Continuable::Headers(ref h) => h.stream_id(),
442 Continuable::PushPromise(ref p) => p.stream_id(),
443 }
444 }
445
446 fn is_over_size(&self) -> bool {
447 match *self {
448 Continuable::Headers(ref h) => h.is_over_size(),
449 Continuable::PushPromise(ref p) => p.is_over_size(),
450 }
451 }
452
453 fn load_hpack(
454 &mut self,
455 src: &mut BytesMut,
456 max_header_list_size: usize,
457 decoder: &mut hpack::Decoder,
458 ) -> Result<(), frame::Error> {
459 match *self {
460 Continuable::Headers(ref mut h) => h.load_hpack(src, max_header_list_size, decoder),
461 Continuable::PushPromise(ref mut p) => p.load_hpack(src, max_header_list_size, decoder),
462 }
463 }
464}
465
466impl<T> From<Continuable> for Frame<T> {
467 fn from(cont: Continuable) -> Self {
468 match cont {
469 Continuable::Headers(mut headers) => {
470 headers.set_end_headers();
471 headers.into()
472 }
473 Continuable::PushPromise(mut push) => {
474 push.set_end_headers();
475 push.into()
476 }
477 }
478 }
479}