1use tungstenite::client::IntoClientRequest;
3use tungstenite::handshake::client::{Request, Response};
4use tungstenite::handshake::server::{Callback, NoCallback};
5use tungstenite::protocol::WebSocketConfig;
6use tungstenite::Error;
7
8use tokio::net::TcpStream;
9
10use super::{domain, port, WebSocketStream};
11
12use futures_io::{AsyncRead, AsyncWrite};
13
14#[cfg(feature = "tokio-native-tls")]
15#[path = "tokio/native_tls.rs"]
16mod tls;
17
18#[cfg(all(
19 any(
20 feature = "tokio-rustls-manual-roots",
21 feature = "tokio-rustls-native-certs",
22 feature = "tokio-rustls-platform-verifier",
23 feature = "tokio-rustls-webpki-roots"
24 ),
25 not(feature = "tokio-native-tls")
26))]
27#[path = "tokio/rustls.rs"]
28mod tls;
29
30#[cfg(all(
31 feature = "tokio-openssl",
32 not(any(
33 feature = "tokio-native-tls",
34 feature = "tokio-rustls-manual-roots",
35 feature = "tokio-rustls-native-certs",
36 feature = "tokio-rustls-platform-verifier",
37 feature = "tokio-rustls-webpki-roots"
38 ))
39))]
40#[path = "tokio/openssl.rs"]
41mod tls;
42
43#[cfg(all(
44 feature = "async-tls",
45 not(any(
46 feature = "tokio-native-tls",
47 feature = "tokio-rustls-manual-roots",
48 feature = "tokio-rustls-native-certs",
49 feature = "tokio-rustls-platform-verifier",
50 feature = "tokio-rustls-webpki-roots",
51 feature = "tokio-openssl"
52 ))
53))]
54#[path = "tokio/async_tls.rs"]
55mod tls;
56
57#[cfg(not(any(
58 feature = "tokio-native-tls",
59 feature = "tokio-rustls-manual-roots",
60 feature = "tokio-rustls-native-certs",
61 feature = "tokio-rustls-platform-verifier",
62 feature = "tokio-rustls-webpki-roots",
63 feature = "tokio-openssl",
64 feature = "async-tls"
65)))]
66#[path = "tokio/dummy_tls.rs"]
67mod tls;
68
69#[cfg(any(
70 feature = "tokio-native-tls",
71 feature = "tokio-rustls-manual-roots",
72 feature = "tokio-rustls-native-certs",
73 feature = "tokio-rustls-platform-verifier",
74 feature = "tokio-rustls-webpki-roots",
75 feature = "tokio-openssl",
76 feature = "async-tls",
77))]
78pub use self::tls::client_async_tls_with_connector_and_config;
79#[cfg(any(
80 feature = "tokio-native-tls",
81 feature = "tokio-rustls-manual-roots",
82 feature = "tokio-rustls-native-certs",
83 feature = "tokio-rustls-platform-verifier",
84 feature = "tokio-rustls-webpki-roots",
85 feature = "tokio-openssl",
86 feature = "async-tls"
87))]
88use self::tls::{AutoStream, Connector};
89
90#[cfg(not(any(
91 feature = "tokio-native-tls",
92 feature = "tokio-rustls-manual-roots",
93 feature = "tokio-rustls-native-certs",
94 feature = "tokio-rustls-platform-verifier",
95 feature = "tokio-rustls-webpki-roots",
96 feature = "tokio-openssl",
97 feature = "async-tls"
98)))]
99pub use self::tls::client_async_tls_with_connector_and_config;
100#[cfg(not(any(
101 feature = "tokio-native-tls",
102 feature = "tokio-rustls-manual-roots",
103 feature = "tokio-rustls-native-certs",
104 feature = "tokio-rustls-platform-verifier",
105 feature = "tokio-rustls-webpki-roots",
106 feature = "tokio-openssl",
107 feature = "async-tls"
108)))]
109use self::tls::AutoStream;
110
111pub async fn client_async<'a, R, S>(
124 request: R,
125 stream: S,
126) -> Result<(WebSocketStream<TokioAdapter<S>>, Response), Error>
127where
128 R: IntoClientRequest + Unpin,
129 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
130{
131 client_async_with_config(request, stream, None).await
132}
133
134pub async fn client_async_with_config<'a, R, S>(
137 request: R,
138 stream: S,
139 config: Option<WebSocketConfig>,
140) -> Result<(WebSocketStream<TokioAdapter<S>>, Response), Error>
141where
142 R: IntoClientRequest + Unpin,
143 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
144{
145 crate::client_async_with_config(request, TokioAdapter::new(stream), config).await
146}
147
148pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<TokioAdapter<S>>, Error>
160where
161 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
162{
163 accept_hdr_async(stream, NoCallback).await
164}
165
166pub async fn accept_async_with_config<S>(
169 stream: S,
170 config: Option<WebSocketConfig>,
171) -> Result<WebSocketStream<TokioAdapter<S>>, Error>
172where
173 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
174{
175 accept_hdr_async_with_config(stream, NoCallback, config).await
176}
177
178pub async fn accept_hdr_async<S, C>(
184 stream: S,
185 callback: C,
186) -> Result<WebSocketStream<TokioAdapter<S>>, Error>
187where
188 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
189 C: Callback + Unpin,
190{
191 accept_hdr_async_with_config(stream, callback, None).await
192}
193
194pub async fn accept_hdr_async_with_config<S, C>(
197 stream: S,
198 callback: C,
199 config: Option<WebSocketConfig>,
200) -> Result<WebSocketStream<TokioAdapter<S>>, Error>
201where
202 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
203 C: Callback + Unpin,
204{
205 crate::accept_hdr_async_with_config(TokioAdapter::new(stream), callback, config).await
206}
207
208pub type ClientStream<S> = AutoStream<S>;
210
211#[cfg(any(
212 feature = "tokio-native-tls",
213 feature = "tokio-rustls-native-certs",
214 feature = "tokio-rustls-platform-verifier",
215 feature = "tokio-rustls-webpki-roots",
216 all(feature = "__rustls-tls", not(feature = "tokio-rustls-manual-roots")), all(feature = "async-tls", not(feature = "tokio-openssl"))
218))]
219pub async fn client_async_tls<R, S>(
222 request: R,
223 stream: S,
224) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
225where
226 R: IntoClientRequest + Unpin,
227 S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
228 AutoStream<S>: Unpin,
229{
230 client_async_tls_with_connector_and_config(request, stream, None, None).await
231}
232
233#[cfg(any(
234 feature = "tokio-native-tls",
235 feature = "tokio-rustls-native-certs",
236 feature = "tokio-rustls-platform-verifier",
237 feature = "tokio-rustls-webpki-roots",
238 all(feature = "__rustls-tls", not(feature = "tokio-rustls-manual-roots")), all(feature = "async-tls", not(feature = "tokio-openssl"))
240))]
241pub async fn client_async_tls_with_config<R, S>(
245 request: R,
246 stream: S,
247 config: Option<WebSocketConfig>,
248) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
249where
250 R: IntoClientRequest + Unpin,
251 S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
252 AutoStream<S>: Unpin,
253{
254 client_async_tls_with_connector_and_config(request, stream, None, config).await
255}
256
257#[cfg(any(
258 feature = "tokio-native-tls",
259 feature = "tokio-rustls-manual-roots",
260 feature = "tokio-rustls-native-certs",
261 feature = "tokio-rustls-platform-verifier",
262 feature = "tokio-rustls-webpki-roots",
263 all(feature = "async-tls", not(feature = "tokio-openssl"))
264))]
265pub async fn client_async_tls_with_connector<R, S>(
269 request: R,
270 stream: S,
271 connector: Option<Connector>,
272) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
273where
274 R: IntoClientRequest + Unpin,
275 S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
276 AutoStream<S>: Unpin,
277{
278 client_async_tls_with_connector_and_config(request, stream, connector, None).await
279}
280
281#[cfg(all(
282 feature = "tokio-openssl",
283 not(any(
284 feature = "tokio-native-tls",
285 feature = "tokio-rustls-manual-roots",
286 feature = "tokio-rustls-native-certs",
287 feature = "tokio-rustls-platform-verifier",
288 feature = "tokio-rustls-webpki-roots"
289 ))
290))]
291pub async fn client_async_tls<R, S>(
294 request: R,
295 stream: S,
296) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
297where
298 R: IntoClientRequest + Unpin,
299 S: 'static
300 + tokio::io::AsyncRead
301 + tokio::io::AsyncWrite
302 + Unpin
303 + std::fmt::Debug
304 + Send
305 + Sync,
306 AutoStream<S>: Unpin,
307{
308 client_async_tls_with_connector_and_config(request, stream, None, None).await
309}
310
311#[cfg(all(
312 feature = "tokio-openssl",
313 not(any(
314 feature = "tokio-native-tls",
315 feature = "tokio-rustls-manual-roots",
316 feature = "tokio-rustls-native-certs",
317 feature = "tokio-rustls-platform-verifier",
318 feature = "tokio-rustls-webpki-roots"
319 ))
320))]
321pub async fn client_async_tls_with_config<R, S>(
325 request: R,
326 stream: S,
327 config: Option<WebSocketConfig>,
328) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
329where
330 R: IntoClientRequest + Unpin,
331 S: 'static
332 + tokio::io::AsyncRead
333 + tokio::io::AsyncWrite
334 + Unpin
335 + std::fmt::Debug
336 + Send
337 + Sync,
338 AutoStream<S>: Unpin,
339{
340 client_async_tls_with_connector_and_config(request, stream, None, config).await
341}
342
343#[cfg(all(
344 feature = "tokio-openssl",
345 not(any(
346 feature = "tokio-native-tls",
347 feature = "tokio-rustls-manual-roots",
348 feature = "tokio-rustls-native-certs",
349 feature = "tokio-rustls-platform-verifier",
350 feature = "tokio-rustls-webpki-roots"
351 ))
352))]
353pub async fn client_async_tls_with_connector<R, S>(
357 request: R,
358 stream: S,
359 connector: Option<Connector>,
360) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
361where
362 R: IntoClientRequest + Unpin,
363 S: 'static
364 + tokio::io::AsyncRead
365 + tokio::io::AsyncWrite
366 + Unpin
367 + std::fmt::Debug
368 + Send
369 + Sync,
370 AutoStream<S>: Unpin,
371{
372 client_async_tls_with_connector_and_config(request, stream, connector, None).await
373}
374
375pub type ConnectStream = ClientStream<TcpStream>;
377
378pub async fn connect_async<R>(
398 request: R,
399) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
400where
401 R: IntoClientRequest + Unpin,
402{
403 connect_async_with_config(request, None).await
404}
405
406pub async fn connect_async_with_config<R>(
408 request: R,
409 config: Option<WebSocketConfig>,
410) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
411where
412 R: IntoClientRequest + Unpin,
413{
414 let request: Request = request.into_client_request()?;
415
416 let domain = domain(&request)?;
417 let port = port(&request)?;
418
419 let try_socket = TcpStream::connect((domain.as_str(), port)).await;
420 let socket = try_socket.map_err(Error::Io)?;
421 client_async_tls_with_connector_and_config(request, socket, None, config).await
422}
423
424#[cfg(any(
425 feature = "async-tls",
426 feature = "tokio-native-tls",
427 feature = "tokio-rustls-manual-roots",
428 feature = "tokio-rustls-native-certs",
429 feature = "tokio-rustls-platform-verifier",
430 feature = "tokio-rustls-webpki-roots",
431 feature = "tokio-openssl"
432))]
433pub async fn connect_async_with_tls_connector<R>(
435 request: R,
436 connector: Option<Connector>,
437) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
438where
439 R: IntoClientRequest + Unpin,
440{
441 connect_async_with_tls_connector_and_config(request, connector, None).await
442}
443
444#[cfg(any(
445 feature = "async-tls",
446 feature = "tokio-native-tls",
447 feature = "tokio-rustls-manual-roots",
448 feature = "tokio-rustls-native-certs",
449 feature = "tokio-rustls-platform-verifier",
450 feature = "tokio-rustls-webpki-roots",
451 feature = "tokio-openssl"
452))]
453pub async fn connect_async_with_tls_connector_and_config<R>(
455 request: R,
456 connector: Option<Connector>,
457 config: Option<WebSocketConfig>,
458) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
459where
460 R: IntoClientRequest + Unpin,
461{
462 let request: Request = request.into_client_request()?;
463
464 let domain = domain(&request)?;
465 let port = port(&request)?;
466
467 let try_socket = TcpStream::connect((domain.as_str(), port)).await;
468 let socket = try_socket.map_err(Error::Io)?;
469 client_async_tls_with_connector_and_config(request, socket, connector, config).await
470}
471
472use std::pin::Pin;
473use std::task::{Context, Poll};
474
475pin_project_lite::pin_project! {
476 #[derive(Debug, Clone)]
479 pub struct TokioAdapter<T> {
480 #[pin]
481 inner: T,
482 }
483}
484
485impl<T> TokioAdapter<T> {
486 pub fn new(inner: T) -> Self {
488 Self { inner }
489 }
490
491 pub fn into_inner(self) -> T {
493 self.inner
494 }
495
496 pub fn get_ref(&self) -> &T {
498 &self.inner
499 }
500
501 pub fn get_mut(&mut self) -> &mut T {
503 &mut self.inner
504 }
505}
506
507impl<T: tokio::io::AsyncRead> AsyncRead for TokioAdapter<T> {
508 fn poll_read(
509 self: Pin<&mut Self>,
510 cx: &mut Context<'_>,
511 buf: &mut [u8],
512 ) -> Poll<std::io::Result<usize>> {
513 let mut buf = tokio::io::ReadBuf::new(buf);
514 match self.project().inner.poll_read(cx, &mut buf)? {
515 Poll::Pending => Poll::Pending,
516 Poll::Ready(_) => Poll::Ready(Ok(buf.filled().len())),
517 }
518 }
519}
520
521impl<T: tokio::io::AsyncWrite> AsyncWrite for TokioAdapter<T> {
522 fn poll_write(
523 self: Pin<&mut Self>,
524 cx: &mut Context<'_>,
525 buf: &[u8],
526 ) -> Poll<Result<usize, std::io::Error>> {
527 self.project().inner.poll_write(cx, buf)
528 }
529
530 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
531 self.project().inner.poll_flush(cx)
532 }
533
534 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
535 self.project().inner.poll_shutdown(cx)
536 }
537}
538
539impl<T: AsyncRead> tokio::io::AsyncRead for TokioAdapter<T> {
540 fn poll_read(
541 self: Pin<&mut Self>,
542 cx: &mut Context<'_>,
543 buf: &mut tokio::io::ReadBuf<'_>,
544 ) -> Poll<std::io::Result<()>> {
545 let slice = buf.initialize_unfilled();
546 let n = match self.project().inner.poll_read(cx, slice)? {
547 Poll::Pending => return Poll::Pending,
548 Poll::Ready(n) => n,
549 };
550 buf.advance(n);
551 Poll::Ready(Ok(()))
552 }
553}
554
555impl<T: AsyncWrite> tokio::io::AsyncWrite for TokioAdapter<T> {
556 fn poll_write(
557 self: Pin<&mut Self>,
558 cx: &mut Context<'_>,
559 buf: &[u8],
560 ) -> Poll<Result<usize, std::io::Error>> {
561 self.project().inner.poll_write(cx, buf)
562 }
563
564 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
565 self.project().inner.poll_flush(cx)
566 }
567
568 fn poll_shutdown(
569 self: Pin<&mut Self>,
570 cx: &mut Context<'_>,
571 ) -> Poll<Result<(), std::io::Error>> {
572 self.project().inner.poll_close(cx)
573 }
574}