Skip to main content

async_tungstenite/
tokio.rs

1//! `tokio` integration.
2use tungstenite::client::IntoClientRequest;
3use tungstenite::handshake::client::{Request, Response};
4use tungstenite::handshake::server::{Callback, NoCallback};
5use tungstenite::protocol::WebSocketConfig;
6use tungstenite::Error;
7
8use tokio::net::TcpStream;
9
10use super::{domain, port, WebSocketStream};
11
12use futures_io::{AsyncRead, AsyncWrite};
13
14#[cfg(feature = "tokio-native-tls")]
15#[path = "tokio/native_tls.rs"]
16mod tls;
17
18#[cfg(all(
19    any(
20        feature = "tokio-rustls-manual-roots",
21        feature = "tokio-rustls-native-certs",
22        feature = "tokio-rustls-platform-verifier",
23        feature = "tokio-rustls-webpki-roots"
24    ),
25    not(feature = "tokio-native-tls")
26))]
27#[path = "tokio/rustls.rs"]
28mod tls;
29
30#[cfg(all(
31    feature = "tokio-openssl",
32    not(any(
33        feature = "tokio-native-tls",
34        feature = "tokio-rustls-manual-roots",
35        feature = "tokio-rustls-native-certs",
36        feature = "tokio-rustls-platform-verifier",
37        feature = "tokio-rustls-webpki-roots"
38    ))
39))]
40#[path = "tokio/openssl.rs"]
41mod tls;
42
43#[cfg(all(
44    feature = "async-tls",
45    not(any(
46        feature = "tokio-native-tls",
47        feature = "tokio-rustls-manual-roots",
48        feature = "tokio-rustls-native-certs",
49        feature = "tokio-rustls-platform-verifier",
50        feature = "tokio-rustls-webpki-roots",
51        feature = "tokio-openssl"
52    ))
53))]
54#[path = "tokio/async_tls.rs"]
55mod tls;
56
57#[cfg(not(any(
58    feature = "tokio-native-tls",
59    feature = "tokio-rustls-manual-roots",
60    feature = "tokio-rustls-native-certs",
61    feature = "tokio-rustls-platform-verifier",
62    feature = "tokio-rustls-webpki-roots",
63    feature = "tokio-openssl",
64    feature = "async-tls"
65)))]
66#[path = "tokio/dummy_tls.rs"]
67mod tls;
68
69pub use self::tls::client_async_tls_with_connector_and_config;
70
71#[cfg(any(
72    feature = "tokio-native-tls",
73    feature = "tokio-rustls-manual-roots",
74    feature = "tokio-rustls-native-certs",
75    feature = "tokio-rustls-platform-verifier",
76    feature = "tokio-rustls-webpki-roots",
77    feature = "tokio-openssl",
78    feature = "async-tls"
79))]
80use self::tls::Connector;
81
82use self::tls::AutoStream;
83
84/// Creates a WebSocket handshake from a request and a stream.
85/// For convenience, the user may call this with a url string, a URL,
86/// or a `Request`. Calling with `Request` allows the user to add
87/// a WebSocket protocol or other custom headers.
88///
89/// Internally, this custom creates a handshake representation and returns
90/// a future representing the resolution of the WebSocket handshake. The
91/// returned future will resolve to either `WebSocketStream<S>` or `Error`
92/// depending on whether the handshake is successful.
93///
94/// This is typically used for clients who have already established, for
95/// example, a TCP connection to the remote server.
96pub async fn client_async<'a, R, S>(
97    request: R,
98    stream: S,
99) -> Result<(WebSocketStream<TokioAdapter<S>>, Response), Error>
100where
101    R: IntoClientRequest + Unpin,
102    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
103{
104    client_async_with_config(request, stream, None).await
105}
106
107/// The same as `client_async()` but the one can specify a websocket configuration.
108/// Please refer to `client_async()` for more details.
109pub async fn client_async_with_config<'a, R, S>(
110    request: R,
111    stream: S,
112    config: Option<WebSocketConfig>,
113) -> Result<(WebSocketStream<TokioAdapter<S>>, Response), Error>
114where
115    R: IntoClientRequest + Unpin,
116    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
117{
118    crate::client_async_with_config(request, TokioAdapter::new(stream), config).await
119}
120
121/// Accepts a new WebSocket connection with the provided stream.
122///
123/// This function will internally call `server::accept` to create a
124/// handshake representation and returns a future representing the
125/// resolution of the WebSocket handshake. The returned future will resolve
126/// to either `WebSocketStream<S>` or `Error` depending if it's successful
127/// or not.
128///
129/// This is typically used after a socket has been accepted from a
130/// `TcpListener`. That socket is then passed to this function to perform
131/// the server half of the accepting a client's websocket connection.
132pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<TokioAdapter<S>>, Error>
133where
134    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
135{
136    accept_hdr_async(stream, NoCallback).await
137}
138
139/// The same as `accept_async()` but the one can specify a websocket configuration.
140/// Please refer to `accept_async()` for more details.
141pub async fn accept_async_with_config<S>(
142    stream: S,
143    config: Option<WebSocketConfig>,
144) -> Result<WebSocketStream<TokioAdapter<S>>, Error>
145where
146    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
147{
148    accept_hdr_async_with_config(stream, NoCallback, config).await
149}
150
151/// Accepts a new WebSocket connection with the provided stream.
152///
153/// This function does the same as `accept_async()` but accepts an extra callback
154/// for header processing. The callback receives headers of the incoming
155/// requests and is able to add extra headers to the reply.
156pub async fn accept_hdr_async<S, C>(
157    stream: S,
158    callback: C,
159) -> Result<WebSocketStream<TokioAdapter<S>>, Error>
160where
161    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
162    C: Callback + Unpin,
163{
164    accept_hdr_async_with_config(stream, callback, None).await
165}
166
167/// The same as `accept_hdr_async()` but the one can specify a websocket configuration.
168/// Please refer to `accept_hdr_async()` for more details.
169pub async fn accept_hdr_async_with_config<S, C>(
170    stream: S,
171    callback: C,
172    config: Option<WebSocketConfig>,
173) -> Result<WebSocketStream<TokioAdapter<S>>, Error>
174where
175    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
176    C: Callback + Unpin,
177{
178    crate::accept_hdr_async_with_config(TokioAdapter::new(stream), callback, config).await
179}
180
181/// Type alias for the stream type of the `client_async()` functions.
182pub type ClientStream<S> = AutoStream<S>;
183
184#[cfg(any(
185    feature = "tokio-native-tls",
186    feature = "tokio-rustls-native-certs",
187        feature = "tokio-rustls-platform-verifier",
188    feature = "tokio-rustls-webpki-roots",
189    all(feature = "__rustls-tls", not(feature = "tokio-rustls-manual-roots")), // No roots will be available
190    all(feature = "async-tls", not(feature = "tokio-openssl"))
191))]
192/// Creates a WebSocket handshake from a request and a stream,
193/// upgrading the stream to TLS if required.
194pub async fn client_async_tls<R, S>(
195    request: R,
196    stream: S,
197) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
198where
199    R: IntoClientRequest + Unpin,
200    S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
201    AutoStream<S>: Unpin,
202{
203    client_async_tls_with_connector_and_config(request, stream, None, None).await
204}
205
206#[cfg(any(
207    feature = "tokio-native-tls",
208    feature = "tokio-rustls-native-certs",
209        feature = "tokio-rustls-platform-verifier",
210    feature = "tokio-rustls-webpki-roots",
211    all(feature = "__rustls-tls", not(feature = "tokio-rustls-manual-roots")), // No roots will be available
212    all(feature = "async-tls", not(feature = "tokio-openssl"))
213))]
214/// Creates a WebSocket handshake from a request and a stream,
215/// upgrading the stream to TLS if required and using the given
216/// WebSocket configuration.
217pub async fn client_async_tls_with_config<R, S>(
218    request: R,
219    stream: S,
220    config: Option<WebSocketConfig>,
221) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
222where
223    R: IntoClientRequest + Unpin,
224    S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
225    AutoStream<S>: Unpin,
226{
227    client_async_tls_with_connector_and_config(request, stream, None, config).await
228}
229
230#[cfg(any(
231    feature = "tokio-native-tls",
232    feature = "tokio-rustls-manual-roots",
233    feature = "tokio-rustls-native-certs",
234    feature = "tokio-rustls-platform-verifier",
235    feature = "tokio-rustls-webpki-roots",
236    all(feature = "async-tls", not(feature = "tokio-openssl"))
237))]
238/// Creates a WebSocket handshake from a request and a stream,
239/// upgrading the stream to TLS if required and using the given
240/// connector.
241pub async fn client_async_tls_with_connector<R, S>(
242    request: R,
243    stream: S,
244    connector: Option<Connector>,
245) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
246where
247    R: IntoClientRequest + Unpin,
248    S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
249    AutoStream<S>: Unpin,
250{
251    client_async_tls_with_connector_and_config(request, stream, connector, None).await
252}
253
254#[cfg(all(
255    feature = "tokio-openssl",
256    not(any(
257        feature = "tokio-native-tls",
258        feature = "tokio-rustls-manual-roots",
259        feature = "tokio-rustls-native-certs",
260        feature = "tokio-rustls-platform-verifier",
261        feature = "tokio-rustls-webpki-roots"
262    ))
263))]
264/// Creates a WebSocket handshake from a request and a stream,
265/// upgrading the stream to TLS if required.
266pub async fn client_async_tls<R, S>(
267    request: R,
268    stream: S,
269) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
270where
271    R: IntoClientRequest + Unpin,
272    S: 'static
273        + tokio::io::AsyncRead
274        + tokio::io::AsyncWrite
275        + Unpin
276        + std::fmt::Debug
277        + Send
278        + Sync,
279    AutoStream<S>: Unpin,
280{
281    client_async_tls_with_connector_and_config(request, stream, None, None).await
282}
283
284#[cfg(all(
285    feature = "tokio-openssl",
286    not(any(
287        feature = "tokio-native-tls",
288        feature = "tokio-rustls-manual-roots",
289        feature = "tokio-rustls-native-certs",
290        feature = "tokio-rustls-platform-verifier",
291        feature = "tokio-rustls-webpki-roots"
292    ))
293))]
294/// Creates a WebSocket handshake from a request and a stream,
295/// upgrading the stream to TLS if required and using the given
296/// WebSocket configuration.
297pub async fn client_async_tls_with_config<R, S>(
298    request: R,
299    stream: S,
300    config: Option<WebSocketConfig>,
301) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
302where
303    R: IntoClientRequest + Unpin,
304    S: 'static
305        + tokio::io::AsyncRead
306        + tokio::io::AsyncWrite
307        + Unpin
308        + std::fmt::Debug
309        + Send
310        + Sync,
311    AutoStream<S>: Unpin,
312{
313    client_async_tls_with_connector_and_config(request, stream, None, config).await
314}
315
316#[cfg(all(
317    feature = "tokio-openssl",
318    not(any(
319        feature = "tokio-native-tls",
320        feature = "tokio-rustls-manual-roots",
321        feature = "tokio-rustls-native-certs",
322        feature = "tokio-rustls-platform-verifier",
323        feature = "tokio-rustls-webpki-roots"
324    ))
325))]
326/// Creates a WebSocket handshake from a request and a stream,
327/// upgrading the stream to TLS if required and using the given
328/// connector.
329pub async fn client_async_tls_with_connector<R, S>(
330    request: R,
331    stream: S,
332    connector: Option<Connector>,
333) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
334where
335    R: IntoClientRequest + Unpin,
336    S: 'static
337        + tokio::io::AsyncRead
338        + tokio::io::AsyncWrite
339        + Unpin
340        + std::fmt::Debug
341        + Send
342        + Sync,
343    AutoStream<S>: Unpin,
344{
345    client_async_tls_with_connector_and_config(request, stream, connector, None).await
346}
347
348/// Type alias for the stream type of the `connect_async()` functions.
349pub type ConnectStream = ClientStream<TcpStream>;
350
351/// Connect to a given URL.
352///
353/// Accepts any request that implements [`IntoClientRequest`], which is often just `&str`, but can
354/// be a variety of types such as `httparse::Request` or [`tungstenite::http::Request`] for more
355/// complex uses.
356///
357/// ```no_run
358/// # use tungstenite::client::IntoClientRequest;
359///
360/// # async fn test() {
361/// use tungstenite::http::{Method, Request};
362/// use async_tungstenite::tokio::connect_async;
363///
364/// let mut request = "wss://api.example.com".into_client_request().unwrap();
365/// request.headers_mut().insert("api-key", "42".parse().unwrap());
366///
367/// let (stream, response) = connect_async(request).await.unwrap();
368/// # }
369/// ```
370pub async fn connect_async<R>(
371    request: R,
372) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
373where
374    R: IntoClientRequest + Unpin,
375{
376    connect_async_with_config(request, None).await
377}
378
379/// Connect to a given URL with a given WebSocket configuration.
380pub async fn connect_async_with_config<R>(
381    request: R,
382    config: Option<WebSocketConfig>,
383) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
384where
385    R: IntoClientRequest + Unpin,
386{
387    let request: Request = request.into_client_request()?;
388
389    let domain = domain(&request)?;
390    let port = port(&request)?;
391
392    let try_socket = TcpStream::connect((domain.as_str(), port)).await;
393    let socket = try_socket.map_err(Error::Io)?;
394    client_async_tls_with_connector_and_config(request, socket, None, config).await
395}
396
397#[cfg(any(
398    feature = "async-tls",
399    feature = "tokio-native-tls",
400    feature = "tokio-rustls-manual-roots",
401    feature = "tokio-rustls-native-certs",
402    feature = "tokio-rustls-platform-verifier",
403    feature = "tokio-rustls-webpki-roots",
404    feature = "tokio-openssl"
405))]
406/// Connect to a given URL using the provided TLS connector.
407pub async fn connect_async_with_tls_connector<R>(
408    request: R,
409    connector: Option<Connector>,
410) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
411where
412    R: IntoClientRequest + Unpin,
413{
414    connect_async_with_tls_connector_and_config(request, connector, None).await
415}
416
417#[cfg(any(
418    feature = "async-tls",
419    feature = "tokio-native-tls",
420    feature = "tokio-rustls-manual-roots",
421    feature = "tokio-rustls-native-certs",
422    feature = "tokio-rustls-platform-verifier",
423    feature = "tokio-rustls-webpki-roots",
424    feature = "tokio-openssl"
425))]
426/// Connect to a given URL using the provided TLS connector.
427pub async fn connect_async_with_tls_connector_and_config<R>(
428    request: R,
429    connector: Option<Connector>,
430    config: Option<WebSocketConfig>,
431) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
432where
433    R: IntoClientRequest + Unpin,
434{
435    let request: Request = request.into_client_request()?;
436
437    let domain = domain(&request)?;
438    let port = port(&request)?;
439
440    let try_socket = TcpStream::connect((domain.as_str(), port)).await;
441    let socket = try_socket.map_err(Error::Io)?;
442    client_async_tls_with_connector_and_config(request, socket, connector, config).await
443}
444
445use std::pin::Pin;
446use std::task::{Context, Poll};
447
448pin_project_lite::pin_project! {
449    /// Adapter for `tokio::io::AsyncRead` and `tokio::io::AsyncWrite` to provide
450    /// the variants from the `futures` crate and the other way around.
451    #[derive(Debug, Clone)]
452    pub struct TokioAdapter<T> {
453        #[pin]
454        inner: T,
455    }
456}
457
458impl<T> TokioAdapter<T> {
459    /// Creates a new `TokioAdapter` wrapping the provided value.
460    pub fn new(inner: T) -> Self {
461        Self { inner }
462    }
463
464    /// Consumes this `TokioAdapter`, returning the underlying value.
465    pub fn into_inner(self) -> T {
466        self.inner
467    }
468
469    /// Get a reference to the underlying value.
470    pub fn get_ref(&self) -> &T {
471        &self.inner
472    }
473
474    /// Get a mutable reference to the underlying value.
475    pub fn get_mut(&mut self) -> &mut T {
476        &mut self.inner
477    }
478}
479
480impl<T: tokio::io::AsyncRead> AsyncRead for TokioAdapter<T> {
481    fn poll_read(
482        self: Pin<&mut Self>,
483        cx: &mut Context<'_>,
484        buf: &mut [u8],
485    ) -> Poll<std::io::Result<usize>> {
486        let mut buf = tokio::io::ReadBuf::new(buf);
487        match self.project().inner.poll_read(cx, &mut buf)? {
488            Poll::Pending => Poll::Pending,
489            Poll::Ready(_) => Poll::Ready(Ok(buf.filled().len())),
490        }
491    }
492}
493
494impl<T: tokio::io::AsyncWrite> AsyncWrite for TokioAdapter<T> {
495    fn poll_write(
496        self: Pin<&mut Self>,
497        cx: &mut Context<'_>,
498        buf: &[u8],
499    ) -> Poll<Result<usize, std::io::Error>> {
500        self.project().inner.poll_write(cx, buf)
501    }
502
503    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
504        self.project().inner.poll_flush(cx)
505    }
506
507    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
508        self.project().inner.poll_shutdown(cx)
509    }
510}
511
512impl<T: AsyncRead> tokio::io::AsyncRead for TokioAdapter<T> {
513    fn poll_read(
514        self: Pin<&mut Self>,
515        cx: &mut Context<'_>,
516        buf: &mut tokio::io::ReadBuf<'_>,
517    ) -> Poll<std::io::Result<()>> {
518        let slice = buf.initialize_unfilled();
519        let n = match self.project().inner.poll_read(cx, slice)? {
520            Poll::Pending => return Poll::Pending,
521            Poll::Ready(n) => n,
522        };
523        buf.advance(n);
524        Poll::Ready(Ok(()))
525    }
526}
527
528impl<T: AsyncWrite> tokio::io::AsyncWrite for TokioAdapter<T> {
529    fn poll_write(
530        self: Pin<&mut Self>,
531        cx: &mut Context<'_>,
532        buf: &[u8],
533    ) -> Poll<Result<usize, std::io::Error>> {
534        self.project().inner.poll_write(cx, buf)
535    }
536
537    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
538        self.project().inner.poll_flush(cx)
539    }
540
541    fn poll_shutdown(
542        self: Pin<&mut Self>,
543        cx: &mut Context<'_>,
544    ) -> Poll<Result<(), std::io::Error>> {
545        self.project().inner.poll_close(cx)
546    }
547}