async_tungstenite/
handshake.rs

1use crate::compat::AllowStd;
2#[cfg(feature = "handshake")]
3use crate::compat::SetWaker;
4use crate::WebSocketStream;
5use futures_io::{AsyncRead, AsyncWrite};
6#[allow(unused_imports)]
7use log::*;
8use std::future::Future;
9use std::io::{Read, Write};
10use std::pin::Pin;
11use std::task::{Context, Poll};
12use tungstenite::WebSocket;
13#[cfg(feature = "handshake")]
14use tungstenite::{
15    handshake::{
16        client::Response, server::Callback, HandshakeError as Error, HandshakeRole,
17        MidHandshake as WsHandshake,
18    },
19    ClientHandshake, ServerHandshake,
20};
21
22pub(crate) async fn without_handshake<F, S>(stream: S, f: F) -> WebSocketStream<S>
23where
24    F: FnOnce(AllowStd<S>) -> WebSocket<AllowStd<S>> + Unpin,
25    S: AsyncRead + AsyncWrite + Unpin,
26{
27    let start = SkippedHandshakeFuture(Some(SkippedHandshakeFutureInner { f, stream }));
28
29    let ws = start.await;
30
31    WebSocketStream::new(ws)
32}
33
34struct SkippedHandshakeFuture<F, S>(Option<SkippedHandshakeFutureInner<F, S>>);
35struct SkippedHandshakeFutureInner<F, S> {
36    f: F,
37    stream: S,
38}
39
40impl<F, S> Future for SkippedHandshakeFuture<F, S>
41where
42    F: FnOnce(AllowStd<S>) -> WebSocket<AllowStd<S>> + Unpin,
43    S: Unpin,
44    AllowStd<S>: Read + Write,
45{
46    type Output = WebSocket<AllowStd<S>>;
47
48    fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
49        let inner = self
50            .get_mut()
51            .0
52            .take()
53            .expect("future polled after completion");
54        #[cfg(feature = "verbose-logging")]
55        trace!("Setting context when skipping handshake");
56        let stream = AllowStd::new(inner.stream, ctx.waker());
57
58        Poll::Ready((inner.f)(stream))
59    }
60}
61
62#[cfg(feature = "handshake")]
63struct MidHandshake<Role: HandshakeRole>(Option<WsHandshake<Role>>);
64
65#[cfg(feature = "handshake")]
66enum StartedHandshake<Role: HandshakeRole> {
67    Done(Role::FinalResult),
68    Mid(WsHandshake<Role>),
69}
70
71#[cfg(feature = "handshake")]
72struct StartedHandshakeFuture<F, S>(Option<StartedHandshakeFutureInner<F, S>>);
73#[cfg(feature = "handshake")]
74struct StartedHandshakeFutureInner<F, S> {
75    f: F,
76    stream: S,
77}
78
79#[cfg(feature = "handshake")]
80async fn handshake<Role, F, S>(stream: S, f: F) -> Result<Role::FinalResult, Error<Role>>
81where
82    Role: HandshakeRole + Unpin,
83    Role::InternalStream: SetWaker + Unpin,
84    F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin,
85    S: AsyncRead + AsyncWrite + Unpin,
86{
87    let start = StartedHandshakeFuture(Some(StartedHandshakeFutureInner { f, stream }));
88
89    match start.await? {
90        StartedHandshake::Done(r) => Ok(r),
91        StartedHandshake::Mid(s) => {
92            let res: Result<Role::FinalResult, Error<Role>> = MidHandshake::<Role>(Some(s)).await;
93            res
94        }
95    }
96}
97
98#[cfg(feature = "handshake")]
99pub(crate) async fn client_handshake<F, S>(
100    stream: S,
101    f: F,
102) -> Result<(WebSocketStream<S>, Response), Error<ClientHandshake<AllowStd<S>>>>
103where
104    F: FnOnce(
105            AllowStd<S>,
106        ) -> Result<
107            <ClientHandshake<AllowStd<S>> as HandshakeRole>::FinalResult,
108            Error<ClientHandshake<AllowStd<S>>>,
109        > + Unpin,
110    S: AsyncRead + AsyncWrite + Unpin,
111{
112    let result = handshake(stream, f).await?;
113    let (s, r) = result;
114    Ok((WebSocketStream::new(s), r))
115}
116
117#[cfg(feature = "handshake")]
118pub(crate) async fn server_handshake<C, F, S>(
119    stream: S,
120    f: F,
121) -> Result<WebSocketStream<S>, Error<ServerHandshake<AllowStd<S>, C>>>
122where
123    C: Callback + Unpin,
124    F: FnOnce(
125            AllowStd<S>,
126        ) -> Result<
127            <ServerHandshake<AllowStd<S>, C> as HandshakeRole>::FinalResult,
128            Error<ServerHandshake<AllowStd<S>, C>>,
129        > + Unpin,
130    S: AsyncRead + AsyncWrite + Unpin,
131{
132    let s: WebSocket<AllowStd<S>> = handshake(stream, f).await?;
133    Ok(WebSocketStream::new(s))
134}
135
136#[cfg(feature = "handshake")]
137impl<Role, F, S> Future for StartedHandshakeFuture<F, S>
138where
139    Role: HandshakeRole,
140    Role::InternalStream: SetWaker,
141    F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin,
142    S: Unpin,
143    AllowStd<S>: Read + Write,
144{
145    type Output = Result<StartedHandshake<Role>, Error<Role>>;
146
147    fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
148        let inner = self.0.take().expect("future polled after completion");
149        #[cfg(feature = "verbose-logging")]
150        trace!("Setting ctx when starting handshake");
151        let stream = AllowStd::new(inner.stream, ctx.waker());
152
153        match (inner.f)(stream) {
154            Ok(r) => Poll::Ready(Ok(StartedHandshake::Done(r))),
155            Err(Error::Interrupted(mid)) => Poll::Ready(Ok(StartedHandshake::Mid(mid))),
156            Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))),
157        }
158    }
159}
160
161#[cfg(feature = "handshake")]
162impl<Role> Future for MidHandshake<Role>
163where
164    Role: HandshakeRole + Unpin,
165    Role::InternalStream: SetWaker + Unpin,
166{
167    type Output = Result<Role::FinalResult, Error<Role>>;
168
169    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
170        let mut s = self
171            .as_mut()
172            .0
173            .take()
174            .expect("future polled after completion");
175
176        let machine = s.get_mut();
177        #[cfg(feature = "verbose-logging")]
178        trace!("Setting context in handshake");
179        machine.get_mut().set_waker(cx.waker());
180
181        match s.handshake() {
182            Ok(stream) => Poll::Ready(Ok(stream)),
183            Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))),
184            Err(Error::Interrupted(mid)) => {
185                self.0 = Some(mid);
186                Poll::Pending
187            }
188        }
189    }
190}