async_tungstenite/
compat.rs

1#[allow(unused_imports)]
2use log::*;
3use std::io::{Read, Write};
4use std::pin::Pin;
5use std::task::{Context, Poll, Waker};
6
7use atomic_waker::AtomicWaker;
8use futures_io::{AsyncRead, AsyncWrite};
9use futures_task::{waker_ref, ArcWake};
10use std::sync::Arc;
11use tungstenite::Error as WsError;
12
13pub(crate) enum ContextWaker {
14    Read,
15    Write,
16}
17
18#[derive(Debug)]
19pub(crate) struct AllowStd<S> {
20    inner: S,
21    // We have the problem that external read operations (i.e. the Stream impl)
22    // can trigger both read (AsyncRead) and write (AsyncWrite) operations on
23    // the underyling stream. At the same time write operations (i.e. the Sink
24    // impl) can trigger write operations (AsyncWrite) too.
25    // Both the Stream and the Sink can be used on two different tasks, but it
26    // is required that AsyncRead and AsyncWrite are only ever used by a single
27    // task (or better: with a single waker) at a time.
28    //
29    // Doing otherwise would cause only the latest waker to be remembered, so
30    // in our case either the Stream or the Sink impl would potentially wait
31    // forever to be woken up because only the other one would've been woken
32    // up.
33    //
34    // To solve this we implement a waker proxy that has two slots (one for
35    // read, one for write) to store wakers. One waker proxy is always passed
36    // to the AsyncRead, the other to AsyncWrite so that they will only ever
37    // have to store a single waker, but internally we dispatch any wakeups to
38    // up to two actual wakers (one from the Sink impl and one from the Stream
39    // impl).
40    //
41    // write_waker_proxy is always used for AsyncWrite, read_waker_proxy for
42    // AsyncRead. The read_waker slots of both are used for the Stream impl
43    // (and handshaking), the write_waker slots for the Sink impl.
44    write_waker_proxy: Arc<WakerProxy>,
45    read_waker_proxy: Arc<WakerProxy>,
46}
47
48// Internal trait used only in the Handshake module for registering
49// the waker for the context used during handshaking. We're using the
50// read waker slot for this, but any would do.
51//
52// Don't ever use this from multiple tasks at the same time!
53#[cfg(feature = "handshake")]
54pub(crate) trait SetWaker {
55    fn set_waker(&self, waker: &Waker);
56}
57
58#[cfg(feature = "handshake")]
59impl<S> SetWaker for AllowStd<S> {
60    fn set_waker(&self, waker: &Waker) {
61        self.set_waker(ContextWaker::Read, waker);
62    }
63}
64
65impl<S> AllowStd<S> {
66    pub(crate) fn new(inner: S, waker: &Waker) -> Self {
67        let res = Self {
68            inner,
69            write_waker_proxy: Default::default(),
70            read_waker_proxy: Default::default(),
71        };
72
73        // Register the handshake waker as read waker for both proxies,
74        // see also the SetWaker trait.
75        res.write_waker_proxy.read_waker.register(waker);
76        res.read_waker_proxy.read_waker.register(waker);
77
78        res
79    }
80
81    // Set the read or write waker for our proxies.
82    //
83    // Read: this is only supposed to be called by read (or handshake) operations, i.e. the Stream
84    // impl on the WebSocketStream.
85    // Reading can also cause writes to happen, e.g. in case of Message::Ping handling.
86    //
87    // Write: this is only supposde to be called by write operations, i.e. the Sink impl on the
88    // WebSocketStream.
89    pub(crate) fn set_waker(&self, kind: ContextWaker, waker: &Waker) {
90        match kind {
91            ContextWaker::Read => {
92                self.write_waker_proxy.read_waker.register(waker);
93                self.read_waker_proxy.read_waker.register(waker);
94            }
95            ContextWaker::Write => {
96                self.write_waker_proxy.write_waker.register(waker);
97                self.read_waker_proxy.write_waker.register(waker);
98            }
99        }
100    }
101}
102
103// Proxy Waker that we pass to the internal AsyncRead/Write of the
104// stream underlying the websocket. We have two slots here for the
105// actual wakers to allow external read operations to trigger both
106// reads and writes, and the same for writes.
107#[derive(Debug, Default)]
108struct WakerProxy {
109    read_waker: AtomicWaker,
110    write_waker: AtomicWaker,
111}
112
113impl ArcWake for WakerProxy {
114    fn wake_by_ref(arc_self: &Arc<Self>) {
115        arc_self.read_waker.wake();
116        arc_self.write_waker.wake();
117    }
118}
119
120impl<S> AllowStd<S>
121where
122    S: Unpin,
123{
124    fn with_context<F, R>(&mut self, kind: ContextWaker, f: F) -> Poll<std::io::Result<R>>
125    where
126        F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> Poll<std::io::Result<R>>,
127    {
128        #[cfg(feature = "verbose-logging")]
129        trace!("{}:{} AllowStd.with_context", file!(), line!());
130        let waker = match kind {
131            ContextWaker::Read => waker_ref(&self.read_waker_proxy),
132            ContextWaker::Write => waker_ref(&self.write_waker_proxy),
133        };
134        let mut context = Context::from_waker(&waker);
135        f(&mut context, Pin::new(&mut self.inner))
136    }
137
138    pub(crate) fn get_mut(&mut self) -> &mut S {
139        &mut self.inner
140    }
141
142    pub(crate) fn get_ref(&self) -> &S {
143        &self.inner
144    }
145}
146
147impl<S> Read for AllowStd<S>
148where
149    S: AsyncRead + Unpin,
150{
151    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
152        #[cfg(feature = "verbose-logging")]
153        trace!("{}:{} Read.read", file!(), line!());
154        match self.with_context(ContextWaker::Read, |ctx, stream| {
155            #[cfg(feature = "verbose-logging")]
156            trace!(
157                "{}:{} Read.with_context read -> poll_read",
158                file!(),
159                line!()
160            );
161            stream.poll_read(ctx, buf)
162        }) {
163            Poll::Ready(r) => r,
164            Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
165        }
166    }
167}
168
169impl<S> Write for AllowStd<S>
170where
171    S: AsyncWrite + Unpin,
172{
173    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
174        #[cfg(feature = "verbose-logging")]
175        trace!("{}:{} Write.write", file!(), line!());
176        match self.with_context(ContextWaker::Write, |ctx, stream| {
177            #[cfg(feature = "verbose-logging")]
178            trace!(
179                "{}:{} Write.with_context write -> poll_write",
180                file!(),
181                line!()
182            );
183            stream.poll_write(ctx, buf)
184        }) {
185            Poll::Ready(r) => r,
186            Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
187        }
188    }
189
190    fn flush(&mut self) -> std::io::Result<()> {
191        #[cfg(feature = "verbose-logging")]
192        trace!("{}:{} Write.flush", file!(), line!());
193        match self.with_context(ContextWaker::Write, |ctx, stream| {
194            #[cfg(feature = "verbose-logging")]
195            trace!(
196                "{}:{} Write.with_context flush -> poll_flush",
197                file!(),
198                line!()
199            );
200            stream.poll_flush(ctx)
201        }) {
202            Poll::Ready(r) => r,
203            Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
204        }
205    }
206}
207
208pub(crate) fn cvt<T>(r: Result<T, WsError>) -> Poll<Result<T, WsError>> {
209    match r {
210        Ok(v) => Poll::Ready(Ok(v)),
211        Err(WsError::Io(ref e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
212            #[cfg(feature = "verbose-logging")]
213            trace!("WouldBlock");
214            Poll::Pending
215        }
216        Err(e) => Poll::Ready(Err(e)),
217    }
218}