1use crate::compat::AllowStd;
2#[cfg(feature = "handshake")]
3use crate::compat::SetWaker;
4use crate::WebSocketStream;
5use futures_io::{AsyncRead, AsyncWrite};
6#[allow(unused_imports)]
7use log::*;
8use std::future::Future;
9use std::io::{Read, Write};
10use std::pin::Pin;
11use std::task::{Context, Poll};
12use tungstenite::WebSocket;
13#[cfg(feature = "handshake")]
14use tungstenite::{
15 handshake::{
16 client::Response, server::Callback, HandshakeError as Error, HandshakeRole,
17 MidHandshake as WsHandshake,
18 },
19 ClientHandshake, ServerHandshake,
20};
21
22pub(crate) async fn without_handshake<F, S>(stream: S, f: F) -> WebSocketStream<S>
23where
24 F: FnOnce(AllowStd<S>) -> WebSocket<AllowStd<S>> + Unpin,
25 S: AsyncRead + AsyncWrite + Unpin,
26{
27 let start = SkippedHandshakeFuture(Some(SkippedHandshakeFutureInner { f, stream }));
28
29 let ws = start.await;
30
31 WebSocketStream::new(ws)
32}
33
34struct SkippedHandshakeFuture<F, S>(Option<SkippedHandshakeFutureInner<F, S>>);
35struct SkippedHandshakeFutureInner<F, S> {
36 f: F,
37 stream: S,
38}
39
40impl<F, S> Future for SkippedHandshakeFuture<F, S>
41where
42 F: FnOnce(AllowStd<S>) -> WebSocket<AllowStd<S>> + Unpin,
43 S: Unpin,
44 AllowStd<S>: Read + Write,
45{
46 type Output = WebSocket<AllowStd<S>>;
47
48 fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
49 let inner = self
50 .get_mut()
51 .0
52 .take()
53 .expect("future polled after completion");
54 #[cfg(feature = "verbose-logging")]
55 trace!("Setting context when skipping handshake");
56 let stream = AllowStd::new(inner.stream, ctx.waker());
57
58 Poll::Ready((inner.f)(stream))
59 }
60}
61
62#[cfg(feature = "handshake")]
63struct MidHandshake<Role: HandshakeRole>(Option<WsHandshake<Role>>);
64
65#[cfg(feature = "handshake")]
66enum StartedHandshake<Role: HandshakeRole> {
67 Done(Role::FinalResult),
68 Mid(WsHandshake<Role>),
69}
70
71#[cfg(feature = "handshake")]
72struct StartedHandshakeFuture<F, S>(Option<StartedHandshakeFutureInner<F, S>>);
73#[cfg(feature = "handshake")]
74struct StartedHandshakeFutureInner<F, S> {
75 f: F,
76 stream: S,
77}
78
79#[cfg(feature = "handshake")]
80async fn handshake<Role, F, S>(stream: S, f: F) -> Result<Role::FinalResult, Error<Role>>
81where
82 Role: HandshakeRole + Unpin,
83 Role::InternalStream: SetWaker + Unpin,
84 F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin,
85 S: AsyncRead + AsyncWrite + Unpin,
86{
87 let start = StartedHandshakeFuture(Some(StartedHandshakeFutureInner { f, stream }));
88
89 match start.await? {
90 StartedHandshake::Done(r) => Ok(r),
91 StartedHandshake::Mid(s) => {
92 let res: Result<Role::FinalResult, Error<Role>> = MidHandshake::<Role>(Some(s)).await;
93 res
94 }
95 }
96}
97
98#[cfg(feature = "handshake")]
99pub(crate) async fn client_handshake<F, S>(
100 stream: S,
101 f: F,
102) -> Result<(WebSocketStream<S>, Response), Error<ClientHandshake<AllowStd<S>>>>
103where
104 F: FnOnce(
105 AllowStd<S>,
106 ) -> Result<
107 <ClientHandshake<AllowStd<S>> as HandshakeRole>::FinalResult,
108 Error<ClientHandshake<AllowStd<S>>>,
109 > + Unpin,
110 S: AsyncRead + AsyncWrite + Unpin,
111{
112 let result = handshake(stream, f).await?;
113 let (s, r) = result;
114 Ok((WebSocketStream::new(s), r))
115}
116
117#[cfg(feature = "handshake")]
118pub(crate) async fn server_handshake<C, F, S>(
119 stream: S,
120 f: F,
121) -> Result<WebSocketStream<S>, Error<ServerHandshake<AllowStd<S>, C>>>
122where
123 C: Callback + Unpin,
124 F: FnOnce(
125 AllowStd<S>,
126 ) -> Result<
127 <ServerHandshake<AllowStd<S>, C> as HandshakeRole>::FinalResult,
128 Error<ServerHandshake<AllowStd<S>, C>>,
129 > + Unpin,
130 S: AsyncRead + AsyncWrite + Unpin,
131{
132 let s: WebSocket<AllowStd<S>> = handshake(stream, f).await?;
133 Ok(WebSocketStream::new(s))
134}
135
136#[cfg(feature = "handshake")]
137impl<Role, F, S> Future for StartedHandshakeFuture<F, S>
138where
139 Role: HandshakeRole,
140 Role::InternalStream: SetWaker,
141 F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin,
142 S: Unpin,
143 AllowStd<S>: Read + Write,
144{
145 type Output = Result<StartedHandshake<Role>, Error<Role>>;
146
147 fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
148 let inner = self.0.take().expect("future polled after completion");
149 #[cfg(feature = "verbose-logging")]
150 trace!("Setting ctx when starting handshake");
151 let stream = AllowStd::new(inner.stream, ctx.waker());
152
153 match (inner.f)(stream) {
154 Ok(r) => Poll::Ready(Ok(StartedHandshake::Done(r))),
155 Err(Error::Interrupted(mid)) => Poll::Ready(Ok(StartedHandshake::Mid(mid))),
156 Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))),
157 }
158 }
159}
160
161#[cfg(feature = "handshake")]
162impl<Role> Future for MidHandshake<Role>
163where
164 Role: HandshakeRole + Unpin,
165 Role::InternalStream: SetWaker + Unpin,
166{
167 type Output = Result<Role::FinalResult, Error<Role>>;
168
169 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
170 let mut s = self
171 .as_mut()
172 .0
173 .take()
174 .expect("future polled after completion");
175
176 let machine = s.get_mut();
177 #[cfg(feature = "verbose-logging")]
178 trace!("Setting context in handshake");
179 machine.get_mut().set_waker(cx.waker());
180
181 match s.handshake() {
182 Ok(stream) => Poll::Ready(Ok(stream)),
183 Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))),
184 Err(Error::Interrupted(mid)) => {
185 self.0 = Some(mid);
186 Poll::Pending
187 }
188 }
189 }
190}