async_tungstenite/
compat.rs1#[allow(unused_imports)]
2use log::*;
3use std::io::{Read, Write};
4use std::pin::Pin;
5use std::task::{Context, Poll, Waker};
6
7use atomic_waker::AtomicWaker;
8use futures_io::{AsyncRead, AsyncWrite};
9use futures_task::{waker_ref, ArcWake};
10use std::sync::Arc;
11use tungstenite::Error as WsError;
12
13pub(crate) enum ContextWaker {
14    Read,
15    Write,
16}
17
18#[derive(Debug)]
19pub(crate) struct AllowStd<S> {
20    inner: S,
21    write_waker_proxy: Arc<WakerProxy>,
45    read_waker_proxy: Arc<WakerProxy>,
46}
47
48#[cfg(feature = "handshake")]
54pub(crate) trait SetWaker {
55    fn set_waker(&self, waker: &Waker);
56}
57
58#[cfg(feature = "handshake")]
59impl<S> SetWaker for AllowStd<S> {
60    fn set_waker(&self, waker: &Waker) {
61        self.set_waker(ContextWaker::Read, waker);
62    }
63}
64
65impl<S> AllowStd<S> {
66    pub(crate) fn new(inner: S, waker: &Waker) -> Self {
67        let res = Self {
68            inner,
69            write_waker_proxy: Default::default(),
70            read_waker_proxy: Default::default(),
71        };
72
73        res.write_waker_proxy.read_waker.register(waker);
76        res.read_waker_proxy.read_waker.register(waker);
77
78        res
79    }
80
81    pub(crate) fn set_waker(&self, kind: ContextWaker, waker: &Waker) {
90        match kind {
91            ContextWaker::Read => {
92                self.write_waker_proxy.read_waker.register(waker);
93                self.read_waker_proxy.read_waker.register(waker);
94            }
95            ContextWaker::Write => {
96                self.write_waker_proxy.write_waker.register(waker);
97                self.read_waker_proxy.write_waker.register(waker);
98            }
99        }
100    }
101
102    pub(crate) fn into_inner(self) -> S {
104        self.inner
105    }
106}
107
108#[derive(Debug, Default)]
113struct WakerProxy {
114    read_waker: AtomicWaker,
115    write_waker: AtomicWaker,
116}
117
118impl ArcWake for WakerProxy {
119    fn wake_by_ref(arc_self: &Arc<Self>) {
120        arc_self.read_waker.wake();
121        arc_self.write_waker.wake();
122    }
123}
124
125impl<S> AllowStd<S>
126where
127    S: Unpin,
128{
129    fn with_context<F, R>(&mut self, kind: ContextWaker, f: F) -> Poll<std::io::Result<R>>
130    where
131        F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> Poll<std::io::Result<R>>,
132    {
133        #[cfg(feature = "verbose-logging")]
134        trace!("{}:{} AllowStd.with_context", file!(), line!());
135        let waker = match kind {
136            ContextWaker::Read => waker_ref(&self.read_waker_proxy),
137            ContextWaker::Write => waker_ref(&self.write_waker_proxy),
138        };
139        let mut context = Context::from_waker(&waker);
140        f(&mut context, Pin::new(&mut self.inner))
141    }
142
143    pub(crate) fn get_mut(&mut self) -> &mut S {
144        &mut self.inner
145    }
146
147    pub(crate) fn get_ref(&self) -> &S {
148        &self.inner
149    }
150}
151
152impl<S> Read for AllowStd<S>
153where
154    S: AsyncRead + Unpin,
155{
156    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
157        #[cfg(feature = "verbose-logging")]
158        trace!("{}:{} Read.read", file!(), line!());
159        match self.with_context(ContextWaker::Read, |ctx, stream| {
160            #[cfg(feature = "verbose-logging")]
161            trace!(
162                "{}:{} Read.with_context read -> poll_read",
163                file!(),
164                line!()
165            );
166            stream.poll_read(ctx, buf)
167        }) {
168            Poll::Ready(r) => r,
169            Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
170        }
171    }
172}
173
174impl<S> Write for AllowStd<S>
175where
176    S: AsyncWrite + Unpin,
177{
178    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
179        #[cfg(feature = "verbose-logging")]
180        trace!("{}:{} Write.write", file!(), line!());
181        match self.with_context(ContextWaker::Write, |ctx, stream| {
182            #[cfg(feature = "verbose-logging")]
183            trace!(
184                "{}:{} Write.with_context write -> poll_write",
185                file!(),
186                line!()
187            );
188            stream.poll_write(ctx, buf)
189        }) {
190            Poll::Ready(r) => r,
191            Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
192        }
193    }
194
195    fn flush(&mut self) -> std::io::Result<()> {
196        #[cfg(feature = "verbose-logging")]
197        trace!("{}:{} Write.flush", file!(), line!());
198        match self.with_context(ContextWaker::Write, |ctx, stream| {
199            #[cfg(feature = "verbose-logging")]
200            trace!(
201                "{}:{} Write.with_context flush -> poll_flush",
202                file!(),
203                line!()
204            );
205            stream.poll_flush(ctx)
206        }) {
207            Poll::Ready(r) => r,
208            Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
209        }
210    }
211}
212
213pub(crate) fn cvt<T>(r: Result<T, WsError>) -> Poll<Result<T, WsError>> {
214    match r {
215        Ok(v) => Poll::Ready(Ok(v)),
216        Err(WsError::Io(ref e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
217            #[cfg(feature = "verbose-logging")]
218            trace!("WouldBlock");
219            Poll::Pending
220        }
221        Err(e) => Poll::Ready(Err(e)),
222    }
223}