hyper/proto/h2/
upgrade.rs

1use std::future::Future;
2use std::io::Cursor;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use bytes::{Buf, Bytes};
7use futures_channel::{mpsc, oneshot};
8use futures_core::{ready, Stream};
9use h2::{Reason, RecvStream, SendStream};
10use pin_project_lite::pin_project;
11
12use super::ping::Recorder;
13use super::SendBuf;
14use crate::rt::{Read, ReadBufCursor, Write};
15
16pub(super) fn pair<B>(
17    send_stream: SendStream<SendBuf<B>>,
18    recv_stream: RecvStream,
19    ping: Recorder,
20) -> (H2Upgraded, UpgradedSendStreamTask<B>) {
21    let (tx, rx) = mpsc::channel(1);
22    let (error_tx, error_rx) = oneshot::channel();
23
24    (
25        H2Upgraded {
26            send_stream: UpgradedSendStreamBridge { tx, error_rx },
27            recv_stream,
28            ping,
29            buf: Bytes::new(),
30        },
31        UpgradedSendStreamTask {
32            h2_tx: send_stream,
33            rx,
34            error_tx: Some(error_tx),
35        },
36    )
37}
38
39pub(super) struct H2Upgraded {
40    ping: Recorder,
41    send_stream: UpgradedSendStreamBridge,
42    recv_stream: RecvStream,
43    buf: Bytes,
44}
45
46struct UpgradedSendStreamBridge {
47    tx: mpsc::Sender<Cursor<Box<[u8]>>>,
48    error_rx: oneshot::Receiver<crate::Error>,
49}
50
51pin_project! {
52    #[must_use = "futures do nothing unless polled"]
53    pub struct UpgradedSendStreamTask<B> {
54        #[pin]
55        h2_tx: SendStream<SendBuf<B>>,
56        #[pin]
57        rx: mpsc::Receiver<Cursor<Box<[u8]>>>,
58        error_tx: Option<oneshot::Sender<crate::Error>>,
59    }
60}
61
62// ===== impl UpgradedSendStreamTask =====
63
64impl<B> UpgradedSendStreamTask<B>
65where
66    B: Buf,
67{
68    fn tick(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), crate::Error>> {
69        let mut me = self.project();
70
71        // this is a manual `select()` over 3 "futures", so we always need
72        // to be sure they are ready and/or we are waiting notification of
73        // one of the sides hanging up, so the task doesn't live around
74        // longer than it's meant to.
75        loop {
76            // we don't have the next chunk of data yet, so just reserve 1 byte to make
77            // sure there's some capacity available. h2 will handle the capacity management
78            // for the actual body chunk.
79            me.h2_tx.reserve_capacity(1);
80
81            if me.h2_tx.capacity() == 0 {
82                // poll_capacity oddly needs a loop
83                'capacity: loop {
84                    match me.h2_tx.poll_capacity(cx) {
85                        Poll::Ready(Some(Ok(0))) => {}
86                        Poll::Ready(Some(Ok(_))) => break,
87                        Poll::Ready(Some(Err(e))) => {
88                            return Poll::Ready(Err(crate::Error::new_body_write(e)))
89                        }
90                        Poll::Ready(None) => {
91                            // None means the stream is no longer in a
92                            // streaming state, we either finished it
93                            // somehow, or the remote reset us.
94                            return Poll::Ready(Err(crate::Error::new_body_write(
95                                "send stream capacity unexpectedly closed",
96                            )));
97                        }
98                        Poll::Pending => break 'capacity,
99                    }
100                }
101            }
102
103            match me.h2_tx.poll_reset(cx) {
104                Poll::Ready(Ok(reason)) => {
105                    trace!("stream received RST_STREAM: {:?}", reason);
106                    return Poll::Ready(Err(crate::Error::new_body_write(::h2::Error::from(
107                        reason,
108                    ))));
109                }
110                Poll::Ready(Err(err)) => {
111                    return Poll::Ready(Err(crate::Error::new_body_write(err)))
112                }
113                Poll::Pending => (),
114            }
115
116            match me.rx.as_mut().poll_next(cx) {
117                Poll::Ready(Some(cursor)) => {
118                    me.h2_tx
119                        .send_data(SendBuf::Cursor(cursor), false)
120                        .map_err(crate::Error::new_body_write)?;
121                }
122                Poll::Ready(None) => {
123                    me.h2_tx
124                        .send_data(SendBuf::None, true)
125                        .map_err(crate::Error::new_body_write)?;
126                    return Poll::Ready(Ok(()));
127                }
128                Poll::Pending => {
129                    return Poll::Pending;
130                }
131            }
132        }
133    }
134}
135
136impl<B> Future for UpgradedSendStreamTask<B>
137where
138    B: Buf,
139{
140    type Output = ();
141
142    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
143        match self.as_mut().tick(cx) {
144            Poll::Ready(Ok(())) => Poll::Ready(()),
145            Poll::Ready(Err(err)) => {
146                if let Some(tx) = self.error_tx.take() {
147                    let _oh_well = tx.send(err);
148                }
149                Poll::Ready(())
150            }
151            Poll::Pending => Poll::Pending,
152        }
153    }
154}
155
156// ===== impl H2Upgraded =====
157
158impl Read for H2Upgraded {
159    fn poll_read(
160        mut self: Pin<&mut Self>,
161        cx: &mut Context<'_>,
162        mut read_buf: ReadBufCursor<'_>,
163    ) -> Poll<Result<(), std::io::Error>> {
164        if self.buf.is_empty() {
165            self.buf = loop {
166                match ready!(self.recv_stream.poll_data(cx)) {
167                    None => return Poll::Ready(Ok(())),
168                    Some(Ok(buf)) if buf.is_empty() && !self.recv_stream.is_end_stream() => {
169                        continue
170                    }
171                    Some(Ok(buf)) => {
172                        self.ping.record_data(buf.len());
173                        break buf;
174                    }
175                    Some(Err(e)) => {
176                        return Poll::Ready(match e.reason() {
177                            Some(Reason::NO_ERROR) | Some(Reason::CANCEL) => Ok(()),
178                            Some(Reason::STREAM_CLOSED) => {
179                                Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))
180                            }
181                            _ => Err(h2_to_io_error(e)),
182                        })
183                    }
184                }
185            };
186        }
187        let cnt = std::cmp::min(self.buf.len(), read_buf.remaining());
188        read_buf.put_slice(&self.buf[..cnt]);
189        self.buf.advance(cnt);
190        let _ = self.recv_stream.flow_control().release_capacity(cnt);
191        Poll::Ready(Ok(()))
192    }
193}
194
195impl Write for H2Upgraded {
196    fn poll_write(
197        mut self: Pin<&mut Self>,
198        cx: &mut Context<'_>,
199        buf: &[u8],
200    ) -> Poll<Result<usize, std::io::Error>> {
201        if buf.is_empty() {
202            return Poll::Ready(Ok(0));
203        }
204
205        match self.send_stream.tx.poll_ready(cx) {
206            Poll::Ready(Ok(())) => {}
207            Poll::Ready(Err(_task_dropped)) => {
208                // if the task dropped, check if there was an error
209                // otherwise i guess its a broken pipe
210                return match Pin::new(&mut self.send_stream.error_rx).poll(cx) {
211                    Poll::Ready(Ok(reason)) => Poll::Ready(Err(io_error(reason))),
212                    Poll::Ready(Err(_task_dropped)) => {
213                        Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()))
214                    }
215                    Poll::Pending => Poll::Pending,
216                };
217            }
218            Poll::Pending => return Poll::Pending,
219        }
220
221        let n = buf.len();
222        match self.send_stream.tx.start_send(Cursor::new(buf.into())) {
223            Ok(()) => Poll::Ready(Ok(n)),
224            Err(_task_dropped) => {
225                // if the task dropped, check if there was an error
226                // otherwise i guess its a broken pipe
227                match Pin::new(&mut self.send_stream.error_rx).poll(cx) {
228                    Poll::Ready(Ok(reason)) => Poll::Ready(Err(io_error(reason))),
229                    Poll::Ready(Err(_task_dropped)) => {
230                        Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()))
231                    }
232                    Poll::Pending => Poll::Pending,
233                }
234            }
235        }
236    }
237
238    fn poll_flush(
239        mut self: Pin<&mut Self>,
240        cx: &mut Context<'_>,
241    ) -> Poll<Result<(), std::io::Error>> {
242        match self.send_stream.tx.poll_ready(cx) {
243            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
244            Poll::Ready(Err(_task_dropped)) => {
245                // if the task dropped, check if there was an error
246                // otherwise it was a clean close
247                match Pin::new(&mut self.send_stream.error_rx).poll(cx) {
248                    Poll::Ready(Ok(reason)) => Poll::Ready(Err(io_error(reason))),
249                    Poll::Ready(Err(_task_dropped)) => Poll::Ready(Ok(())),
250                    Poll::Pending => Poll::Pending,
251                }
252            }
253            Poll::Pending => Poll::Pending,
254        }
255    }
256
257    fn poll_shutdown(
258        mut self: Pin<&mut Self>,
259        cx: &mut Context<'_>,
260    ) -> Poll<Result<(), std::io::Error>> {
261        self.send_stream.tx.close_channel();
262        match Pin::new(&mut self.send_stream.error_rx).poll(cx) {
263            Poll::Ready(Ok(reason)) => Poll::Ready(Err(io_error(reason))),
264            Poll::Ready(Err(_task_dropped)) => Poll::Ready(Ok(())),
265            Poll::Pending => Poll::Pending,
266        }
267    }
268}
269
270fn io_error(e: crate::Error) -> std::io::Error {
271    std::io::Error::new(std::io::ErrorKind::Other, e)
272}
273
274fn h2_to_io_error(e: h2::Error) -> std::io::Error {
275    if e.is_io() {
276        e.into_io().unwrap()
277    } else {
278        std::io::Error::new(std::io::ErrorKind::Other, e)
279    }
280}