Skip to main content

hyper/
upgrade.rs

1//! HTTP Upgrades.
2//!
3//! This module deals with managing [HTTP Upgrades][mdn] in hyper. Since
4//! several concepts in HTTP allow for first talking HTTP, and then converting
5//! to a different protocol, this module conflates them into a single API.
6//! Those include:
7//!
8//! - HTTP/1.1 Upgrades
9//! - HTTP `CONNECT`
10//!
11//! You are responsible for any other pre-requisites to establish an upgrade,
12//! such as sending the appropriate headers, methods, and status codes. You can
13//! then use [`on`][] to grab a `Future` which will resolve to the upgraded
14//! connection object, or an error if the upgrade fails.
15//!
16//! [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Protocol_upgrade_mechanism
17//!
18//! # Client
19//!
20//! Sending an HTTP upgrade from the [`client`](super::client) involves setting
21//! either the appropriate method, if wanting to `CONNECT`, or headers such as
22//! `Upgrade` and `Connection`, on the `http::Request`. Once receiving the
23//! `http::Response` back, you must check for the specific information that the
24//! upgrade is agreed upon by the server (such as a `101` status code), and then
25//! get the `Future` from the `Response`.
26//!
27//! # Server
28//!
29//! Receiving upgrade requests in a server requires you to check the relevant
30//! headers in a `Request`, and if an upgrade should be done, you then send the
31//! corresponding headers in a response. To then wait for hyper to finish the
32//! upgrade, you call `on()` with the `Request`, and then can spawn a task
33//! awaiting it.
34//!
35//! # Example
36//!
37//! See [this example][example] showing how upgrades work with both
38//! Clients and Servers.
39//!
40//! [example]: https://github.com/hyperium/hyper/blob/master/examples/upgrades.rs
41
42use std::any::TypeId;
43use std::error::Error as StdError;
44use std::fmt;
45use std::future::Future;
46use std::io;
47use std::pin::Pin;
48use std::sync::{Arc, Mutex};
49use std::task::{Context, Poll};
50
51use crate::rt::{Read, ReadBufCursor, Write};
52use bytes::Bytes;
53use tokio::sync::oneshot;
54
55use crate::common::io::Rewind;
56use crate::common::lock::LockResultExt;
57
58/// An upgraded HTTP connection.
59///
60/// This type holds a trait object internally of the original IO that
61/// was used to speak HTTP before the upgrade. It can be used directly
62/// as a [`Read`] or [`Write`] for convenience.
63///
64/// Alternatively, if the exact type is known, this can be deconstructed
65/// into its parts.
66pub struct Upgraded {
67    io: Rewind<Box<dyn Io + Send>>,
68}
69
70/// A future for a possible HTTP upgrade.
71///
72/// If no upgrade was available, or it doesn't succeed, yields an `Error`.
73#[derive(Clone)]
74pub struct OnUpgrade {
75    rx: Option<Arc<Mutex<oneshot::Receiver<crate::Result<Upgraded>>>>>,
76}
77
78/// The deconstructed parts of an [`Upgraded`] type.
79///
80/// Includes the original IO type, and a read buffer of bytes that the
81/// HTTP state machine may have already read before completing an upgrade.
82#[derive(Debug)]
83#[non_exhaustive]
84pub struct Parts<T> {
85    /// The original IO object used before the upgrade.
86    pub io: T,
87    /// A buffer of bytes that have been read but not processed as HTTP.
88    ///
89    /// For instance, if the `Connection` is used for an HTTP upgrade request,
90    /// it is possible the server sent back the first bytes of the new protocol
91    /// along with the response upgrade.
92    ///
93    /// You will want to check for any existing bytes if you plan to continue
94    /// communicating on the IO object.
95    pub read_buf: Bytes,
96}
97
98/// Gets a pending HTTP upgrade from this message.
99///
100/// This can be called on the following types:
101///
102/// - `http::Request<B>`
103/// - `http::Response<B>`
104/// - `&mut http::Request<B>`
105/// - `&mut http::Response<B>`
106pub fn on<T: sealed::CanUpgrade>(msg: T) -> OnUpgrade {
107    msg.on_upgrade()
108}
109
110#[cfg(all(
111    any(feature = "client", feature = "server"),
112    any(feature = "http1", feature = "http2"),
113))]
114pub(super) struct Pending {
115    tx: oneshot::Sender<crate::Result<Upgraded>>,
116}
117
118#[cfg(all(
119    any(feature = "client", feature = "server"),
120    any(feature = "http1", feature = "http2"),
121))]
122pub(super) fn pending() -> (Pending, OnUpgrade) {
123    let (tx, rx) = oneshot::channel();
124    (
125        Pending { tx },
126        OnUpgrade {
127            rx: Some(Arc::new(Mutex::new(rx))),
128        },
129    )
130}
131
132// ===== impl Upgraded =====
133
134impl Upgraded {
135    #[cfg(all(
136        any(feature = "client", feature = "server"),
137        any(feature = "http1", feature = "http2")
138    ))]
139    pub(super) fn new<T>(io: T, read_buf: Bytes) -> Self
140    where
141        T: Read + Write + Unpin + Send + 'static,
142    {
143        Upgraded {
144            io: Rewind::new_buffered(Box::new(io), read_buf),
145        }
146    }
147
148    /// Tries to downcast the internal trait object to the type passed.
149    ///
150    /// On success, returns the downcasted parts. On error, returns the
151    /// `Upgraded` back.
152    pub fn downcast<T: Read + Write + Unpin + 'static>(self) -> Result<Parts<T>, Self> {
153        let (io, buf) = self.io.into_inner();
154        match io.__hyper_downcast() {
155            Ok(t) => Ok(Parts {
156                io: *t,
157                read_buf: buf,
158            }),
159            Err(io) => Err(Upgraded {
160                io: Rewind::new_buffered(io, buf),
161            }),
162        }
163    }
164}
165
166impl Read for Upgraded {
167    fn poll_read(
168        mut self: Pin<&mut Self>,
169        cx: &mut Context<'_>,
170        buf: ReadBufCursor<'_>,
171    ) -> Poll<io::Result<()>> {
172        Pin::new(&mut self.io).poll_read(cx, buf)
173    }
174}
175
176impl Write for Upgraded {
177    fn poll_write(
178        mut self: Pin<&mut Self>,
179        cx: &mut Context<'_>,
180        buf: &[u8],
181    ) -> Poll<io::Result<usize>> {
182        Pin::new(&mut self.io).poll_write(cx, buf)
183    }
184
185    fn poll_write_vectored(
186        mut self: Pin<&mut Self>,
187        cx: &mut Context<'_>,
188        bufs: &[io::IoSlice<'_>],
189    ) -> Poll<io::Result<usize>> {
190        Pin::new(&mut self.io).poll_write_vectored(cx, bufs)
191    }
192
193    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
194        Pin::new(&mut self.io).poll_flush(cx)
195    }
196
197    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
198        Pin::new(&mut self.io).poll_shutdown(cx)
199    }
200
201    fn is_write_vectored(&self) -> bool {
202        self.io.is_write_vectored()
203    }
204}
205
206impl fmt::Debug for Upgraded {
207    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208        f.debug_struct("Upgraded").finish()
209    }
210}
211
212// ===== impl OnUpgrade =====
213
214impl OnUpgrade {
215    pub(super) fn none() -> Self {
216        OnUpgrade { rx: None }
217    }
218
219    #[cfg(all(any(feature = "client", feature = "server"), feature = "http1"))]
220    pub(super) fn is_none(&self) -> bool {
221        self.rx.is_none()
222    }
223}
224
225impl Future for OnUpgrade {
226    type Output = Result<Upgraded, crate::Error>;
227
228    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
229        match self.rx {
230            Some(ref rx) => Pin::new(&mut *rx.lock().panic_if_poisoned())
231                .poll(cx)
232                .map(|res| match res {
233                    Ok(Ok(upgraded)) => Ok(upgraded),
234                    Ok(Err(err)) => Err(err),
235                    Err(_oneshot_canceled) => {
236                        Err(crate::Error::new_canceled().with(UpgradeExpected))
237                    }
238                }),
239            None => Poll::Ready(Err(crate::Error::new_user_no_upgrade())),
240        }
241    }
242}
243
244impl fmt::Debug for OnUpgrade {
245    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
246        f.debug_struct("OnUpgrade").finish()
247    }
248}
249
250// ===== impl Pending =====
251
252#[cfg(all(
253    any(feature = "client", feature = "server"),
254    any(feature = "http1", feature = "http2")
255))]
256impl Pending {
257    pub(super) fn fulfill(self, upgraded: Upgraded) {
258        trace!("pending upgrade fulfill");
259        let _ = self.tx.send(Ok(upgraded));
260    }
261
262    #[cfg(feature = "http1")]
263    /// Don't fulfill the pending Upgrade, but instead signal that
264    /// upgrades are handled manually.
265    pub(super) fn manual(self) {
266        #[cfg(any(feature = "http1", feature = "http2"))]
267        trace!("pending upgrade handled manually");
268        let _ = self.tx.send(Err(crate::Error::new_user_manual_upgrade()));
269    }
270}
271
272// ===== impl UpgradeExpected =====
273
274/// Error cause returned when an upgrade was expected but canceled
275/// for whatever reason.
276///
277/// This likely means the actual `Conn` future wasn't polled and upgraded.
278#[derive(Debug)]
279struct UpgradeExpected;
280
281impl fmt::Display for UpgradeExpected {
282    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
283        f.write_str("upgrade expected but not completed")
284    }
285}
286
287impl StdError for UpgradeExpected {}
288
289// ===== impl Io =====
290
291pub(super) trait Io: Read + Write + Unpin + 'static {
292    fn __hyper_type_id(&self) -> TypeId {
293        TypeId::of::<Self>()
294    }
295}
296
297impl<T: Read + Write + Unpin + 'static> Io for T {}
298
299impl dyn Io + Send {
300    fn __hyper_is<T: Io>(&self) -> bool {
301        let t = TypeId::of::<T>();
302        self.__hyper_type_id() == t
303    }
304
305    fn __hyper_downcast<T: Io>(self: Box<Self>) -> Result<Box<T>, Box<Self>> {
306        if self.__hyper_is::<T>() {
307            // Taken from `std::error::Error::downcast()`.
308            unsafe {
309                let raw: *mut dyn Io = Box::into_raw(self);
310                Ok(Box::from_raw(raw as *mut T))
311            }
312        } else {
313            Err(self)
314        }
315    }
316}
317
318mod sealed {
319    use super::OnUpgrade;
320
321    pub trait CanUpgrade {
322        fn on_upgrade(self) -> OnUpgrade;
323    }
324
325    impl<B> CanUpgrade for http::Request<B> {
326        fn on_upgrade(mut self) -> OnUpgrade {
327            self.extensions_mut()
328                .remove::<OnUpgrade>()
329                .unwrap_or_else(OnUpgrade::none)
330        }
331    }
332
333    impl<B> CanUpgrade for &'_ mut http::Request<B> {
334        fn on_upgrade(self) -> OnUpgrade {
335            self.extensions_mut()
336                .remove::<OnUpgrade>()
337                .unwrap_or_else(OnUpgrade::none)
338        }
339    }
340
341    impl<B> CanUpgrade for http::Response<B> {
342        fn on_upgrade(mut self) -> OnUpgrade {
343            self.extensions_mut()
344                .remove::<OnUpgrade>()
345                .unwrap_or_else(OnUpgrade::none)
346        }
347    }
348
349    impl<B> CanUpgrade for &'_ mut http::Response<B> {
350        fn on_upgrade(self) -> OnUpgrade {
351            self.extensions_mut()
352                .remove::<OnUpgrade>()
353                .unwrap_or_else(OnUpgrade::none)
354        }
355    }
356}
357
358#[cfg(all(
359    any(feature = "client", feature = "server"),
360    any(feature = "http1", feature = "http2"),
361))]
362#[cfg(test)]
363mod tests {
364    use super::*;
365
366    #[test]
367    fn upgraded_downcast() {
368        let upgraded = Upgraded::new(Mock, Bytes::new());
369
370        let upgraded = upgraded
371            .downcast::<crate::common::io::Compat<std::io::Cursor<Vec<u8>>>>()
372            .unwrap_err();
373
374        upgraded.downcast::<Mock>().unwrap();
375    }
376
377    // TODO: replace with tokio_test::io when it can test write_buf
378    struct Mock;
379
380    impl Read for Mock {
381        fn poll_read(
382            self: Pin<&mut Self>,
383            _cx: &mut Context<'_>,
384            _buf: ReadBufCursor<'_>,
385        ) -> Poll<io::Result<()>> {
386            unreachable!("Mock::poll_read")
387        }
388    }
389
390    impl Write for Mock {
391        fn poll_write(
392            self: Pin<&mut Self>,
393            _: &mut Context<'_>,
394            buf: &[u8],
395        ) -> Poll<io::Result<usize>> {
396            // panic!("poll_write shouldn't be called");
397            Poll::Ready(Ok(buf.len()))
398        }
399
400        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
401            unreachable!("Mock::poll_flush")
402        }
403
404        fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
405            unreachable!("Mock::poll_shutdown")
406        }
407    }
408}