async_tungstenite/
tokio.rs

1//! `tokio` integration.
2use tungstenite::client::IntoClientRequest;
3use tungstenite::handshake::client::{Request, Response};
4use tungstenite::handshake::server::{Callback, NoCallback};
5use tungstenite::protocol::WebSocketConfig;
6use tungstenite::Error;
7
8use tokio::net::TcpStream;
9
10use super::{domain, port, WebSocketStream};
11
12use futures_io::{AsyncRead, AsyncWrite};
13
14#[cfg(feature = "tokio-native-tls")]
15#[path = "tokio/native_tls.rs"]
16mod tls;
17
18#[cfg(all(
19    any(
20        feature = "tokio-rustls-manual-roots",
21        feature = "tokio-rustls-native-certs",
22        feature = "tokio-rustls-platform-verifier",
23        feature = "tokio-rustls-webpki-roots"
24    ),
25    not(feature = "tokio-native-tls")
26))]
27#[path = "tokio/rustls.rs"]
28mod tls;
29
30#[cfg(all(
31    feature = "tokio-openssl",
32    not(any(
33        feature = "tokio-native-tls",
34        feature = "tokio-rustls-manual-roots",
35        feature = "tokio-rustls-native-certs",
36        feature = "tokio-rustls-platform-verifier",
37        feature = "tokio-rustls-webpki-roots"
38    ))
39))]
40#[path = "tokio/openssl.rs"]
41mod tls;
42
43#[cfg(all(
44    feature = "async-tls",
45    not(any(
46        feature = "tokio-native-tls",
47        feature = "tokio-rustls-manual-roots",
48        feature = "tokio-rustls-native-certs",
49        feature = "tokio-rustls-platform-verifier",
50        feature = "tokio-rustls-webpki-roots",
51        feature = "tokio-openssl"
52    ))
53))]
54#[path = "tokio/async_tls.rs"]
55mod tls;
56
57#[cfg(not(any(
58    feature = "tokio-native-tls",
59    feature = "tokio-rustls-manual-roots",
60    feature = "tokio-rustls-native-certs",
61    feature = "tokio-rustls-platform-verifier",
62    feature = "tokio-rustls-webpki-roots",
63    feature = "tokio-openssl",
64    feature = "async-tls"
65)))]
66#[path = "tokio/dummy_tls.rs"]
67mod tls;
68
69#[cfg(any(
70    feature = "tokio-native-tls",
71    feature = "tokio-rustls-manual-roots",
72    feature = "tokio-rustls-native-certs",
73    feature = "tokio-rustls-platform-verifier",
74    feature = "tokio-rustls-webpki-roots",
75    feature = "tokio-openssl",
76    feature = "async-tls",
77))]
78pub use self::tls::client_async_tls_with_connector_and_config;
79#[cfg(any(
80    feature = "tokio-native-tls",
81    feature = "tokio-rustls-manual-roots",
82    feature = "tokio-rustls-native-certs",
83    feature = "tokio-rustls-platform-verifier",
84    feature = "tokio-rustls-webpki-roots",
85    feature = "tokio-openssl",
86    feature = "async-tls"
87))]
88use self::tls::{AutoStream, Connector};
89
90#[cfg(not(any(
91    feature = "tokio-native-tls",
92    feature = "tokio-rustls-manual-roots",
93    feature = "tokio-rustls-native-certs",
94    feature = "tokio-rustls-platform-verifier",
95    feature = "tokio-rustls-webpki-roots",
96    feature = "tokio-openssl",
97    feature = "async-tls"
98)))]
99pub use self::tls::client_async_tls_with_connector_and_config;
100#[cfg(not(any(
101    feature = "tokio-native-tls",
102    feature = "tokio-rustls-manual-roots",
103    feature = "tokio-rustls-native-certs",
104    feature = "tokio-rustls-platform-verifier",
105    feature = "tokio-rustls-webpki-roots",
106    feature = "tokio-openssl",
107    feature = "async-tls"
108)))]
109use self::tls::AutoStream;
110
111/// Creates a WebSocket handshake from a request and a stream.
112/// For convenience, the user may call this with a url string, a URL,
113/// or a `Request`. Calling with `Request` allows the user to add
114/// a WebSocket protocol or other custom headers.
115///
116/// Internally, this custom creates a handshake representation and returns
117/// a future representing the resolution of the WebSocket handshake. The
118/// returned future will resolve to either `WebSocketStream<S>` or `Error`
119/// depending on whether the handshake is successful.
120///
121/// This is typically used for clients who have already established, for
122/// example, a TCP connection to the remote server.
123pub async fn client_async<'a, R, S>(
124    request: R,
125    stream: S,
126) -> Result<(WebSocketStream<TokioAdapter<S>>, Response), Error>
127where
128    R: IntoClientRequest + Unpin,
129    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
130{
131    client_async_with_config(request, stream, None).await
132}
133
134/// The same as `client_async()` but the one can specify a websocket configuration.
135/// Please refer to `client_async()` for more details.
136pub async fn client_async_with_config<'a, R, S>(
137    request: R,
138    stream: S,
139    config: Option<WebSocketConfig>,
140) -> Result<(WebSocketStream<TokioAdapter<S>>, Response), Error>
141where
142    R: IntoClientRequest + Unpin,
143    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
144{
145    crate::client_async_with_config(request, TokioAdapter::new(stream), config).await
146}
147
148/// Accepts a new WebSocket connection with the provided stream.
149///
150/// This function will internally call `server::accept` to create a
151/// handshake representation and returns a future representing the
152/// resolution of the WebSocket handshake. The returned future will resolve
153/// to either `WebSocketStream<S>` or `Error` depending if it's successful
154/// or not.
155///
156/// This is typically used after a socket has been accepted from a
157/// `TcpListener`. That socket is then passed to this function to perform
158/// the server half of the accepting a client's websocket connection.
159pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<TokioAdapter<S>>, Error>
160where
161    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
162{
163    accept_hdr_async(stream, NoCallback).await
164}
165
166/// The same as `accept_async()` but the one can specify a websocket configuration.
167/// Please refer to `accept_async()` for more details.
168pub async fn accept_async_with_config<S>(
169    stream: S,
170    config: Option<WebSocketConfig>,
171) -> Result<WebSocketStream<TokioAdapter<S>>, Error>
172where
173    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
174{
175    accept_hdr_async_with_config(stream, NoCallback, config).await
176}
177
178/// Accepts a new WebSocket connection with the provided stream.
179///
180/// This function does the same as `accept_async()` but accepts an extra callback
181/// for header processing. The callback receives headers of the incoming
182/// requests and is able to add extra headers to the reply.
183pub async fn accept_hdr_async<S, C>(
184    stream: S,
185    callback: C,
186) -> Result<WebSocketStream<TokioAdapter<S>>, Error>
187where
188    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
189    C: Callback + Unpin,
190{
191    accept_hdr_async_with_config(stream, callback, None).await
192}
193
194/// The same as `accept_hdr_async()` but the one can specify a websocket configuration.
195/// Please refer to `accept_hdr_async()` for more details.
196pub async fn accept_hdr_async_with_config<S, C>(
197    stream: S,
198    callback: C,
199    config: Option<WebSocketConfig>,
200) -> Result<WebSocketStream<TokioAdapter<S>>, Error>
201where
202    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
203    C: Callback + Unpin,
204{
205    crate::accept_hdr_async_with_config(TokioAdapter::new(stream), callback, config).await
206}
207
208/// Type alias for the stream type of the `client_async()` functions.
209pub type ClientStream<S> = AutoStream<S>;
210
211#[cfg(any(
212    feature = "tokio-native-tls",
213    feature = "tokio-rustls-native-certs",
214        feature = "tokio-rustls-platform-verifier",
215    feature = "tokio-rustls-webpki-roots",
216    all(feature = "__rustls-tls", not(feature = "tokio-rustls-manual-roots")), // No roots will be available
217    all(feature = "async-tls", not(feature = "tokio-openssl"))
218))]
219/// Creates a WebSocket handshake from a request and a stream,
220/// upgrading the stream to TLS if required.
221pub async fn client_async_tls<R, S>(
222    request: R,
223    stream: S,
224) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
225where
226    R: IntoClientRequest + Unpin,
227    S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
228    AutoStream<S>: Unpin,
229{
230    client_async_tls_with_connector_and_config(request, stream, None, None).await
231}
232
233#[cfg(any(
234    feature = "tokio-native-tls",
235    feature = "tokio-rustls-native-certs",
236        feature = "tokio-rustls-platform-verifier",
237    feature = "tokio-rustls-webpki-roots",
238    all(feature = "__rustls-tls", not(feature = "tokio-rustls-manual-roots")), // No roots will be available
239    all(feature = "async-tls", not(feature = "tokio-openssl"))
240))]
241/// Creates a WebSocket handshake from a request and a stream,
242/// upgrading the stream to TLS if required and using the given
243/// WebSocket configuration.
244pub async fn client_async_tls_with_config<R, S>(
245    request: R,
246    stream: S,
247    config: Option<WebSocketConfig>,
248) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
249where
250    R: IntoClientRequest + Unpin,
251    S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
252    AutoStream<S>: Unpin,
253{
254    client_async_tls_with_connector_and_config(request, stream, None, config).await
255}
256
257#[cfg(any(
258    feature = "tokio-native-tls",
259    feature = "tokio-rustls-manual-roots",
260    feature = "tokio-rustls-native-certs",
261    feature = "tokio-rustls-platform-verifier",
262    feature = "tokio-rustls-webpki-roots",
263    all(feature = "async-tls", not(feature = "tokio-openssl"))
264))]
265/// Creates a WebSocket handshake from a request and a stream,
266/// upgrading the stream to TLS if required and using the given
267/// connector.
268pub async fn client_async_tls_with_connector<R, S>(
269    request: R,
270    stream: S,
271    connector: Option<Connector>,
272) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
273where
274    R: IntoClientRequest + Unpin,
275    S: 'static + tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
276    AutoStream<S>: Unpin,
277{
278    client_async_tls_with_connector_and_config(request, stream, connector, None).await
279}
280
281#[cfg(all(
282    feature = "tokio-openssl",
283    not(any(
284        feature = "tokio-native-tls",
285        feature = "tokio-rustls-manual-roots",
286        feature = "tokio-rustls-native-certs",
287        feature = "tokio-rustls-platform-verifier",
288        feature = "tokio-rustls-webpki-roots"
289    ))
290))]
291/// Creates a WebSocket handshake from a request and a stream,
292/// upgrading the stream to TLS if required.
293pub async fn client_async_tls<R, S>(
294    request: R,
295    stream: S,
296) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
297where
298    R: IntoClientRequest + Unpin,
299    S: 'static
300        + tokio::io::AsyncRead
301        + tokio::io::AsyncWrite
302        + Unpin
303        + std::fmt::Debug
304        + Send
305        + Sync,
306    AutoStream<S>: Unpin,
307{
308    client_async_tls_with_connector_and_config(request, stream, None, None).await
309}
310
311#[cfg(all(
312    feature = "tokio-openssl",
313    not(any(
314        feature = "tokio-native-tls",
315        feature = "tokio-rustls-manual-roots",
316        feature = "tokio-rustls-native-certs",
317        feature = "tokio-rustls-platform-verifier",
318        feature = "tokio-rustls-webpki-roots"
319    ))
320))]
321/// Creates a WebSocket handshake from a request and a stream,
322/// upgrading the stream to TLS if required and using the given
323/// WebSocket configuration.
324pub async fn client_async_tls_with_config<R, S>(
325    request: R,
326    stream: S,
327    config: Option<WebSocketConfig>,
328) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
329where
330    R: IntoClientRequest + Unpin,
331    S: 'static
332        + tokio::io::AsyncRead
333        + tokio::io::AsyncWrite
334        + Unpin
335        + std::fmt::Debug
336        + Send
337        + Sync,
338    AutoStream<S>: Unpin,
339{
340    client_async_tls_with_connector_and_config(request, stream, None, config).await
341}
342
343#[cfg(all(
344    feature = "tokio-openssl",
345    not(any(
346        feature = "tokio-native-tls",
347        feature = "tokio-rustls-manual-roots",
348        feature = "tokio-rustls-native-certs",
349        feature = "tokio-rustls-platform-verifier",
350        feature = "tokio-rustls-webpki-roots"
351    ))
352))]
353/// Creates a WebSocket handshake from a request and a stream,
354/// upgrading the stream to TLS if required and using the given
355/// connector.
356pub async fn client_async_tls_with_connector<R, S>(
357    request: R,
358    stream: S,
359    connector: Option<Connector>,
360) -> Result<(WebSocketStream<ClientStream<S>>, Response), Error>
361where
362    R: IntoClientRequest + Unpin,
363    S: 'static
364        + tokio::io::AsyncRead
365        + tokio::io::AsyncWrite
366        + Unpin
367        + std::fmt::Debug
368        + Send
369        + Sync,
370    AutoStream<S>: Unpin,
371{
372    client_async_tls_with_connector_and_config(request, stream, connector, None).await
373}
374
375/// Type alias for the stream type of the `connect_async()` functions.
376pub type ConnectStream = ClientStream<TcpStream>;
377
378/// Connect to a given URL.
379///
380/// Accepts any request that implements [`IntoClientRequest`], which is often just `&str`, but can
381/// be a variety of types such as `httparse::Request` or [`tungstenite::http::Request`] for more
382/// complex uses.
383///
384/// ```no_run
385/// # use tungstenite::client::IntoClientRequest;
386///
387/// # async fn test() {
388/// use tungstenite::http::{Method, Request};
389/// use async_tungstenite::tokio::connect_async;
390///
391/// let mut request = "wss://api.example.com".into_client_request().unwrap();
392/// request.headers_mut().insert("api-key", "42".parse().unwrap());
393///
394/// let (stream, response) = connect_async(request).await.unwrap();
395/// # }
396/// ```
397pub async fn connect_async<R>(
398    request: R,
399) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
400where
401    R: IntoClientRequest + Unpin,
402{
403    connect_async_with_config(request, None).await
404}
405
406/// Connect to a given URL with a given WebSocket configuration.
407pub async fn connect_async_with_config<R>(
408    request: R,
409    config: Option<WebSocketConfig>,
410) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
411where
412    R: IntoClientRequest + Unpin,
413{
414    let request: Request = request.into_client_request()?;
415
416    let domain = domain(&request)?;
417    let port = port(&request)?;
418
419    let try_socket = TcpStream::connect((domain.as_str(), port)).await;
420    let socket = try_socket.map_err(Error::Io)?;
421    client_async_tls_with_connector_and_config(request, socket, None, config).await
422}
423
424#[cfg(any(
425    feature = "async-tls",
426    feature = "tokio-native-tls",
427    feature = "tokio-rustls-manual-roots",
428    feature = "tokio-rustls-native-certs",
429    feature = "tokio-rustls-platform-verifier",
430    feature = "tokio-rustls-webpki-roots",
431    feature = "tokio-openssl"
432))]
433/// Connect to a given URL using the provided TLS connector.
434pub async fn connect_async_with_tls_connector<R>(
435    request: R,
436    connector: Option<Connector>,
437) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
438where
439    R: IntoClientRequest + Unpin,
440{
441    connect_async_with_tls_connector_and_config(request, connector, None).await
442}
443
444#[cfg(any(
445    feature = "async-tls",
446    feature = "tokio-native-tls",
447    feature = "tokio-rustls-manual-roots",
448    feature = "tokio-rustls-native-certs",
449    feature = "tokio-rustls-platform-verifier",
450    feature = "tokio-rustls-webpki-roots",
451    feature = "tokio-openssl"
452))]
453/// Connect to a given URL using the provided TLS connector.
454pub async fn connect_async_with_tls_connector_and_config<R>(
455    request: R,
456    connector: Option<Connector>,
457    config: Option<WebSocketConfig>,
458) -> Result<(WebSocketStream<ConnectStream>, Response), Error>
459where
460    R: IntoClientRequest + Unpin,
461{
462    let request: Request = request.into_client_request()?;
463
464    let domain = domain(&request)?;
465    let port = port(&request)?;
466
467    let try_socket = TcpStream::connect((domain.as_str(), port)).await;
468    let socket = try_socket.map_err(Error::Io)?;
469    client_async_tls_with_connector_and_config(request, socket, connector, config).await
470}
471
472use std::pin::Pin;
473use std::task::{Context, Poll};
474
475pin_project_lite::pin_project! {
476    /// Adapter for `tokio::io::AsyncRead` and `tokio::io::AsyncWrite` to provide
477    /// the variants from the `futures` crate and the other way around.
478    #[derive(Debug, Clone)]
479    pub struct TokioAdapter<T> {
480        #[pin]
481        inner: T,
482    }
483}
484
485impl<T> TokioAdapter<T> {
486    /// Creates a new `TokioAdapter` wrapping the provided value.
487    pub fn new(inner: T) -> Self {
488        Self { inner }
489    }
490
491    /// Consumes this `TokioAdapter`, returning the underlying value.
492    pub fn into_inner(self) -> T {
493        self.inner
494    }
495
496    /// Get a reference to the underlying value.
497    pub fn get_ref(&self) -> &T {
498        &self.inner
499    }
500
501    /// Get a mutable reference to the underlying value.
502    pub fn get_mut(&mut self) -> &mut T {
503        &mut self.inner
504    }
505}
506
507impl<T: tokio::io::AsyncRead> AsyncRead for TokioAdapter<T> {
508    fn poll_read(
509        self: Pin<&mut Self>,
510        cx: &mut Context<'_>,
511        buf: &mut [u8],
512    ) -> Poll<std::io::Result<usize>> {
513        let mut buf = tokio::io::ReadBuf::new(buf);
514        match self.project().inner.poll_read(cx, &mut buf)? {
515            Poll::Pending => Poll::Pending,
516            Poll::Ready(_) => Poll::Ready(Ok(buf.filled().len())),
517        }
518    }
519}
520
521impl<T: tokio::io::AsyncWrite> AsyncWrite for TokioAdapter<T> {
522    fn poll_write(
523        self: Pin<&mut Self>,
524        cx: &mut Context<'_>,
525        buf: &[u8],
526    ) -> Poll<Result<usize, std::io::Error>> {
527        self.project().inner.poll_write(cx, buf)
528    }
529
530    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
531        self.project().inner.poll_flush(cx)
532    }
533
534    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
535        self.project().inner.poll_shutdown(cx)
536    }
537}
538
539impl<T: AsyncRead> tokio::io::AsyncRead for TokioAdapter<T> {
540    fn poll_read(
541        self: Pin<&mut Self>,
542        cx: &mut Context<'_>,
543        buf: &mut tokio::io::ReadBuf<'_>,
544    ) -> Poll<std::io::Result<()>> {
545        let slice = buf.initialize_unfilled();
546        let n = match self.project().inner.poll_read(cx, slice)? {
547            Poll::Pending => return Poll::Pending,
548            Poll::Ready(n) => n,
549        };
550        buf.advance(n);
551        Poll::Ready(Ok(()))
552    }
553}
554
555impl<T: AsyncWrite> tokio::io::AsyncWrite for TokioAdapter<T> {
556    fn poll_write(
557        self: Pin<&mut Self>,
558        cx: &mut Context<'_>,
559        buf: &[u8],
560    ) -> Poll<Result<usize, std::io::Error>> {
561        self.project().inner.poll_write(cx, buf)
562    }
563
564    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
565        self.project().inner.poll_flush(cx)
566    }
567
568    fn poll_shutdown(
569        self: Pin<&mut Self>,
570        cx: &mut Context<'_>,
571    ) -> Poll<Result<(), std::io::Error>> {
572        self.project().inner.poll_close(cx)
573    }
574}