1use crate::codec::UserError;
2use crate::frame::{Reason, StreamId};
3use crate::{client, server};
4
5use crate::frame::DEFAULT_INITIAL_WINDOW_SIZE;
6use crate::proto::*;
7
8use bytes::Bytes;
9use futures_core::Stream;
10use std::io;
11use std::marker::PhantomData;
12use std::pin::Pin;
13use std::task::{Context, Poll};
14use std::time::Duration;
15use tokio::io::AsyncRead;
16
17#[derive(Debug)]
19pub(crate) struct Connection<T, P, B: Buf = Bytes>
20where
21 P: Peer,
22{
23 codec: Codec<T, Prioritized<B>>,
25
26 inner: ConnectionInner<P, B>,
27}
28
29#[derive(Debug)]
32struct ConnectionInner<P, B: Buf = Bytes>
33where
34 P: Peer,
35{
36 state: State,
38
39 error: Option<frame::GoAway>,
44
45 go_away: GoAway,
47
48 ping_pong: PingPong,
50
51 settings: Settings,
53
54 streams: Streams<B, P>,
56
57 span: tracing::Span,
59
60 _phantom: PhantomData<P>,
62}
63
64struct DynConnection<'a, B: Buf = Bytes> {
65 state: &'a mut State,
66
67 go_away: &'a mut GoAway,
68
69 streams: DynStreams<'a, B>,
70
71 error: &'a mut Option<frame::GoAway>,
72
73 ping_pong: &'a mut PingPong,
74}
75
76#[derive(Debug, Clone)]
77pub(crate) struct Config {
78 pub next_stream_id: StreamId,
79 pub initial_max_send_streams: usize,
80 pub max_send_buffer_size: usize,
81 pub reset_stream_duration: Duration,
82 pub reset_stream_max: usize,
83 pub remote_reset_stream_max: usize,
84 pub local_error_reset_streams_max: Option<usize>,
85 pub settings: frame::Settings,
86}
87
88#[derive(Debug)]
89enum State {
90 Open,
92
93 Closing(Reason, Initiator),
95
96 Closed(Reason, Initiator),
98}
99
100impl<T, P, B> Connection<T, P, B>
101where
102 T: AsyncRead + AsyncWrite + Unpin,
103 P: Peer,
104 B: Buf,
105{
106 pub fn new(codec: Codec<T, Prioritized<B>>, config: Config) -> Connection<T, P, B> {
107 fn streams_config(config: &Config) -> streams::Config {
108 streams::Config {
109 initial_max_send_streams: config.initial_max_send_streams,
110 local_max_buffer_size: config.max_send_buffer_size,
111 local_next_stream_id: config.next_stream_id,
112 local_push_enabled: config.settings.is_push_enabled().unwrap_or(true),
113 extended_connect_protocol_enabled: config
114 .settings
115 .is_extended_connect_protocol_enabled()
116 .unwrap_or(false),
117 local_reset_duration: config.reset_stream_duration,
118 local_reset_max: config.reset_stream_max,
119 remote_reset_max: config.remote_reset_stream_max,
120 remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE,
121 remote_max_initiated: config
122 .settings
123 .max_concurrent_streams()
124 .map(|max| max as usize),
125 local_max_error_reset_streams: config.local_error_reset_streams_max,
126 }
127 }
128 let streams = Streams::new(streams_config(&config));
129 let span = tracing::debug_span!(parent: None, "Connection", peer = %P::NAME);
130 span.follows_from(tracing::Span::current());
131 Connection {
132 codec,
133 inner: ConnectionInner {
134 state: State::Open,
135 error: None,
136 go_away: GoAway::new(),
137 ping_pong: PingPong::new(),
138 settings: Settings::new(config.settings),
139 streams,
140 span,
141 _phantom: PhantomData,
142 },
143 }
144 }
145
146 pub(crate) fn set_target_window_size(&mut self, size: WindowSize) {
148 let _res = self.inner.streams.set_target_connection_window_size(size);
149 debug_assert!(_res.is_ok());
151 }
152
153 pub(crate) fn set_initial_window_size(&mut self, size: WindowSize) -> Result<(), UserError> {
155 let mut settings = frame::Settings::default();
156 settings.set_initial_window_size(Some(size));
157 self.inner.settings.send_settings(settings)
158 }
159
160 pub(crate) fn set_enable_connect_protocol(&mut self) -> Result<(), UserError> {
162 let mut settings = frame::Settings::default();
163 settings.set_enable_connect_protocol(Some(1));
164 self.inner.settings.send_settings(settings)
165 }
166
167 pub(crate) fn max_send_streams(&self) -> usize {
170 self.inner.streams.max_send_streams()
171 }
172
173 pub(crate) fn max_recv_streams(&self) -> usize {
176 self.inner.streams.max_recv_streams()
177 }
178
179 #[cfg(feature = "unstable")]
180 pub fn num_wired_streams(&self) -> usize {
181 self.inner.streams.num_wired_streams()
182 }
183
184 fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Error>> {
189 let _e = self.inner.span.enter();
190 let span = tracing::trace_span!("poll_ready");
191 let _e = span.enter();
192 ready!(self.inner.ping_pong.send_pending_pong(cx, &mut self.codec))?;
194 ready!(self.inner.ping_pong.send_pending_ping(cx, &mut self.codec))?;
195 ready!(self
196 .inner
197 .settings
198 .poll_send(cx, &mut self.codec, &mut self.inner.streams))?;
199 ready!(self.inner.streams.send_pending_refusal(cx, &mut self.codec))?;
200
201 Poll::Ready(Ok(()))
202 }
203
204 fn poll_go_away(&mut self, cx: &mut Context) -> Poll<Option<io::Result<Reason>>> {
209 self.inner.go_away.send_pending_go_away(cx, &mut self.codec)
210 }
211
212 pub fn go_away_from_user(&mut self, e: Reason) {
213 self.inner.as_dyn().go_away_from_user(e)
214 }
215
216 fn take_error(&mut self, ours: Reason, initiator: Initiator) -> Result<(), Error> {
217 let (debug_data, theirs) = self
218 .inner
219 .error
220 .take()
221 .as_ref()
222 .map_or((Bytes::new(), Reason::NO_ERROR), |frame| {
223 (frame.debug_data().clone(), frame.reason())
224 });
225
226 match (ours, theirs) {
227 (Reason::NO_ERROR, Reason::NO_ERROR) => Ok(()),
228 (ours, Reason::NO_ERROR) => Err(Error::GoAway(Bytes::new(), ours, initiator)),
229 (_, theirs) => Err(Error::remote_go_away(debug_data, theirs)),
234 }
235 }
236
237 pub fn maybe_close_connection_if_no_streams(&mut self) {
240 if !self.inner.streams.has_streams_or_other_references() {
243 self.inner.as_dyn().go_away_now(Reason::NO_ERROR);
244 }
245 }
246
247 pub fn has_streams(&self) -> bool {
249 self.inner.streams.has_streams()
250 }
251
252 pub fn has_streams_or_other_references(&self) -> bool {
254 self.inner.streams.has_streams_or_other_references()
257 }
258
259 pub(crate) fn take_user_pings(&mut self) -> Option<UserPings> {
260 self.inner.ping_pong.take_user_pings()
261 }
262
263 pub fn poll(&mut self, cx: &mut Context) -> Poll<Result<(), Error>> {
265 let span = self.inner.span.clone();
270 let _e = span.enter();
271 let span = tracing::trace_span!("poll");
272 let _e = span.enter();
273
274 loop {
275 tracing::trace!(connection.state = ?self.inner.state);
276 match self.inner.state {
278 State::Open => {
280 let result = match self.poll2(cx) {
281 Poll::Ready(result) => result,
282 Poll::Pending => {
284 ready!(self.inner.streams.poll_complete(cx, &mut self.codec))?;
288
289 if (self.inner.error.is_some()
290 || self.inner.go_away.should_close_on_idle())
291 && !self.inner.streams.has_streams()
292 {
293 self.inner.as_dyn().go_away_now(Reason::NO_ERROR);
294 continue;
295 }
296
297 return Poll::Pending;
298 }
299 };
300
301 self.inner.as_dyn().handle_poll2_result(result)?
302 }
303 State::Closing(reason, initiator) => {
304 tracing::trace!("connection closing after flush");
305 ready!(self.codec.shutdown(cx))?;
307
308 self.inner.state = State::Closed(reason, initiator);
310 }
311 State::Closed(reason, initiator) => {
312 return Poll::Ready(self.take_error(reason, initiator));
313 }
314 }
315 }
316 }
317
318 fn poll2(&mut self, cx: &mut Context) -> Poll<Result<(), Error>> {
319 self.clear_expired_reset_streams();
323
324 loop {
325 if let Some(reason) = ready!(self.poll_go_away(cx)?) {
331 if self.inner.go_away.should_close_now() {
332 if self.inner.go_away.is_user_initiated() {
333 return Poll::Ready(Ok(()));
336 } else {
337 return Poll::Ready(Err(Error::library_go_away(reason)));
338 }
339 }
340 debug_assert_eq!(
342 reason,
343 Reason::NO_ERROR,
344 "graceful GOAWAY should be NO_ERROR"
345 );
346 }
347 ready!(self.poll_ready(cx))?;
348
349 match self
350 .inner
351 .as_dyn()
352 .recv_frame(ready!(Pin::new(&mut self.codec).poll_next(cx)?))?
353 {
354 ReceivedFrame::Settings(frame) => {
355 self.inner.settings.recv_settings(
356 frame,
357 &mut self.codec,
358 &mut self.inner.streams,
359 )?;
360 }
361 ReceivedFrame::Continue => (),
362 ReceivedFrame::Done => {
363 return Poll::Ready(Ok(()));
364 }
365 }
366 }
367 }
368
369 fn clear_expired_reset_streams(&mut self) {
370 self.inner.streams.clear_expired_reset_streams();
371 }
372}
373
374impl<P, B> ConnectionInner<P, B>
375where
376 P: Peer,
377 B: Buf,
378{
379 fn as_dyn(&mut self) -> DynConnection<'_, B> {
380 let ConnectionInner {
381 state,
382 go_away,
383 streams,
384 error,
385 ping_pong,
386 ..
387 } = self;
388 let streams = streams.as_dyn();
389 DynConnection {
390 state,
391 go_away,
392 streams,
393 error,
394 ping_pong,
395 }
396 }
397}
398
399impl<B> DynConnection<'_, B>
400where
401 B: Buf,
402{
403 fn go_away(&mut self, id: StreamId, e: Reason) {
404 let frame = frame::GoAway::new(id, e);
405 self.streams.send_go_away(id);
406 self.go_away.go_away(frame);
407 }
408
409 fn go_away_now(&mut self, e: Reason) {
410 let last_processed_id = self.streams.last_processed_id();
411 let frame = frame::GoAway::new(last_processed_id, e);
412 self.go_away.go_away_now(frame);
413 }
414
415 fn go_away_now_data(&mut self, e: Reason, data: Bytes) {
416 let last_processed_id = self.streams.last_processed_id();
417 let frame = frame::GoAway::with_debug_data(last_processed_id, e, data);
418 self.go_away.go_away_now(frame);
419 }
420
421 fn go_away_from_user(&mut self, e: Reason) {
422 let last_processed_id = self.streams.last_processed_id();
423 let frame = frame::GoAway::new(last_processed_id, e);
424 self.go_away.go_away_from_user(frame);
425
426 self.streams.handle_error(Error::user_go_away(e));
428 }
429
430 fn handle_poll2_result(&mut self, result: Result<(), Error>) -> Result<(), Error> {
431 match result {
432 Ok(()) => {
434 *self.state = State::Closing(Reason::NO_ERROR, Initiator::Library);
435 Ok(())
436 }
437 Err(Error::GoAway(debug_data, reason, initiator)) => {
441 self.handle_go_away(reason, debug_data, initiator);
442 Ok(())
443 }
444 Err(Error::Reset(id, reason, initiator)) => {
449 if initiator == Initiator::Remote {
450 tracing::trace!(?id, ?reason, ?initiator, "stream reset");
451 return Ok(());
452 }
453
454 debug_assert_eq!(initiator, Initiator::Library);
455 tracing::trace!(?id, ?reason, ?initiator, "stream error");
456 match self.streams.send_reset(id, reason) {
457 Ok(()) => (),
458 Err(crate::proto::error::GoAway { debug_data, reason }) => {
459 self.handle_go_away(reason, debug_data, Initiator::Library);
460 }
461 }
462 Ok(())
463 }
464 Err(Error::Io(kind, inner)) => {
469 tracing::debug!(error = ?kind, "Connection::poll; IO error");
470 let e = Error::Io(kind, inner);
471
472 self.streams.handle_error(e.clone());
474
475 if self.streams.is_buffer_empty()
482 && matches!(kind, io::ErrorKind::UnexpectedEof)
483 && (self.streams.is_server()
484 || self.error.as_ref().map(|f| f.reason() == Reason::NO_ERROR)
485 == Some(true))
486 {
487 *self.state = State::Closed(Reason::NO_ERROR, Initiator::Library);
488 return Ok(());
489 }
490
491 Err(e)
493 }
494 }
495 }
496
497 fn handle_go_away(&mut self, reason: Reason, debug_data: Bytes, initiator: Initiator) {
498 let e = Error::GoAway(debug_data.clone(), reason, initiator);
499 tracing::debug!(error = ?e, "Connection::poll; connection error");
500
501 if self
504 .go_away
505 .going_away()
506 .map_or(false, |frame| frame.reason() == reason)
507 {
508 tracing::trace!(" -> already going away");
509 *self.state = State::Closing(reason, initiator);
510 return;
511 }
512
513 self.streams.handle_error(e);
515 self.go_away_now_data(reason, debug_data);
516 }
517
518 fn recv_frame(&mut self, frame: Option<Frame>) -> Result<ReceivedFrame, Error> {
519 use crate::frame::Frame::*;
520 match frame {
521 Some(Headers(frame)) => {
522 tracing::trace!(?frame, "recv HEADERS");
523 self.streams.recv_headers(frame)?;
524 }
525 Some(Data(frame)) => {
526 tracing::trace!(?frame, "recv DATA");
527 self.streams.recv_data(frame)?;
528 }
529 Some(Reset(frame)) => {
530 tracing::trace!(?frame, "recv RST_STREAM");
531 self.streams.recv_reset(frame)?;
532 }
533 Some(PushPromise(frame)) => {
534 tracing::trace!(?frame, "recv PUSH_PROMISE");
535 self.streams.recv_push_promise(frame)?;
536 }
537 Some(Settings(frame)) => {
538 tracing::trace!(?frame, "recv SETTINGS");
539 return Ok(ReceivedFrame::Settings(frame));
540 }
541 Some(GoAway(frame)) => {
542 tracing::trace!(?frame, "recv GOAWAY");
543 self.streams.recv_go_away(&frame)?;
548 *self.error = Some(frame);
549 }
550 Some(Ping(frame)) => {
551 tracing::trace!(?frame, "recv PING");
552 let status = self.ping_pong.recv_ping(frame);
553 if status.is_shutdown() {
554 assert!(
555 self.go_away.is_going_away(),
556 "received unexpected shutdown ping"
557 );
558
559 let last_processed_id = self.streams.last_processed_id();
560 self.go_away(last_processed_id, Reason::NO_ERROR);
561 }
562 }
563 Some(WindowUpdate(frame)) => {
564 tracing::trace!(?frame, "recv WINDOW_UPDATE");
565 self.streams.recv_window_update(frame)?;
566 }
567 Some(Priority(frame)) => {
568 tracing::trace!(?frame, "recv PRIORITY");
569 }
571 None => {
572 tracing::trace!("codec closed");
573 self.streams.recv_eof(false).expect("mutex poisoned");
574 return Ok(ReceivedFrame::Done);
575 }
576 }
577 Ok(ReceivedFrame::Continue)
578 }
579}
580
581enum ReceivedFrame {
582 Settings(frame::Settings),
583 Continue,
584 Done,
585}
586
587impl<T, B> Connection<T, client::Peer, B>
588where
589 T: AsyncRead + AsyncWrite,
590 B: Buf,
591{
592 pub(crate) fn streams(&self) -> &Streams<B, client::Peer> {
593 &self.inner.streams
594 }
595}
596
597impl<T, B> Connection<T, server::Peer, B>
598where
599 T: AsyncRead + AsyncWrite + Unpin,
600 B: Buf,
601{
602 pub fn next_incoming(&mut self) -> Option<StreamRef<B>> {
603 self.inner.streams.next_incoming()
604 }
605
606 pub fn go_away_gracefully(&mut self) {
608 if self.inner.go_away.is_going_away() {
609 return;
611 }
612
613 self.inner.as_dyn().go_away(StreamId::MAX, Reason::NO_ERROR);
625
626 self.inner.ping_pong.ping_shutdown();
629 }
630}
631
632impl<T, P, B> Drop for Connection<T, P, B>
633where
634 P: Peer,
635 B: Buf,
636{
637 fn drop(&mut self) {
638 let _ = self.inner.streams.recv_eof(true);
640 }
641}