1use tungstenite::client::IntoClientRequest;
3use tungstenite::handshake::client::{Request, Response};
4use tungstenite::handshake::server::{Callback, NoCallback};
5use tungstenite::protocol::WebSocketConfig;
6use tungstenite::Error;
7
8use tokio::net::TcpStream;
9
10use super::{domain, port, WebSocketStream};
11
12use futures_io::{AsyncRead, AsyncWrite};
13
14#[cfg(feature = "tokio-native-tls")]
15#[path = "tokio/native_tls.rs"]
16mod tls;
17
18#[cfg(all(
19 any(
20 feature = "tokio-rustls-manual-roots",
21 feature = "tokio-rustls-native-certs",
22 feature = "tokio-rustls-platform-verifier",
23 feature = "tokio-rustls-webpki-roots"
24 ),
25 not(feature = "tokio-native-tls")
26))]
27#[path = "tokio/rustls.rs"]
28mod tls;
29
30#[cfg(all(
31 feature = "tokio-openssl",
32 not(any(
33 feature = "tokio-native-tls",
34 feature = "tokio-rustls-manual-roots",
35 feature = "tokio-rustls-native-certs",
36 feature = "tokio-rustls-platform-verifier",
37 feature = "tokio-rustls-webpki-roots"
38 ))
39))]
40#[path = "tokio/openssl.rs"]
41mod tls;
42
43#[cfg(all(
44 feature = "async-tls",
45 not(any(
46 feature = "tokio-native-tls",
47 feature = "tokio-rustls-manual-roots",
48 feature = "tokio-rustls-native-certs",
49 feature = "tokio-rustls-platform-verifier",
50 feature = "tokio-rustls-webpki-roots",
51 feature = "tokio-openssl"
52 ))
53))]
54#[path = "tokio/async_tls.rs"]
55mod tls;
56
57#[cfg(not(any(
58 feature = "tokio-native-tls",
59 feature = "tokio-rustls-manual-roots",
60 feature = "tokio-rustls-native-certs",
61 feature = "tokio-rustls-platform-verifier",
62 feature = "tokio-rustls-webpki-roots",
63 feature = "tokio-openssl",
64 feature = "async-tls"
65)))]
66#[path = "tokio/dummy_tls.rs"]
67mod tls;
68
69pub use self::tls::client_async_tls_with_connector_and_config;
70
71#[cfg(any(
72 feature = "tokio-native-tls",
73 feature = "tokio-rustls-manual-roots",
74 feature = "tokio-rustls-native-certs",
75 feature = "tokio-rustls-platform-verifier",
76 feature = "tokio-rustls-webpki-roots",
77 feature = "tokio-openssl",
78 feature = "async-tls"
79))]
80use self::tls::Connector;
81
82use self::tls::AutoStream;
83
84pub async fn client_async<'a, R, S>(
97 request: R,
98 stream: S,
99) -> Result<(WebSocketStream<TokioAdapter<S>>, Response), Error>
100where
101 R: IntoClientRequest + Unpin,
102 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
103{
104 client_async_with_config(request, stream, None).await
105}
106
107pub async fn client_async_with_config<'a, R, S>(
110 request: R,
111 stream: S,
112 config: Option<WebSocketConfig>,
113) -> Result<(WebSocketStream<TokioAdapter<S>>, Response), Error>
114where
115 R: IntoClientRequest + Unpin,
116 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
117{
118 crate::client_async_with_config(request, TokioAdapter::new(stream), config).await
119}
120
121pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<TokioAdapter<S>>, Error>
133where
134 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
135{
136 accept_hdr_async(stream, NoCallback).await
137}
138
139pub async fn accept_async_with_config<S>(
142 stream: S,
143 config: Option<WebSocketConfig>,
144) -> Result<WebSocketStream<TokioAdapter<S>>, Error>
145where
146 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
147{
148 accept_hdr_async_with_config(stream, NoCallback, config).await
149}
150
151pub async fn accept_hdr_async<S, C>(
157 stream: S,
158 callback: C,
159) -> Result<WebSocketStream<TokioAdapter<S>>, Error>
160where
161 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
162 C: Callback + Unpin,
163{
164 accept_hdr_async_with_config(stream, callback, None).await
165}
166
167pub async fn accept_hdr_async_with_config<S, C>(
170 stream: S,
171 callback: C,
172 config: Option<WebSocketConfig>,
173) -> Result<WebSocketStream<TokioAdapter<S>>, Error>
174where
175 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
176 C: Callback + Unpin,
177{
178 crate::accept_hdr_async_with_config(TokioAdapter::new(stream), callback, config).await
179}
180
181pub type ClientStream<S> = AutoStream<S>;
183
184#[cfg(any(
185 feature = "tokio-native-tls",
186 feature = "tokio-rustls-native-certs",
187 feature = "tokio-rustls-platform-verifier",
188 feature = "tokio-rustls-webpki-roots",
189 all(feature = "__rustls-tls", not(feature = "tokio-rustls-manual-roots")), all(feature = "async-tls", not(feature = "tokio-openssl"))
191))]
192pub async fn client_async_tls<R, S>(
195 request: R,
196 stream: S,
197) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
198where
199 R: IntoClientRequest + Unpin,
200 S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
201 AutoStream<S>: Unpin,
202{
203 client_async_tls_with_connector_and_config(request, stream, None, None).await
204}
205
206#[cfg(any(
207 feature = "tokio-native-tls",
208 feature = "tokio-rustls-native-certs",
209 feature = "tokio-rustls-platform-verifier",
210 feature = "tokio-rustls-webpki-roots",
211 all(feature = "__rustls-tls", not(feature = "tokio-rustls-manual-roots")), all(feature = "async-tls", not(feature = "tokio-openssl"))
213))]
214pub async fn client_async_tls_with_config<R, S>(
218 request: R,
219 stream: S,
220 config: Option<WebSocketConfig>,
221) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
222where
223 R: IntoClientRequest + Unpin,
224 S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
225 AutoStream<S>: Unpin,
226{
227 client_async_tls_with_connector_and_config(request, stream, None, config).await
228}
229
230#[cfg(any(
231 feature = "tokio-native-tls",
232 feature = "tokio-rustls-manual-roots",
233 feature = "tokio-rustls-native-certs",
234 feature = "tokio-rustls-platform-verifier",
235 feature = "tokio-rustls-webpki-roots",
236 all(feature = "async-tls", not(feature = "tokio-openssl"))
237))]
238pub async fn client_async_tls_with_connector<R, S>(
242 request: R,
243 stream: S,
244 connector: Option<Connector>,
245) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
246where
247 R: IntoClientRequest + Unpin,
248 S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
249 AutoStream<S>: Unpin,
250{
251 client_async_tls_with_connector_and_config(request, stream, connector, None).await
252}
253
254#[cfg(all(
255 feature = "tokio-openssl",
256 not(any(
257 feature = "tokio-native-tls",
258 feature = "tokio-rustls-manual-roots",
259 feature = "tokio-rustls-native-certs",
260 feature = "tokio-rustls-platform-verifier",
261 feature = "tokio-rustls-webpki-roots"
262 ))
263))]
264pub async fn client_async_tls<R, S>(
267 request: R,
268 stream: S,
269) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
270where
271 R: IntoClientRequest + Unpin,
272 S: 'static
273 + tokio::io::AsyncRead
274 + tokio::io::AsyncWrite
275 + Unpin
276 + std::fmt::Debug
277 + Send
278 + Sync,
279 AutoStream<S>: Unpin,
280{
281 client_async_tls_with_connector_and_config(request, stream, None, None).await
282}
283
284#[cfg(all(
285 feature = "tokio-openssl",
286 not(any(
287 feature = "tokio-native-tls",
288 feature = "tokio-rustls-manual-roots",
289 feature = "tokio-rustls-native-certs",
290 feature = "tokio-rustls-platform-verifier",
291 feature = "tokio-rustls-webpki-roots"
292 ))
293))]
294pub async fn client_async_tls_with_config<R, S>(
298 request: R,
299 stream: S,
300 config: Option<WebSocketConfig>,
301) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
302where
303 R: IntoClientRequest + Unpin,
304 S: 'static
305 + tokio::io::AsyncRead
306 + tokio::io::AsyncWrite
307 + Unpin
308 + std::fmt::Debug
309 + Send
310 + Sync,
311 AutoStream<S>: Unpin,
312{
313 client_async_tls_with_connector_and_config(request, stream, None, config).await
314}
315
316#[cfg(all(
317 feature = "tokio-openssl",
318 not(any(
319 feature = "tokio-native-tls",
320 feature = "tokio-rustls-manual-roots",
321 feature = "tokio-rustls-native-certs",
322 feature = "tokio-rustls-platform-verifier",
323 feature = "tokio-rustls-webpki-roots"
324 ))
325))]
326pub async fn client_async_tls_with_connector<R, S>(
330 request: R,
331 stream: S,
332 connector: Option<Connector>,
333) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
334where
335 R: IntoClientRequest + Unpin,
336 S: 'static
337 + tokio::io::AsyncRead
338 + tokio::io::AsyncWrite
339 + Unpin
340 + std::fmt::Debug
341 + Send
342 + Sync,
343 AutoStream<S>: Unpin,
344{
345 client_async_tls_with_connector_and_config(request, stream, connector, None).await
346}
347
348pub type ConnectStream = ClientStream<TcpStream>;
350
351pub async fn connect_async<R>(
371 request: R,
372) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
373where
374 R: IntoClientRequest + Unpin,
375{
376 connect_async_with_config(request, None).await
377}
378
379pub async fn connect_async_with_config<R>(
381 request: R,
382 config: Option<WebSocketConfig>,
383) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
384where
385 R: IntoClientRequest + Unpin,
386{
387 let request: Request = request.into_client_request()?;
388
389 let domain = domain(&request)?;
390 let port = port(&request)?;
391
392 let try_socket = TcpStream::connect((domain.as_str(), port)).await;
393 let socket = try_socket.map_err(Error::Io)?;
394 client_async_tls_with_connector_and_config(request, socket, None, config).await
395}
396
397#[cfg(any(
398 feature = "async-tls",
399 feature = "tokio-native-tls",
400 feature = "tokio-rustls-manual-roots",
401 feature = "tokio-rustls-native-certs",
402 feature = "tokio-rustls-platform-verifier",
403 feature = "tokio-rustls-webpki-roots",
404 feature = "tokio-openssl"
405))]
406pub async fn connect_async_with_tls_connector<R>(
408 request: R,
409 connector: Option<Connector>,
410) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
411where
412 R: IntoClientRequest + Unpin,
413{
414 connect_async_with_tls_connector_and_config(request, connector, None).await
415}
416
417#[cfg(any(
418 feature = "async-tls",
419 feature = "tokio-native-tls",
420 feature = "tokio-rustls-manual-roots",
421 feature = "tokio-rustls-native-certs",
422 feature = "tokio-rustls-platform-verifier",
423 feature = "tokio-rustls-webpki-roots",
424 feature = "tokio-openssl"
425))]
426pub async fn connect_async_with_tls_connector_and_config<R>(
428 request: R,
429 connector: Option<Connector>,
430 config: Option<WebSocketConfig>,
431) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
432where
433 R: IntoClientRequest + Unpin,
434{
435 let request: Request = request.into_client_request()?;
436
437 let domain = domain(&request)?;
438 let port = port(&request)?;
439
440 let try_socket = TcpStream::connect((domain.as_str(), port)).await;
441 let socket = try_socket.map_err(Error::Io)?;
442 client_async_tls_with_connector_and_config(request, socket, connector, config).await
443}
444
445use std::pin::Pin;
446use std::task::{Context, Poll};
447
448pin_project_lite::pin_project! {
449 #[derive(Debug, Clone)]
452 pub struct TokioAdapter<T> {
453 #[pin]
454 inner: T,
455 }
456}
457
458impl<T> TokioAdapter<T> {
459 pub fn new(inner: T) -> Self {
461 Self { inner }
462 }
463
464 pub fn into_inner(self) -> T {
466 self.inner
467 }
468
469 pub fn get_ref(&self) -> &T {
471 &self.inner
472 }
473
474 pub fn get_mut(&mut self) -> &mut T {
476 &mut self.inner
477 }
478}
479
480impl<T: tokio::io::AsyncRead> AsyncRead for TokioAdapter<T> {
481 fn poll_read(
482 self: Pin<&mut Self>,
483 cx: &mut Context<'_>,
484 buf: &mut [u8],
485 ) -> Poll<std::io::Result<usize>> {
486 let mut buf = tokio::io::ReadBuf::new(buf);
487 match self.project().inner.poll_read(cx, &mut buf)? {
488 Poll::Pending => Poll::Pending,
489 Poll::Ready(_) => Poll::Ready(Ok(buf.filled().len())),
490 }
491 }
492}
493
494impl<T: tokio::io::AsyncWrite> AsyncWrite for TokioAdapter<T> {
495 fn poll_write(
496 self: Pin<&mut Self>,
497 cx: &mut Context<'_>,
498 buf: &[u8],
499 ) -> Poll<Result<usize, std::io::Error>> {
500 self.project().inner.poll_write(cx, buf)
501 }
502
503 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
504 self.project().inner.poll_flush(cx)
505 }
506
507 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
508 self.project().inner.poll_shutdown(cx)
509 }
510}
511
512impl<T: AsyncRead> tokio::io::AsyncRead for TokioAdapter<T> {
513 fn poll_read(
514 self: Pin<&mut Self>,
515 cx: &mut Context<'_>,
516 buf: &mut tokio::io::ReadBuf<'_>,
517 ) -> Poll<std::io::Result<()>> {
518 let slice = buf.initialize_unfilled();
519 let n = match self.project().inner.poll_read(cx, slice)? {
520 Poll::Pending => return Poll::Pending,
521 Poll::Ready(n) => n,
522 };
523 buf.advance(n);
524 Poll::Ready(Ok(()))
525 }
526}
527
528impl<T: AsyncWrite> tokio::io::AsyncWrite for TokioAdapter<T> {
529 fn poll_write(
530 self: Pin<&mut Self>,
531 cx: &mut Context<'_>,
532 buf: &[u8],
533 ) -> Poll<Result<usize, std::io::Error>> {
534 self.project().inner.poll_write(cx, buf)
535 }
536
537 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
538 self.project().inner.poll_flush(cx)
539 }
540
541 fn poll_shutdown(
542 self: Pin<&mut Self>,
543 cx: &mut Context<'_>,
544 ) -> Poll<Result<(), std::io::Error>> {
545 self.project().inner.poll_close(cx)
546 }
547}