async_tungstenite/
compat.rs

1#[allow(unused_imports)]
2use log::*;
3use std::io::{Read, Write};
4use std::pin::Pin;
5use std::task::{Context, Poll, Waker};
6
7use atomic_waker::AtomicWaker;
8use futures_io::{AsyncRead, AsyncWrite};
9use futures_task::{waker_ref, ArcWake};
10use std::sync::Arc;
11use tungstenite::Error as WsError;
12
13pub(crate) enum ContextWaker {
14    Read,
15    Write,
16}
17
18#[derive(Debug)]
19pub(crate) struct AllowStd<S> {
20    inner: S,
21    // We have the problem that external read operations (i.e. the Stream impl)
22    // can trigger both read (AsyncRead) and write (AsyncWrite) operations on
23    // the underyling stream. At the same time write operations (i.e. the Sink
24    // impl) can trigger write operations (AsyncWrite) too.
25    // Both the Stream and the Sink can be used on two different tasks, but it
26    // is required that AsyncRead and AsyncWrite are only ever used by a single
27    // task (or better: with a single waker) at a time.
28    //
29    // Doing otherwise would cause only the latest waker to be remembered, so
30    // in our case either the Stream or the Sink impl would potentially wait
31    // forever to be woken up because only the other one would've been woken
32    // up.
33    //
34    // To solve this we implement a waker proxy that has two slots (one for
35    // read, one for write) to store wakers. One waker proxy is always passed
36    // to the AsyncRead, the other to AsyncWrite so that they will only ever
37    // have to store a single waker, but internally we dispatch any wakeups to
38    // up to two actual wakers (one from the Sink impl and one from the Stream
39    // impl).
40    //
41    // write_waker_proxy is always used for AsyncWrite, read_waker_proxy for
42    // AsyncRead. The read_waker slots of both are used for the Stream impl
43    // (and handshaking), the write_waker slots for the Sink impl.
44    write_waker_proxy: Arc<WakerProxy>,
45    read_waker_proxy: Arc<WakerProxy>,
46}
47
48// Internal trait used only in the Handshake module for registering
49// the waker for the context used during handshaking. We're using the
50// read waker slot for this, but any would do.
51//
52// Don't ever use this from multiple tasks at the same time!
53#[cfg(feature = "handshake")]
54pub(crate) trait SetWaker {
55    fn set_waker(&self, waker: &Waker);
56}
57
58#[cfg(feature = "handshake")]
59impl<S> SetWaker for AllowStd<S> {
60    fn set_waker(&self, waker: &Waker) {
61        self.set_waker(ContextWaker::Read, waker);
62    }
63}
64
65impl<S> AllowStd<S> {
66    pub(crate) fn new(inner: S, waker: &Waker) -> Self {
67        let res = Self {
68            inner,
69            write_waker_proxy: Default::default(),
70            read_waker_proxy: Default::default(),
71        };
72
73        // Register the handshake waker as read waker for both proxies,
74        // see also the SetWaker trait.
75        res.write_waker_proxy.read_waker.register(waker);
76        res.read_waker_proxy.read_waker.register(waker);
77
78        res
79    }
80
81    // Set the read or write waker for our proxies.
82    //
83    // Read: this is only supposed to be called by read (or handshake) operations, i.e. the Stream
84    // impl on the WebSocketStream.
85    // Reading can also cause writes to happen, e.g. in case of Message::Ping handling.
86    //
87    // Write: this is only supposde to be called by write operations, i.e. the Sink impl on the
88    // WebSocketStream.
89    pub(crate) fn set_waker(&self, kind: ContextWaker, waker: &Waker) {
90        match kind {
91            ContextWaker::Read => {
92                self.write_waker_proxy.read_waker.register(waker);
93                self.read_waker_proxy.read_waker.register(waker);
94            }
95            ContextWaker::Write => {
96                self.write_waker_proxy.write_waker.register(waker);
97                self.read_waker_proxy.write_waker.register(waker);
98            }
99        }
100    }
101
102    /// Returns the underlying stream.
103    pub(crate) fn into_inner(self) -> S {
104        self.inner
105    }
106}
107
108// Proxy Waker that we pass to the internal AsyncRead/Write of the
109// stream underlying the websocket. We have two slots here for the
110// actual wakers to allow external read operations to trigger both
111// reads and writes, and the same for writes.
112#[derive(Debug, Default)]
113struct WakerProxy {
114    read_waker: AtomicWaker,
115    write_waker: AtomicWaker,
116}
117
118impl ArcWake for WakerProxy {
119    fn wake_by_ref(arc_self: &Arc<Self>) {
120        arc_self.read_waker.wake();
121        arc_self.write_waker.wake();
122    }
123}
124
125impl<S> AllowStd<S>
126where
127    S: Unpin,
128{
129    fn with_context<F, R>(&mut self, kind: ContextWaker, f: F) -> Poll<std::io::Result<R>>
130    where
131        F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> Poll<std::io::Result<R>>,
132    {
133        #[cfg(feature = "verbose-logging")]
134        trace!("{}:{} AllowStd.with_context", file!(), line!());
135        let waker = match kind {
136            ContextWaker::Read => waker_ref(&self.read_waker_proxy),
137            ContextWaker::Write => waker_ref(&self.write_waker_proxy),
138        };
139        let mut context = Context::from_waker(&waker);
140        f(&mut context, Pin::new(&mut self.inner))
141    }
142
143    pub(crate) fn get_mut(&mut self) -> &mut S {
144        &mut self.inner
145    }
146
147    pub(crate) fn get_ref(&self) -> &S {
148        &self.inner
149    }
150}
151
152impl<S> Read for AllowStd<S>
153where
154    S: AsyncRead + Unpin,
155{
156    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
157        #[cfg(feature = "verbose-logging")]
158        trace!("{}:{} Read.read", file!(), line!());
159        match self.with_context(ContextWaker::Read, |ctx, stream| {
160            #[cfg(feature = "verbose-logging")]
161            trace!(
162                "{}:{} Read.with_context read -> poll_read",
163                file!(),
164                line!()
165            );
166            stream.poll_read(ctx, buf)
167        }) {
168            Poll::Ready(r) => r,
169            Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
170        }
171    }
172}
173
174impl<S> Write for AllowStd<S>
175where
176    S: AsyncWrite + Unpin,
177{
178    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
179        #[cfg(feature = "verbose-logging")]
180        trace!("{}:{} Write.write", file!(), line!());
181        match self.with_context(ContextWaker::Write, |ctx, stream| {
182            #[cfg(feature = "verbose-logging")]
183            trace!(
184                "{}:{} Write.with_context write -> poll_write",
185                file!(),
186                line!()
187            );
188            stream.poll_write(ctx, buf)
189        }) {
190            Poll::Ready(r) => r,
191            Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
192        }
193    }
194
195    fn flush(&mut self) -> std::io::Result<()> {
196        #[cfg(feature = "verbose-logging")]
197        trace!("{}:{} Write.flush", file!(), line!());
198        match self.with_context(ContextWaker::Write, |ctx, stream| {
199            #[cfg(feature = "verbose-logging")]
200            trace!(
201                "{}:{} Write.with_context flush -> poll_flush",
202                file!(),
203                line!()
204            );
205            stream.poll_flush(ctx)
206        }) {
207            Poll::Ready(r) => r,
208            Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
209        }
210    }
211}
212
213pub(crate) fn cvt<T>(r: Result<T, WsError>) -> Poll<Result<T, WsError>> {
214    match r {
215        Ok(v) => Poll::Ready(Ok(v)),
216        Err(WsError::Io(ref e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
217            #[cfg(feature = "verbose-logging")]
218            trace!("WouldBlock");
219            Poll::Pending
220        }
221        Err(e) => Poll::Ready(Err(e)),
222    }
223}