async_tungstenite/
compat.rs1#[allow(unused_imports)]
2use log::*;
3use std::io::{Read, Write};
4use std::pin::Pin;
5use std::task::{Context, Poll, Waker};
6
7use atomic_waker::AtomicWaker;
8use futures_io::{AsyncRead, AsyncWrite};
9use futures_task::{waker_ref, ArcWake};
10use std::sync::Arc;
11use tungstenite::Error as WsError;
12
13pub(crate) enum ContextWaker {
14 Read,
15 Write,
16}
17
18#[derive(Debug)]
19pub(crate) struct AllowStd<S> {
20 inner: S,
21 write_waker_proxy: Arc<WakerProxy>,
45 read_waker_proxy: Arc<WakerProxy>,
46}
47
48#[cfg(feature = "handshake")]
54pub(crate) trait SetWaker {
55 fn set_waker(&self, waker: &Waker);
56}
57
58#[cfg(feature = "handshake")]
59impl<S> SetWaker for AllowStd<S> {
60 fn set_waker(&self, waker: &Waker) {
61 self.set_waker(ContextWaker::Read, waker);
62 }
63}
64
65impl<S> AllowStd<S> {
66 pub(crate) fn new(inner: S, waker: &Waker) -> Self {
67 let res = Self {
68 inner,
69 write_waker_proxy: Default::default(),
70 read_waker_proxy: Default::default(),
71 };
72
73 res.write_waker_proxy.read_waker.register(waker);
76 res.read_waker_proxy.read_waker.register(waker);
77
78 res
79 }
80
81 pub(crate) fn set_waker(&self, kind: ContextWaker, waker: &Waker) {
90 match kind {
91 ContextWaker::Read => {
92 self.write_waker_proxy.read_waker.register(waker);
93 self.read_waker_proxy.read_waker.register(waker);
94 }
95 ContextWaker::Write => {
96 self.write_waker_proxy.write_waker.register(waker);
97 self.read_waker_proxy.write_waker.register(waker);
98 }
99 }
100 }
101}
102
103#[derive(Debug, Default)]
108struct WakerProxy {
109 read_waker: AtomicWaker,
110 write_waker: AtomicWaker,
111}
112
113impl ArcWake for WakerProxy {
114 fn wake_by_ref(arc_self: &Arc<Self>) {
115 arc_self.read_waker.wake();
116 arc_self.write_waker.wake();
117 }
118}
119
120impl<S> AllowStd<S>
121where
122 S: Unpin,
123{
124 fn with_context<F, R>(&mut self, kind: ContextWaker, f: F) -> Poll<std::io::Result<R>>
125 where
126 F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> Poll<std::io::Result<R>>,
127 {
128 #[cfg(feature = "verbose-logging")]
129 trace!("{}:{} AllowStd.with_context", file!(), line!());
130 let waker = match kind {
131 ContextWaker::Read => waker_ref(&self.read_waker_proxy),
132 ContextWaker::Write => waker_ref(&self.write_waker_proxy),
133 };
134 let mut context = Context::from_waker(&waker);
135 f(&mut context, Pin::new(&mut self.inner))
136 }
137
138 pub(crate) fn get_mut(&mut self) -> &mut S {
139 &mut self.inner
140 }
141
142 pub(crate) fn get_ref(&self) -> &S {
143 &self.inner
144 }
145}
146
147impl<S> Read for AllowStd<S>
148where
149 S: AsyncRead + Unpin,
150{
151 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
152 #[cfg(feature = "verbose-logging")]
153 trace!("{}:{} Read.read", file!(), line!());
154 match self.with_context(ContextWaker::Read, |ctx, stream| {
155 #[cfg(feature = "verbose-logging")]
156 trace!(
157 "{}:{} Read.with_context read -> poll_read",
158 file!(),
159 line!()
160 );
161 stream.poll_read(ctx, buf)
162 }) {
163 Poll::Ready(r) => r,
164 Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
165 }
166 }
167}
168
169impl<S> Write for AllowStd<S>
170where
171 S: AsyncWrite + Unpin,
172{
173 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
174 #[cfg(feature = "verbose-logging")]
175 trace!("{}:{} Write.write", file!(), line!());
176 match self.with_context(ContextWaker::Write, |ctx, stream| {
177 #[cfg(feature = "verbose-logging")]
178 trace!(
179 "{}:{} Write.with_context write -> poll_write",
180 file!(),
181 line!()
182 );
183 stream.poll_write(ctx, buf)
184 }) {
185 Poll::Ready(r) => r,
186 Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
187 }
188 }
189
190 fn flush(&mut self) -> std::io::Result<()> {
191 #[cfg(feature = "verbose-logging")]
192 trace!("{}:{} Write.flush", file!(), line!());
193 match self.with_context(ContextWaker::Write, |ctx, stream| {
194 #[cfg(feature = "verbose-logging")]
195 trace!(
196 "{}:{} Write.with_context flush -> poll_flush",
197 file!(),
198 line!()
199 );
200 stream.poll_flush(ctx)
201 }) {
202 Poll::Ready(r) => r,
203 Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
204 }
205 }
206}
207
208pub(crate) fn cvt<T>(r: Result<T, WsError>) -> Poll<Result<T, WsError>> {
209 match r {
210 Ok(v) => Poll::Ready(Ok(v)),
211 Err(WsError::Io(ref e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
212 #[cfg(feature = "verbose-logging")]
213 trace!("WouldBlock");
214 Poll::Pending
215 }
216 Err(e) => Poll::Ready(Err(e)),
217 }
218}