hyper/proto/h2/
upgrade.rs1use std::future::Future;
2use std::io::Cursor;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use bytes::{Buf, Bytes};
7use futures_channel::{mpsc, oneshot};
8use futures_core::{ready, Stream};
9use h2::{Reason, RecvStream, SendStream};
10use pin_project_lite::pin_project;
11
12use super::ping::Recorder;
13use super::SendBuf;
14use crate::rt::{Read, ReadBufCursor, Write};
15
16pub(super) fn pair<B>(
17 send_stream: SendStream<SendBuf<B>>,
18 recv_stream: RecvStream,
19 ping: Recorder,
20) -> (H2Upgraded, UpgradedSendStreamTask<B>) {
21 let (tx, rx) = mpsc::channel(1);
22 let (error_tx, error_rx) = oneshot::channel();
23
24 (
25 H2Upgraded {
26 send_stream: UpgradedSendStreamBridge { tx, error_rx },
27 recv_stream,
28 ping,
29 buf: Bytes::new(),
30 },
31 UpgradedSendStreamTask {
32 h2_tx: send_stream,
33 rx,
34 error_tx: Some(error_tx),
35 },
36 )
37}
38
39pub(super) struct H2Upgraded {
40 ping: Recorder,
41 send_stream: UpgradedSendStreamBridge,
42 recv_stream: RecvStream,
43 buf: Bytes,
44}
45
46struct UpgradedSendStreamBridge {
47 tx: mpsc::Sender<Cursor<Box<[u8]>>>,
48 error_rx: oneshot::Receiver<crate::Error>,
49}
50
51pin_project! {
52 #[must_use = "futures do nothing unless polled"]
53 pub struct UpgradedSendStreamTask<B> {
54 #[pin]
55 h2_tx: SendStream<SendBuf<B>>,
56 #[pin]
57 rx: mpsc::Receiver<Cursor<Box<[u8]>>>,
58 error_tx: Option<oneshot::Sender<crate::Error>>,
59 }
60}
61
62impl<B> UpgradedSendStreamTask<B>
65where
66 B: Buf,
67{
68 fn tick(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), crate::Error>> {
69 let mut me = self.project();
70
71 loop {
76 me.h2_tx.reserve_capacity(1);
80
81 if me.h2_tx.capacity() == 0 {
82 'capacity: loop {
84 match me.h2_tx.poll_capacity(cx) {
85 Poll::Ready(Some(Ok(0))) => {}
86 Poll::Ready(Some(Ok(_))) => break,
87 Poll::Ready(Some(Err(e))) => {
88 return Poll::Ready(Err(crate::Error::new_body_write(e)))
89 }
90 Poll::Ready(None) => {
91 return Poll::Ready(Err(crate::Error::new_body_write(
95 "send stream capacity unexpectedly closed",
96 )));
97 }
98 Poll::Pending => break 'capacity,
99 }
100 }
101 }
102
103 match me.h2_tx.poll_reset(cx) {
104 Poll::Ready(Ok(reason)) => {
105 trace!("stream received RST_STREAM: {:?}", reason);
106 return Poll::Ready(Err(crate::Error::new_body_write(::h2::Error::from(
107 reason,
108 ))));
109 }
110 Poll::Ready(Err(err)) => {
111 return Poll::Ready(Err(crate::Error::new_body_write(err)))
112 }
113 Poll::Pending => (),
114 }
115
116 match me.rx.as_mut().poll_next(cx) {
117 Poll::Ready(Some(cursor)) => {
118 me.h2_tx
119 .send_data(SendBuf::Cursor(cursor), false)
120 .map_err(crate::Error::new_body_write)?;
121 }
122 Poll::Ready(None) => {
123 me.h2_tx
124 .send_data(SendBuf::None, true)
125 .map_err(crate::Error::new_body_write)?;
126 return Poll::Ready(Ok(()));
127 }
128 Poll::Pending => {
129 return Poll::Pending;
130 }
131 }
132 }
133 }
134}
135
136impl<B> Future for UpgradedSendStreamTask<B>
137where
138 B: Buf,
139{
140 type Output = ();
141
142 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
143 match self.as_mut().tick(cx) {
144 Poll::Ready(Ok(())) => Poll::Ready(()),
145 Poll::Ready(Err(err)) => {
146 if let Some(tx) = self.error_tx.take() {
147 let _oh_well = tx.send(err);
148 }
149 Poll::Ready(())
150 }
151 Poll::Pending => Poll::Pending,
152 }
153 }
154}
155
156impl Read for H2Upgraded {
159 fn poll_read(
160 mut self: Pin<&mut Self>,
161 cx: &mut Context<'_>,
162 mut read_buf: ReadBufCursor<'_>,
163 ) -> Poll<Result<(), std::io::Error>> {
164 if self.buf.is_empty() {
165 self.buf = loop {
166 match ready!(self.recv_stream.poll_data(cx)) {
167 None => return Poll::Ready(Ok(())),
168 Some(Ok(buf)) if buf.is_empty() && !self.recv_stream.is_end_stream() => {
169 continue
170 }
171 Some(Ok(buf)) => {
172 self.ping.record_data(buf.len());
173 break buf;
174 }
175 Some(Err(e)) => {
176 return Poll::Ready(match e.reason() {
177 Some(Reason::NO_ERROR) | Some(Reason::CANCEL) => Ok(()),
178 Some(Reason::STREAM_CLOSED) => {
179 Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))
180 }
181 _ => Err(h2_to_io_error(e)),
182 })
183 }
184 }
185 };
186 }
187 let cnt = std::cmp::min(self.buf.len(), read_buf.remaining());
188 read_buf.put_slice(&self.buf[..cnt]);
189 self.buf.advance(cnt);
190 let _ = self.recv_stream.flow_control().release_capacity(cnt);
191 Poll::Ready(Ok(()))
192 }
193}
194
195impl Write for H2Upgraded {
196 fn poll_write(
197 mut self: Pin<&mut Self>,
198 cx: &mut Context<'_>,
199 buf: &[u8],
200 ) -> Poll<Result<usize, std::io::Error>> {
201 if buf.is_empty() {
202 return Poll::Ready(Ok(0));
203 }
204
205 match self.send_stream.tx.poll_ready(cx) {
206 Poll::Ready(Ok(())) => {}
207 Poll::Ready(Err(_task_dropped)) => {
208 return match Pin::new(&mut self.send_stream.error_rx).poll(cx) {
211 Poll::Ready(Ok(reason)) => Poll::Ready(Err(io_error(reason))),
212 Poll::Ready(Err(_task_dropped)) => {
213 Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()))
214 }
215 Poll::Pending => Poll::Pending,
216 };
217 }
218 Poll::Pending => return Poll::Pending,
219 }
220
221 let n = buf.len();
222 match self.send_stream.tx.start_send(Cursor::new(buf.into())) {
223 Ok(()) => Poll::Ready(Ok(n)),
224 Err(_task_dropped) => {
225 match Pin::new(&mut self.send_stream.error_rx).poll(cx) {
228 Poll::Ready(Ok(reason)) => Poll::Ready(Err(io_error(reason))),
229 Poll::Ready(Err(_task_dropped)) => {
230 Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()))
231 }
232 Poll::Pending => Poll::Pending,
233 }
234 }
235 }
236 }
237
238 fn poll_flush(
239 mut self: Pin<&mut Self>,
240 cx: &mut Context<'_>,
241 ) -> Poll<Result<(), std::io::Error>> {
242 match self.send_stream.tx.poll_ready(cx) {
243 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
244 Poll::Ready(Err(_task_dropped)) => {
245 match Pin::new(&mut self.send_stream.error_rx).poll(cx) {
248 Poll::Ready(Ok(reason)) => Poll::Ready(Err(io_error(reason))),
249 Poll::Ready(Err(_task_dropped)) => Poll::Ready(Ok(())),
250 Poll::Pending => Poll::Pending,
251 }
252 }
253 Poll::Pending => Poll::Pending,
254 }
255 }
256
257 fn poll_shutdown(
258 mut self: Pin<&mut Self>,
259 cx: &mut Context<'_>,
260 ) -> Poll<Result<(), std::io::Error>> {
261 self.send_stream.tx.close_channel();
262 match Pin::new(&mut self.send_stream.error_rx).poll(cx) {
263 Poll::Ready(Ok(reason)) => Poll::Ready(Err(io_error(reason))),
264 Poll::Ready(Err(_task_dropped)) => Poll::Ready(Ok(())),
265 Poll::Pending => Poll::Pending,
266 }
267 }
268}
269
270fn io_error(e: crate::Error) -> std::io::Error {
271 std::io::Error::new(std::io::ErrorKind::Other, e)
272}
273
274fn h2_to_io_error(e: h2::Error) -> std::io::Error {
275 if e.is_io() {
276 e.into_io().unwrap()
277 } else {
278 std::io::Error::new(std::io::ErrorKind::Other, e)
279 }
280}