async_tungstenite/
compat.rs1#[allow(unused_imports)]
2use log::*;
3use std::io::{Read, Write};
4use std::pin::Pin;
5use std::task::{Context, Poll, Waker};
6
7use atomic_waker::AtomicWaker;
8use futures_io::{AsyncRead, AsyncWrite};
9use futures_task::{waker_ref, ArcWake};
10use std::sync::Arc;
11use tungstenite::Error as WsError;
12
13pub(crate) enum ContextWaker {
14 Read,
15 Write,
16}
17
18#[derive(Debug)]
19pub(crate) struct AllowStd<S> {
20 inner: S,
21 write_waker_proxy: Arc<WakerProxy>,
45 read_waker_proxy: Arc<WakerProxy>,
46}
47
48#[cfg(feature = "handshake")]
54pub(crate) trait SetWaker {
55 fn set_waker(&self, waker: &Waker);
56}
57
58#[cfg(feature = "handshake")]
59impl<S> SetWaker for AllowStd<S> {
60 fn set_waker(&self, waker: &Waker) {
61 self.set_waker(ContextWaker::Read, waker);
62 }
63}
64
65impl<S> AllowStd<S> {
66 pub(crate) fn new(inner: S, waker: &Waker) -> Self {
67 let res = Self {
68 inner,
69 write_waker_proxy: Default::default(),
70 read_waker_proxy: Default::default(),
71 };
72
73 res.write_waker_proxy.read_waker.register(waker);
76 res.read_waker_proxy.read_waker.register(waker);
77
78 res
79 }
80
81 pub(crate) fn set_waker(&self, kind: ContextWaker, waker: &Waker) {
90 match kind {
91 ContextWaker::Read => {
92 self.write_waker_proxy.read_waker.register(waker);
93 self.read_waker_proxy.read_waker.register(waker);
94 }
95 ContextWaker::Write => {
96 self.write_waker_proxy.write_waker.register(waker);
97 self.read_waker_proxy.write_waker.register(waker);
98 }
99 }
100 }
101
102 pub(crate) fn into_inner(self) -> S {
104 self.inner
105 }
106}
107
108#[derive(Debug, Default)]
113struct WakerProxy {
114 read_waker: AtomicWaker,
115 write_waker: AtomicWaker,
116}
117
118impl ArcWake for WakerProxy {
119 fn wake_by_ref(arc_self: &Arc<Self>) {
120 arc_self.read_waker.wake();
121 arc_self.write_waker.wake();
122 }
123}
124
125impl<S> AllowStd<S>
126where
127 S: Unpin,
128{
129 fn with_context<F, R>(&mut self, kind: ContextWaker, f: F) -> Poll<std::io::Result<R>>
130 where
131 F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> Poll<std::io::Result<R>>,
132 {
133 #[cfg(feature = "verbose-logging")]
134 trace!("{}:{} AllowStd.with_context", file!(), line!());
135 let waker = match kind {
136 ContextWaker::Read => waker_ref(&self.read_waker_proxy),
137 ContextWaker::Write => waker_ref(&self.write_waker_proxy),
138 };
139 let mut context = Context::from_waker(&waker);
140 f(&mut context, Pin::new(&mut self.inner))
141 }
142
143 pub(crate) fn get_mut(&mut self) -> &mut S {
144 &mut self.inner
145 }
146
147 pub(crate) fn get_ref(&self) -> &S {
148 &self.inner
149 }
150}
151
152impl<S> Read for AllowStd<S>
153where
154 S: AsyncRead + Unpin,
155{
156 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
157 #[cfg(feature = "verbose-logging")]
158 trace!("{}:{} Read.read", file!(), line!());
159 match self.with_context(ContextWaker::Read, |ctx, stream| {
160 #[cfg(feature = "verbose-logging")]
161 trace!(
162 "{}:{} Read.with_context read -> poll_read",
163 file!(),
164 line!()
165 );
166 stream.poll_read(ctx, buf)
167 }) {
168 Poll::Ready(r) => r,
169 Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
170 }
171 }
172}
173
174impl<S> Write for AllowStd<S>
175where
176 S: AsyncWrite + Unpin,
177{
178 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
179 #[cfg(feature = "verbose-logging")]
180 trace!("{}:{} Write.write", file!(), line!());
181 match self.with_context(ContextWaker::Write, |ctx, stream| {
182 #[cfg(feature = "verbose-logging")]
183 trace!(
184 "{}:{} Write.with_context write -> poll_write",
185 file!(),
186 line!()
187 );
188 stream.poll_write(ctx, buf)
189 }) {
190 Poll::Ready(r) => r,
191 Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
192 }
193 }
194
195 fn flush(&mut self) -> std::io::Result<()> {
196 #[cfg(feature = "verbose-logging")]
197 trace!("{}:{} Write.flush", file!(), line!());
198 match self.with_context(ContextWaker::Write, |ctx, stream| {
199 #[cfg(feature = "verbose-logging")]
200 trace!(
201 "{}:{} Write.with_context flush -> poll_flush",
202 file!(),
203 line!()
204 );
205 stream.poll_flush(ctx)
206 }) {
207 Poll::Ready(r) => r,
208 Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
209 }
210 }
211}
212
213pub(crate) fn cvt<T>(r: Result<T, WsError>) -> Poll<Result<T, WsError>> {
214 match r {
215 Ok(v) => Poll::Ready(Ok(v)),
216 Err(WsError::Io(ref e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
217 #[cfg(feature = "verbose-logging")]
218 trace!("WouldBlock");
219 Poll::Pending
220 }
221 Err(e) => Poll::Ready(Err(e)),
222 }
223}