1use super::{
2 store, Buffer, Codec, Config, Counts, Frame, Prioritize, Prioritized, Store, Stream, StreamId,
3 StreamIdOverflow, WindowSize,
4};
5use crate::codec::UserError;
6use crate::frame::{self, Reason};
7use crate::proto::{self, Error, Initiator};
8
9use bytes::Buf;
10use tokio::io::AsyncWrite;
11
12use std::cmp::Ordering;
13use std::io;
14use std::task::{Context, Poll, Waker};
15
16#[derive(Debug)]
18pub(super) struct Send {
19 next_stream_id: Result<StreamId, StreamIdOverflow>,
21
22 max_stream_id: StreamId,
30
31 init_window_sz: WindowSize,
33
34 prioritize: Prioritize,
36
37 is_push_enabled: bool,
38
39 is_extended_connect_protocol_enabled: bool,
41}
42
43#[derive(Debug)]
45pub(crate) enum PollReset {
46 AwaitingHeaders,
47 Streaming,
48}
49
50impl Send {
51 pub fn new(config: &Config) -> Self {
53 Send {
54 init_window_sz: config.remote_init_window_sz,
55 max_stream_id: StreamId::MAX,
56 next_stream_id: Ok(config.local_next_stream_id),
57 prioritize: Prioritize::new(config),
58 is_push_enabled: true,
59 is_extended_connect_protocol_enabled: false,
60 }
61 }
62
63 pub fn init_window_sz(&self) -> WindowSize {
65 self.init_window_sz
66 }
67
68 pub fn open(&mut self) -> Result<StreamId, UserError> {
69 let stream_id = self.ensure_next_stream_id()?;
70 self.next_stream_id = stream_id.next_id();
71 Ok(stream_id)
72 }
73
74 pub fn reserve_local(&mut self) -> Result<StreamId, UserError> {
75 let stream_id = self.ensure_next_stream_id()?;
76 self.next_stream_id = stream_id.next_id();
77 Ok(stream_id)
78 }
79
80 fn check_headers(fields: &http::HeaderMap) -> Result<(), UserError> {
81 if fields.contains_key(http::header::CONNECTION)
83 || fields.contains_key(http::header::TRANSFER_ENCODING)
84 || fields.contains_key(http::header::UPGRADE)
85 || fields.contains_key("keep-alive")
86 || fields.contains_key("proxy-connection")
87 {
88 tracing::debug!("illegal connection-specific headers found");
89 return Err(UserError::MalformedHeaders);
90 } else if let Some(te) = fields.get(http::header::TE) {
91 if te != "trailers" {
92 tracing::debug!("illegal connection-specific headers found");
93 return Err(UserError::MalformedHeaders);
94 }
95 }
96 Ok(())
97 }
98
99 pub fn send_push_promise<B>(
100 &mut self,
101 frame: frame::PushPromise,
102 buffer: &mut Buffer<Frame<B>>,
103 stream: &mut store::Ptr,
104 task: &mut Option<Waker>,
105 ) -> Result<(), UserError> {
106 if !self.is_push_enabled {
107 return Err(UserError::PeerDisabledServerPush);
108 }
109
110 tracing::trace!(
111 "send_push_promise; frame={:?}; init_window={:?}",
112 frame,
113 self.init_window_sz
114 );
115
116 Self::check_headers(frame.fields())?;
117
118 self.prioritize
120 .queue_frame(frame.into(), buffer, stream, task);
121
122 Ok(())
123 }
124
125 pub fn send_headers<B>(
126 &mut self,
127 frame: frame::Headers,
128 buffer: &mut Buffer<Frame<B>>,
129 stream: &mut store::Ptr,
130 counts: &mut Counts,
131 task: &mut Option<Waker>,
132 ) -> Result<(), UserError> {
133 tracing::trace!(
134 "send_headers; frame={:?}; init_window={:?}",
135 frame,
136 self.init_window_sz
137 );
138
139 Self::check_headers(frame.fields())?;
140
141 let end_stream = frame.is_end_stream();
142
143 stream.state.send_open(end_stream)?;
145
146 let mut pending_open = false;
147 if counts.peer().is_local_init(frame.stream_id()) && !stream.is_pending_push {
148 self.prioritize.queue_open(stream);
149 pending_open = true;
150 }
151
152 self.prioritize
157 .queue_frame(frame.into(), buffer, stream, task);
158
159 if pending_open {
162 if let Some(task) = task.take() {
163 task.wake();
164 }
165 }
166
167 Ok(())
168 }
169
170 pub fn send_interim_informational_headers<B>(
173 &mut self,
174 frame: frame::Headers,
175 buffer: &mut Buffer<Frame<B>>,
176 stream: &mut store::Ptr,
177 _counts: &mut Counts,
178 task: &mut Option<Waker>,
179 ) -> Result<(), UserError> {
180 tracing::trace!(
181 "send_interim_informational_headers; frame={:?}; stream_id={:?}",
182 frame,
183 frame.stream_id()
184 );
185
186 Self::check_headers(frame.fields())?;
188
189 debug_assert!(frame.is_informational(),
190 "Frame must be informational (1xx status code) at this point. Validation should happen at the public API boundary.");
191 debug_assert!(!frame.is_end_stream(),
192 "Informational frames must not have end_stream flag set. Validation should happen at the internal send informational header streams.");
193
194 self.prioritize
197 .queue_frame(frame.into(), buffer, stream, task);
198
199 Ok(())
200 }
201
202 pub fn send_reset<B>(
204 &mut self,
205 reason: Reason,
206 initiator: Initiator,
207 buffer: &mut Buffer<Frame<B>>,
208 stream: &mut store::Ptr,
209 counts: &mut Counts,
210 task: &mut Option<Waker>,
211 ) {
212 let is_reset = stream.state.is_reset();
213 let is_closed = stream.state.is_closed();
214 let is_empty = stream.pending_send.is_empty();
215 let stream_id = stream.id;
216
217 tracing::trace!(
218 "send_reset(..., reason={:?}, initiator={:?}, stream={:?}, ..., \
219 is_reset={:?}; is_closed={:?}; pending_send.is_empty={:?}; \
220 state={:?} \
221 ",
222 reason,
223 initiator,
224 stream_id,
225 is_reset,
226 is_closed,
227 is_empty,
228 stream.state
229 );
230
231 if is_reset {
232 tracing::trace!(
234 " -> not sending RST_STREAM ({:?} is already reset)",
235 stream_id
236 );
237 return;
238 }
239
240 stream.set_reset(reason, initiator);
242
243 if is_closed && is_empty {
246 tracing::trace!(
247 " -> not sending explicit RST_STREAM ({:?} was closed \
248 and send queue was flushed)",
249 stream_id
250 );
251 return;
252 }
253
254 if !stream.is_pending_open {
262 self.prioritize.clear_queue(buffer, stream);
269 }
270
271 let frame = frame::Reset::new(stream.id, reason);
272
273 tracing::trace!("send_reset -- queueing; frame={:?}", frame);
274 self.prioritize
275 .queue_frame(frame.into(), buffer, stream, task);
276 self.prioritize.reclaim_all_capacity(stream, counts);
277 }
278
279 pub fn schedule_implicit_reset(
280 &mut self,
281 stream: &mut store::Ptr,
282 reason: Reason,
283 counts: &mut Counts,
284 task: &mut Option<Waker>,
285 ) {
286 if stream.state.is_closed() {
287 return;
289 }
290
291 stream.state.set_scheduled_reset(reason);
292
293 self.prioritize.reclaim_reserved_capacity(stream, counts);
294 self.prioritize.schedule_send(stream, task);
295 }
296
297 pub fn send_data<B>(
298 &mut self,
299 frame: frame::Data<B>,
300 buffer: &mut Buffer<Frame<B>>,
301 stream: &mut store::Ptr,
302 counts: &mut Counts,
303 task: &mut Option<Waker>,
304 ) -> Result<(), UserError>
305 where
306 B: Buf,
307 {
308 self.prioritize
309 .send_data(frame, buffer, stream, counts, task)
310 }
311
312 pub fn send_trailers<B>(
313 &mut self,
314 frame: frame::Headers,
315 buffer: &mut Buffer<Frame<B>>,
316 stream: &mut store::Ptr,
317 counts: &mut Counts,
318 task: &mut Option<Waker>,
319 ) -> Result<(), UserError> {
320 if !stream.state.is_send_streaming() {
322 return Err(UserError::UnexpectedFrameType);
323 }
324
325 stream.state.send_close();
326
327 tracing::trace!("send_trailers -- queuing; frame={:?}", frame);
328 self.prioritize
329 .queue_frame(frame.into(), buffer, stream, task);
330
331 self.prioritize.reserve_capacity(0, stream, counts);
333
334 Ok(())
335 }
336
337 pub fn poll_complete<T, B>(
338 &mut self,
339 cx: &mut Context,
340 buffer: &mut Buffer<Frame<B>>,
341 store: &mut Store,
342 counts: &mut Counts,
343 dst: &mut Codec<T, Prioritized<B>>,
344 ) -> Poll<io::Result<()>>
345 where
346 T: AsyncWrite + Unpin,
347 B: Buf,
348 {
349 self.prioritize
350 .poll_complete(cx, buffer, store, counts, dst)
351 }
352
353 pub fn reserve_capacity(
355 &mut self,
356 capacity: WindowSize,
357 stream: &mut store::Ptr,
358 counts: &mut Counts,
359 ) {
360 self.prioritize.reserve_capacity(capacity, stream, counts)
361 }
362
363 pub fn poll_capacity(
364 &mut self,
365 cx: &Context,
366 stream: &mut store::Ptr,
367 ) -> Poll<Option<Result<WindowSize, UserError>>> {
368 if !stream.state.is_send_streaming() {
369 return Poll::Ready(None);
370 }
371
372 if !stream.send_capacity_inc {
373 stream.wait_send(cx);
374 return Poll::Pending;
375 }
376
377 stream.send_capacity_inc = false;
378
379 let capacity = self.capacity(stream);
380
381 if capacity == 0 {
384 stream.wait_send(cx);
385 return Poll::Pending;
386 }
387
388 Poll::Ready(Some(Ok(capacity)))
389 }
390
391 pub fn capacity(&self, stream: &mut store::Ptr) -> WindowSize {
393 stream.capacity(self.prioritize.max_buffer_size())
394 }
395
396 pub fn poll_reset(
397 &self,
398 cx: &Context,
399 stream: &mut Stream,
400 mode: PollReset,
401 ) -> Poll<Result<Reason, crate::Error>> {
402 match stream.state.ensure_reason(mode)? {
403 Some(reason) => Poll::Ready(Ok(reason)),
404 None => {
405 stream.wait_send(cx);
406 Poll::Pending
407 }
408 }
409 }
410
411 pub fn recv_connection_window_update(
412 &mut self,
413 frame: frame::WindowUpdate,
414 store: &mut Store,
415 counts: &mut Counts,
416 ) -> Result<(), Reason> {
417 self.prioritize
418 .recv_connection_window_update(frame.size_increment(), store, counts)
419 }
420
421 pub fn recv_stream_window_update<B>(
422 &mut self,
423 sz: WindowSize,
424 buffer: &mut Buffer<Frame<B>>,
425 stream: &mut store::Ptr,
426 counts: &mut Counts,
427 task: &mut Option<Waker>,
428 ) -> Result<(), Reason> {
429 if let Err(e) = self.prioritize.recv_stream_window_update(sz, stream) {
430 tracing::debug!("recv_stream_window_update !!; err={:?}", e);
431
432 self.send_reset(
433 Reason::FLOW_CONTROL_ERROR,
434 Initiator::Library,
435 buffer,
436 stream,
437 counts,
438 task,
439 );
440
441 return Err(e);
442 }
443
444 Ok(())
445 }
446
447 pub(super) fn recv_go_away(&mut self, last_stream_id: StreamId) -> Result<(), Error> {
448 if last_stream_id > self.max_stream_id {
449 proto_err!(conn:
457 "recv_go_away: last_stream_id ({:?}) > max_stream_id ({:?})",
458 last_stream_id, self.max_stream_id,
459 );
460 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
461 }
462
463 self.max_stream_id = last_stream_id;
464 Ok(())
465 }
466
467 pub fn handle_error<B>(
468 &mut self,
469 buffer: &mut Buffer<Frame<B>>,
470 stream: &mut store::Ptr,
471 counts: &mut Counts,
472 ) {
473 self.prioritize.clear_queue(buffer, stream);
475 self.prioritize.reclaim_all_capacity(stream, counts);
476 }
477
478 pub fn apply_remote_settings<B>(
479 &mut self,
480 settings: &frame::Settings,
481 buffer: &mut Buffer<Frame<B>>,
482 store: &mut Store,
483 counts: &mut Counts,
484 task: &mut Option<Waker>,
485 ) -> Result<(), Error> {
486 if let Some(val) = settings.is_extended_connect_protocol_enabled() {
487 self.is_extended_connect_protocol_enabled = val;
488 }
489
490 if let Some(val) = settings.initial_window_size() {
508 let old_val = self.init_window_sz;
509 self.init_window_sz = val;
510
511 match val.cmp(&old_val) {
512 Ordering::Less => {
513 let dec = old_val - val;
515 tracing::trace!("decrementing all windows; dec={}", dec);
516
517 let mut total_reclaimed = 0;
518 store.try_for_each(|mut stream| {
519 let stream = &mut *stream;
520
521 if stream.state.is_send_closed() && stream.buffered_send_data == 0 {
522 tracing::trace!(
523 "skipping send-closed stream; id={:?}; flow={:?}",
524 stream.id,
525 stream.send_flow
526 );
527
528 return Ok(());
529 }
530
531 tracing::trace!(
532 "decrementing stream window; id={:?}; decr={}; flow={:?}",
533 stream.id,
534 dec,
535 stream.send_flow
536 );
537
538 stream
540 .send_flow
541 .dec_send_window(dec)
542 .map_err(proto::Error::library_go_away)?;
543
544 let window_size = stream.send_flow.window_size();
551 let available = stream.send_flow.available().as_size();
552 let reclaimed = if available > window_size {
553 let reclaim = available - window_size;
555 stream
556 .send_flow
557 .claim_capacity(reclaim)
558 .map_err(proto::Error::library_go_away)?;
559 total_reclaimed += reclaim;
560 reclaim
561 } else {
562 0
563 };
564
565 tracing::trace!(
566 "decremented stream window; id={:?}; decr={}; reclaimed={}; flow={:?}",
567 stream.id,
568 dec,
569 reclaimed,
570 stream.send_flow
571 );
572
573 Ok::<_, proto::Error>(())
578 })?;
579
580 self.prioritize
581 .assign_connection_capacity(total_reclaimed, store, counts);
582 }
583 Ordering::Greater => {
584 let inc = val - old_val;
585
586 store.try_for_each(|mut stream| {
587 self.recv_stream_window_update(inc, buffer, &mut stream, counts, task)
588 .map_err(Error::library_go_away)
589 })?;
590 }
591 Ordering::Equal => (),
592 }
593 }
594
595 if let Some(val) = settings.is_push_enabled() {
596 self.is_push_enabled = val
597 }
598
599 Ok(())
600 }
601
602 pub fn clear_queues(&mut self, store: &mut Store, counts: &mut Counts) {
603 self.prioritize.clear_pending_capacity(store, counts);
604 self.prioritize.clear_pending_send(store, counts);
605 self.prioritize.clear_pending_open(store, counts);
606 }
607
608 pub fn ensure_not_idle(&self, id: StreamId) -> Result<(), Reason> {
609 if let Ok(next) = self.next_stream_id {
610 if id >= next {
611 return Err(Reason::PROTOCOL_ERROR);
612 }
613 }
614 Ok(())
617 }
618
619 pub fn ensure_next_stream_id(&self) -> Result<StreamId, UserError> {
620 self.next_stream_id
621 .map_err(|_| UserError::OverflowedStreamId)
622 }
623
624 pub fn may_have_created_stream(&self, id: StreamId) -> bool {
625 if let Ok(next_id) = self.next_stream_id {
626 debug_assert_eq!(id.is_server_initiated(), next_id.is_server_initiated(),);
628 id < next_id
629 } else {
630 true
631 }
632 }
633
634 pub(super) fn maybe_reset_next_stream_id(&mut self, id: StreamId) {
635 if let Ok(next_id) = self.next_stream_id {
636 debug_assert_eq!(id.is_server_initiated(), next_id.is_server_initiated());
638 if id >= next_id {
639 self.next_stream_id = id.next_id();
640 }
641 }
642 }
643
644 pub(crate) fn is_extended_connect_protocol_enabled(&self) -> bool {
645 self.is_extended_connect_protocol_enabled
646 }
647}