1use std::any::TypeId;
43use std::error::Error as StdError;
44use std::fmt;
45use std::future::Future;
46use std::io;
47use std::pin::Pin;
48use std::sync::{Arc, Mutex};
49use std::task::{Context, Poll};
50
51use crate::rt::{Read, ReadBufCursor, Write};
52use bytes::Bytes;
53use tokio::sync::oneshot;
54
55use crate::common::io::Rewind;
56use crate::common::lock::LockResultExt;
57
58pub struct Upgraded {
67 io: Rewind<Box<dyn Io + Send>>,
68}
69
70#[derive(Clone)]
74pub struct OnUpgrade {
75 rx: Option<Arc<Mutex<oneshot::Receiver<crate::Result<Upgraded>>>>>,
76}
77
78#[derive(Debug)]
83#[non_exhaustive]
84pub struct Parts<T> {
85 pub io: T,
87 pub read_buf: Bytes,
96}
97
98pub fn on<T: sealed::CanUpgrade>(msg: T) -> OnUpgrade {
107 msg.on_upgrade()
108}
109
110#[cfg(all(
111 any(feature = "client", feature = "server"),
112 any(feature = "http1", feature = "http2"),
113))]
114pub(super) struct Pending {
115 tx: oneshot::Sender<crate::Result<Upgraded>>,
116}
117
118#[cfg(all(
119 any(feature = "client", feature = "server"),
120 any(feature = "http1", feature = "http2"),
121))]
122pub(super) fn pending() -> (Pending, OnUpgrade) {
123 let (tx, rx) = oneshot::channel();
124 (
125 Pending { tx },
126 OnUpgrade {
127 rx: Some(Arc::new(Mutex::new(rx))),
128 },
129 )
130}
131
132impl Upgraded {
135 #[cfg(all(
136 any(feature = "client", feature = "server"),
137 any(feature = "http1", feature = "http2")
138 ))]
139 pub(super) fn new<T>(io: T, read_buf: Bytes) -> Self
140 where
141 T: Read + Write + Unpin + Send + 'static,
142 {
143 Upgraded {
144 io: Rewind::new_buffered(Box::new(io), read_buf),
145 }
146 }
147
148 pub fn downcast<T: Read + Write + Unpin + 'static>(self) -> Result<Parts<T>, Self> {
153 let (io, buf) = self.io.into_inner();
154 match io.__hyper_downcast() {
155 Ok(t) => Ok(Parts {
156 io: *t,
157 read_buf: buf,
158 }),
159 Err(io) => Err(Upgraded {
160 io: Rewind::new_buffered(io, buf),
161 }),
162 }
163 }
164}
165
166impl Read for Upgraded {
167 fn poll_read(
168 mut self: Pin<&mut Self>,
169 cx: &mut Context<'_>,
170 buf: ReadBufCursor<'_>,
171 ) -> Poll<io::Result<()>> {
172 Pin::new(&mut self.io).poll_read(cx, buf)
173 }
174}
175
176impl Write for Upgraded {
177 fn poll_write(
178 mut self: Pin<&mut Self>,
179 cx: &mut Context<'_>,
180 buf: &[u8],
181 ) -> Poll<io::Result<usize>> {
182 Pin::new(&mut self.io).poll_write(cx, buf)
183 }
184
185 fn poll_write_vectored(
186 mut self: Pin<&mut Self>,
187 cx: &mut Context<'_>,
188 bufs: &[io::IoSlice<'_>],
189 ) -> Poll<io::Result<usize>> {
190 Pin::new(&mut self.io).poll_write_vectored(cx, bufs)
191 }
192
193 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
194 Pin::new(&mut self.io).poll_flush(cx)
195 }
196
197 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
198 Pin::new(&mut self.io).poll_shutdown(cx)
199 }
200
201 fn is_write_vectored(&self) -> bool {
202 self.io.is_write_vectored()
203 }
204}
205
206impl fmt::Debug for Upgraded {
207 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208 f.debug_struct("Upgraded").finish()
209 }
210}
211
212impl OnUpgrade {
215 pub(super) fn none() -> Self {
216 OnUpgrade { rx: None }
217 }
218
219 #[cfg(all(any(feature = "client", feature = "server"), feature = "http1"))]
220 pub(super) fn is_none(&self) -> bool {
221 self.rx.is_none()
222 }
223}
224
225impl Future for OnUpgrade {
226 type Output = Result<Upgraded, crate::Error>;
227
228 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
229 match self.rx {
230 Some(ref rx) => Pin::new(&mut *rx.lock().panic_if_poisoned())
231 .poll(cx)
232 .map(|res| match res {
233 Ok(Ok(upgraded)) => Ok(upgraded),
234 Ok(Err(err)) => Err(err),
235 Err(_oneshot_canceled) => {
236 Err(crate::Error::new_canceled().with(UpgradeExpected))
237 }
238 }),
239 None => Poll::Ready(Err(crate::Error::new_user_no_upgrade())),
240 }
241 }
242}
243
244impl fmt::Debug for OnUpgrade {
245 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
246 f.debug_struct("OnUpgrade").finish()
247 }
248}
249
250#[cfg(all(
253 any(feature = "client", feature = "server"),
254 any(feature = "http1", feature = "http2")
255))]
256impl Pending {
257 pub(super) fn fulfill(self, upgraded: Upgraded) {
258 trace!("pending upgrade fulfill");
259 let _ = self.tx.send(Ok(upgraded));
260 }
261
262 #[cfg(feature = "http1")]
263 pub(super) fn manual(self) {
266 #[cfg(any(feature = "http1", feature = "http2"))]
267 trace!("pending upgrade handled manually");
268 let _ = self.tx.send(Err(crate::Error::new_user_manual_upgrade()));
269 }
270}
271
272#[derive(Debug)]
279struct UpgradeExpected;
280
281impl fmt::Display for UpgradeExpected {
282 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
283 f.write_str("upgrade expected but not completed")
284 }
285}
286
287impl StdError for UpgradeExpected {}
288
289pub(super) trait Io: Read + Write + Unpin + 'static {
292 fn __hyper_type_id(&self) -> TypeId {
293 TypeId::of::<Self>()
294 }
295}
296
297impl<T: Read + Write + Unpin + 'static> Io for T {}
298
299impl dyn Io + Send {
300 fn __hyper_is<T: Io>(&self) -> bool {
301 let t = TypeId::of::<T>();
302 self.__hyper_type_id() == t
303 }
304
305 fn __hyper_downcast<T: Io>(self: Box<Self>) -> Result<Box<T>, Box<Self>> {
306 if self.__hyper_is::<T>() {
307 unsafe {
309 let raw: *mut dyn Io = Box::into_raw(self);
310 Ok(Box::from_raw(raw as *mut T))
311 }
312 } else {
313 Err(self)
314 }
315 }
316}
317
318mod sealed {
319 use super::OnUpgrade;
320
321 pub trait CanUpgrade {
322 fn on_upgrade(self) -> OnUpgrade;
323 }
324
325 impl<B> CanUpgrade for http::Request<B> {
326 fn on_upgrade(mut self) -> OnUpgrade {
327 self.extensions_mut()
328 .remove::<OnUpgrade>()
329 .unwrap_or_else(OnUpgrade::none)
330 }
331 }
332
333 impl<B> CanUpgrade for &'_ mut http::Request<B> {
334 fn on_upgrade(self) -> OnUpgrade {
335 self.extensions_mut()
336 .remove::<OnUpgrade>()
337 .unwrap_or_else(OnUpgrade::none)
338 }
339 }
340
341 impl<B> CanUpgrade for http::Response<B> {
342 fn on_upgrade(mut self) -> OnUpgrade {
343 self.extensions_mut()
344 .remove::<OnUpgrade>()
345 .unwrap_or_else(OnUpgrade::none)
346 }
347 }
348
349 impl<B> CanUpgrade for &'_ mut http::Response<B> {
350 fn on_upgrade(self) -> OnUpgrade {
351 self.extensions_mut()
352 .remove::<OnUpgrade>()
353 .unwrap_or_else(OnUpgrade::none)
354 }
355 }
356}
357
358#[cfg(all(
359 any(feature = "client", feature = "server"),
360 any(feature = "http1", feature = "http2"),
361))]
362#[cfg(test)]
363mod tests {
364 use super::*;
365
366 #[test]
367 fn upgraded_downcast() {
368 let upgraded = Upgraded::new(Mock, Bytes::new());
369
370 let upgraded = upgraded
371 .downcast::<crate::common::io::Compat<std::io::Cursor<Vec<u8>>>>()
372 .unwrap_err();
373
374 upgraded.downcast::<Mock>().unwrap();
375 }
376
377 struct Mock;
379
380 impl Read for Mock {
381 fn poll_read(
382 self: Pin<&mut Self>,
383 _cx: &mut Context<'_>,
384 _buf: ReadBufCursor<'_>,
385 ) -> Poll<io::Result<()>> {
386 unreachable!("Mock::poll_read")
387 }
388 }
389
390 impl Write for Mock {
391 fn poll_write(
392 self: Pin<&mut Self>,
393 _: &mut Context<'_>,
394 buf: &[u8],
395 ) -> Poll<io::Result<usize>> {
396 Poll::Ready(Ok(buf.len()))
398 }
399
400 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
401 unreachable!("Mock::poll_flush")
402 }
403
404 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
405 unreachable!("Mock::poll_shutdown")
406 }
407 }
408}