Skip to main content

async_tungstenite/
lib.rs

1//! Async WebSockets.
2//!
3//! This crate is based on [tungstenite](https://crates.io/crates/tungstenite)
4//! Rust WebSocket library and provides async bindings and wrappers for it, so you
5//! can use it with non-blocking/asynchronous `TcpStream`s from and couple it
6//! together with other crates from the async stack. In addition, optional
7//! integration with various other crates can be enabled via feature flags
8//!
9//!  * `async-tls`: Enables the `async_tls` module, which provides integration
10//!    with the [async-tls](https://crates.io/crates/async-tls) TLS stack and can
11//!    be used independent of any async runtime.
12//!  * `async-std-runtime`: Enables the `async_std` module, which provides
13//!    integration with the [async-std](https://async.rs) runtime.
14//!  * `async-native-tls`: Enables the additional functions in the `async_std`
15//!    module to implement TLS via
16//!    [async-native-tls](https://crates.io/crates/async-native-tls).
17//!  * `tokio-runtime`: Enables the `tokio` module, which provides integration
18//!    with the [tokio](https://tokio.rs) runtime.
19//!  * `tokio-native-tls`: Enables the additional functions in the `tokio` module to
20//!    implement TLS via [tokio-native-tls](https://crates.io/crates/tokio-native-tls).
21//!  * `tokio-rustls-native-certs`: Enables the additional functions in the `tokio`
22//!    module to implement TLS via [tokio-rustls](https://crates.io/crates/tokio-rustls)
23//!    and uses native system certificates found with
24//!    [rustls-native-certs](https://github.com/rustls/rustls-native-certs).
25//!  * `tokio-rustls-platform-verifier`: Enables the additional functions in the `tokio`
26//!    module to implement TLS via [tokio-rustls](https://crates.io/crates/tokio-rustls)
27//!    and uses native system certificates via the platform verifier found with
28//!    [rustls-platform-verifier](https://github.com/rustls/rustls-platform-verifier).
29//!  * `tokio-rustls-webpki-roots`: Enables the additional functions in the `tokio`
30//!    module to implement TLS via [tokio-rustls](https://crates.io/crates/tokio-rustls)
31//!    and uses the certificates [webpki-roots](https://github.com/rustls/webpki-roots)
32//!    provides.
33//!  * `tokio-openssl`: Enables the additional functions in the `tokio` module to
34//!    implement TLS via [tokio-openssl](https://crates.io/crates/tokio-openssl).
35//!  * `gio-runtime`: Enables the `gio` module, which provides integration with
36//!    the [gio](https://www.gtk-rs.org) runtime.
37//!
38//! Each WebSocket stream implements the required `Stream` and `Sink` traits,
39//! making the socket a stream of WebSocket messages coming in and going out.
40
41#![deny(
42    missing_docs,
43    unused_must_use,
44    unused_mut,
45    unused_imports,
46    unused_import_braces
47)]
48
49pub use tungstenite;
50
51mod compat;
52mod handshake;
53
54#[cfg(any(
55    feature = "async-tls",
56    feature = "async-native-tls",
57    feature = "smol-native-tls",
58    feature = "futures-rustls-manual-roots",
59    feature = "futures-rustls-webpki-roots",
60    feature = "futures-rustls-native-certs",
61    feature = "futures-rustls-platform-verifier",
62    feature = "tokio-native-tls",
63    feature = "tokio-rustls-manual-roots",
64    feature = "tokio-rustls-native-certs",
65    feature = "tokio-rustls-platform-verifier",
66    feature = "tokio-rustls-webpki-roots",
67    feature = "tokio-openssl",
68))]
69pub mod stream;
70
71use std::{
72    io::{Read, Write},
73    pin::Pin,
74    sync::{Arc, Mutex, MutexGuard},
75    task::{ready, Context, Poll},
76};
77
78use compat::{cvt, AllowStd, ContextWaker};
79use futures_core::stream::{FusedStream, Stream};
80use futures_io::{AsyncRead, AsyncWrite};
81use log::*;
82
83#[cfg(feature = "handshake")]
84use tungstenite::{
85    client::IntoClientRequest,
86    handshake::{
87        client::{ClientHandshake, Response},
88        server::{Callback, NoCallback},
89        HandshakeError,
90    },
91};
92use tungstenite::{
93    error::Error as WsError,
94    protocol::{Message, Role, WebSocket, WebSocketConfig},
95};
96
97#[cfg(feature = "async-std-runtime")]
98#[deprecated = "async-std is unmaintained upstream. Please use the smol runtime instead."]
99pub mod async_std;
100#[cfg(feature = "async-tls")]
101pub mod async_tls;
102#[cfg(feature = "gio-runtime")]
103pub mod gio;
104#[cfg(feature = "smol-runtime")]
105pub mod smol;
106#[cfg(feature = "tokio-runtime")]
107pub mod tokio;
108
109pub mod bytes;
110pub use bytes::ByteReader;
111pub use bytes::ByteWriter;
112
113use tungstenite::protocol::CloseFrame;
114
115/// Creates a WebSocket handshake from a request and a stream.
116/// For convenience, the user may call this with a url string, a URL,
117/// or a `Request`. Calling with `Request` allows the user to add
118/// a WebSocket protocol or other custom headers.
119///
120/// Internally, this custom creates a handshake representation and returns
121/// a future representing the resolution of the WebSocket handshake. The
122/// returned future will resolve to either `WebSocketStream<S>` or `Error`
123/// depending on whether the handshake is successful.
124///
125/// This is typically used for clients who have already established, for
126/// example, a TCP connection to the remote server.
127#[cfg(feature = "handshake")]
128pub async fn client_async<'a, R, S>(
129    request: R,
130    stream: S,
131) -> Result<(WebSocketStream<S>, Response), WsError>
132where
133    R: IntoClientRequest + Unpin,
134    S: AsyncRead + AsyncWrite + Unpin,
135{
136    client_async_with_config(request, stream, None).await
137}
138
139/// The same as `client_async()` but the one can specify a websocket configuration.
140/// Please refer to `client_async()` for more details.
141#[cfg(feature = "handshake")]
142pub async fn client_async_with_config<'a, R, S>(
143    request: R,
144    stream: S,
145    config: Option<WebSocketConfig>,
146) -> Result<(WebSocketStream<S>, Response), WsError>
147where
148    R: IntoClientRequest + Unpin,
149    S: AsyncRead + AsyncWrite + Unpin,
150{
151    let f = handshake::client_handshake(stream, move |allow_std| {
152        let request = request.into_client_request()?;
153        let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
154        cli_handshake.handshake()
155    });
156    f.await.map_err(|e| match e {
157        HandshakeError::Failure(e) => e,
158        e => WsError::Io(std::io::Error::new(
159            std::io::ErrorKind::Other,
160            e.to_string(),
161        )),
162    })
163}
164
165/// Accepts a new WebSocket connection with the provided stream.
166///
167/// This function will internally call `server::accept` to create a
168/// handshake representation and returns a future representing the
169/// resolution of the WebSocket handshake. The returned future will resolve
170/// to either `WebSocketStream<S>` or `Error` depending if it's successful
171/// or not.
172///
173/// This is typically used after a socket has been accepted from a
174/// `TcpListener`. That socket is then passed to this function to perform
175/// the server half of the accepting a client's websocket connection.
176#[cfg(feature = "handshake")]
177pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
178where
179    S: AsyncRead + AsyncWrite + Unpin,
180{
181    accept_hdr_async(stream, NoCallback).await
182}
183
184/// The same as `accept_async()` but the one can specify a websocket configuration.
185/// Please refer to `accept_async()` for more details.
186#[cfg(feature = "handshake")]
187pub async fn accept_async_with_config<S>(
188    stream: S,
189    config: Option<WebSocketConfig>,
190) -> Result<WebSocketStream<S>, WsError>
191where
192    S: AsyncRead + AsyncWrite + Unpin,
193{
194    accept_hdr_async_with_config(stream, NoCallback, config).await
195}
196
197/// Accepts a new WebSocket connection with the provided stream.
198///
199/// This function does the same as `accept_async()` but accepts an extra callback
200/// for header processing. The callback receives headers of the incoming
201/// requests and is able to add extra headers to the reply.
202#[cfg(feature = "handshake")]
203pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
204where
205    S: AsyncRead + AsyncWrite + Unpin,
206    C: Callback + Unpin,
207{
208    accept_hdr_async_with_config(stream, callback, None).await
209}
210
211/// The same as `accept_hdr_async()` but the one can specify a websocket configuration.
212/// Please refer to `accept_hdr_async()` for more details.
213#[cfg(feature = "handshake")]
214pub async fn accept_hdr_async_with_config<S, C>(
215    stream: S,
216    callback: C,
217    config: Option<WebSocketConfig>,
218) -> Result<WebSocketStream<S>, WsError>
219where
220    S: AsyncRead + AsyncWrite + Unpin,
221    C: Callback + Unpin,
222{
223    let f = handshake::server_handshake(stream, move |allow_std| {
224        tungstenite::accept_hdr_with_config(allow_std, callback, config)
225    });
226    f.await.map_err(|e| match e {
227        HandshakeError::Failure(e) => e,
228        e => WsError::Io(std::io::Error::new(
229            std::io::ErrorKind::Other,
230            e.to_string(),
231        )),
232    })
233}
234
235/// A wrapper around an underlying raw stream which implements the WebSocket
236/// protocol.
237///
238/// A `WebSocketStream<S>` represents a handshake that has been completed
239/// successfully and both the server and the client are ready for receiving
240/// and sending data. Message from a `WebSocketStream<S>` are accessible
241/// through the respective `Stream` and `Sink`. Check more information about
242/// them in `futures-rs` crate documentation or have a look on the examples
243/// and unit tests for this crate.
244#[derive(Debug)]
245pub struct WebSocketStream<S> {
246    inner: WebSocket<AllowStd<S>>,
247    #[cfg(feature = "futures-03-sink")]
248    closing: bool,
249    ended: bool,
250    /// Tungstenite is probably ready to receive more data.
251    ///
252    /// `false` once start_send hits `WouldBlock` errors.
253    /// `true` initially and after `flush`ing.
254    ready: bool,
255}
256
257impl<S> WebSocketStream<S> {
258    /// Convert a raw socket into a WebSocketStream without performing a
259    /// handshake.
260    pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
261    where
262        S: AsyncRead + AsyncWrite + Unpin,
263    {
264        handshake::without_handshake(stream, move |allow_std| {
265            WebSocket::from_raw_socket(allow_std, role, config)
266        })
267        .await
268    }
269
270    /// Convert a raw socket into a WebSocketStream without performing a
271    /// handshake.
272    pub async fn from_partially_read(
273        stream: S,
274        part: Vec<u8>,
275        role: Role,
276        config: Option<WebSocketConfig>,
277    ) -> Self
278    where
279        S: AsyncRead + AsyncWrite + Unpin,
280    {
281        handshake::without_handshake(stream, move |allow_std| {
282            WebSocket::from_partially_read(allow_std, part, role, config)
283        })
284        .await
285    }
286
287    pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
288        Self {
289            inner: ws,
290            #[cfg(feature = "futures-03-sink")]
291            closing: false,
292            ended: false,
293            ready: true,
294        }
295    }
296
297    fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
298    where
299        F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
300        AllowStd<S>: Read + Write,
301    {
302        #[cfg(feature = "verbose-logging")]
303        trace!("{}:{} WebSocketStream.with_context", file!(), line!());
304        if let Some((kind, ctx)) = ctx {
305            self.inner.get_mut().set_waker(kind, ctx.waker());
306        }
307        f(&mut self.inner)
308    }
309
310    /// Consumes the `WebSocketStream` and returns the underlying stream.
311    pub fn into_inner(self) -> S {
312        self.inner.into_inner().into_inner()
313    }
314
315    /// Returns a shared reference to the inner stream.
316    pub fn get_ref(&self) -> &S
317    where
318        S: AsyncRead + AsyncWrite + Unpin,
319    {
320        self.inner.get_ref().get_ref()
321    }
322
323    /// Returns a mutable reference to the inner stream.
324    pub fn get_mut(&mut self) -> &mut S
325    where
326        S: AsyncRead + AsyncWrite + Unpin,
327    {
328        self.inner.get_mut().get_mut()
329    }
330
331    /// Returns a reference to the configuration of the tungstenite stream.
332    pub fn get_config(&self) -> &WebSocketConfig {
333        self.inner.get_config()
334    }
335
336    /// Close the underlying web socket
337    pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
338    where
339        S: AsyncRead + AsyncWrite + Unpin,
340    {
341        self.send(Message::Close(msg)).await
342    }
343
344    /// Splits the websocket stream into separate
345    /// [sender](WebSocketSender) and [receiver](WebSocketReceiver) parts.
346    pub fn split(self) -> (WebSocketSender<S>, WebSocketReceiver<S>) {
347        let shared = Arc::new(Shared(Mutex::new(self)));
348        let sender = WebSocketSender {
349            shared: shared.clone(),
350        };
351
352        let receiver = WebSocketReceiver { shared };
353        (sender, receiver)
354    }
355
356    /// Attempts to reunite the [sender](WebSocketSender) and [receiver](WebSocketReceiver)
357    /// parts back into a single stream. If both parts originate from the same
358    /// [`split`](WebSocketStream::split) call, returns `Ok` with the original stream.
359    /// Otherwise, returns `Err` containing the provided parts.
360    pub fn reunite(
361        sender: WebSocketSender<S>,
362        receiver: WebSocketReceiver<S>,
363    ) -> Result<Self, (WebSocketSender<S>, WebSocketReceiver<S>)> {
364        if sender.is_pair_of(&receiver) {
365            drop(receiver);
366            let stream = Arc::try_unwrap(sender.shared)
367                .ok()
368                .expect("reunite the stream")
369                .into_inner();
370
371            Ok(stream)
372        } else {
373            Err((sender, receiver))
374        }
375    }
376}
377
378impl<S> WebSocketStream<S>
379where
380    S: AsyncRead + AsyncWrite + Unpin,
381{
382    fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Message, WsError>>> {
383        #[cfg(feature = "verbose-logging")]
384        trace!("{}:{} WebSocketStream.poll_next", file!(), line!());
385
386        // The connection has been closed or a critical error has occurred.
387        // We have already returned the error to the user, the `Stream` is unusable,
388        // so we assume that the stream has been "fused".
389        if self.ended {
390            return Poll::Ready(None);
391        }
392
393        match ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
394            #[cfg(feature = "verbose-logging")]
395            trace!(
396                "{}:{} WebSocketStream.with_context poll_next -> read()",
397                file!(),
398                line!()
399            );
400            cvt(s.read())
401        })) {
402            Ok(v) => Poll::Ready(Some(Ok(v))),
403            Err(e) => {
404                self.ended = true;
405                if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) {
406                    Poll::Ready(None)
407                } else {
408                    Poll::Ready(Some(Err(e)))
409                }
410            }
411        }
412    }
413
414    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
415        if self.ready {
416            return Poll::Ready(Ok(()));
417        }
418
419        // Currently blocked so try to flush the blockage away
420        self.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
421            .map(|r| {
422                self.ready = true;
423                r
424            })
425    }
426
427    fn start_send(&mut self, item: Message) -> Result<(), WsError> {
428        match self.with_context(None, |s| s.write(item)) {
429            Ok(()) => {
430                self.ready = true;
431                Ok(())
432            }
433            Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
434                // the message was accepted and queued so not an error
435                // but `poll_ready` will now start trying to flush the block
436                self.ready = false;
437                Ok(())
438            }
439            Err(e) => {
440                self.ready = true;
441                debug!("websocket start_send error: {}", e);
442                Err(e)
443            }
444        }
445    }
446
447    fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
448        self.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
449            .map(|r| {
450                self.ready = true;
451                match r {
452                    // WebSocket connection has just been closed. Flushing completed, not an error.
453                    Err(WsError::ConnectionClosed) => Ok(()),
454                    other => other,
455                }
456            })
457    }
458
459    #[cfg(feature = "futures-03-sink")]
460    fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
461        self.ready = true;
462        let res = if self.closing {
463            // After queueing it, we call `flush` to drive the close handshake to completion.
464            self.with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
465        } else {
466            self.with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
467        };
468
469        match res {
470            Ok(()) => Poll::Ready(Ok(())),
471            Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())),
472            Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
473                trace!("WouldBlock");
474                self.closing = true;
475                Poll::Pending
476            }
477            Err(err) => {
478                debug!("websocket close error: {}", err);
479                Poll::Ready(Err(err))
480            }
481        }
482    }
483}
484
485impl<S> Stream for WebSocketStream<S>
486where
487    S: AsyncRead + AsyncWrite + Unpin,
488{
489    type Item = Result<Message, WsError>;
490
491    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
492        self.get_mut().poll_next(cx)
493    }
494}
495
496impl<S> FusedStream for WebSocketStream<S>
497where
498    S: AsyncRead + AsyncWrite + Unpin,
499{
500    fn is_terminated(&self) -> bool {
501        self.ended
502    }
503}
504
505#[cfg(feature = "futures-03-sink")]
506impl<S> futures_util::Sink<Message> for WebSocketStream<S>
507where
508    S: AsyncRead + AsyncWrite + Unpin,
509{
510    type Error = WsError;
511
512    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
513        self.get_mut().poll_ready(cx)
514    }
515
516    fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
517        self.get_mut().start_send(item)
518    }
519
520    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
521        self.get_mut().poll_flush(cx)
522    }
523
524    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
525        self.get_mut().poll_close(cx)
526    }
527}
528
529#[cfg(not(feature = "futures-03-sink"))]
530impl<S> bytes::private::SealedSender for WebSocketStream<S>
531where
532    S: AsyncRead + AsyncWrite + Unpin,
533{
534    fn poll_write(
535        self: Pin<&mut Self>,
536        cx: &mut Context<'_>,
537        buf: &[u8],
538    ) -> Poll<Result<usize, WsError>> {
539        let me = self.get_mut();
540        ready!(me.poll_ready(cx))?;
541        let len = buf.len();
542        me.start_send(Message::binary(buf.to_owned()))?;
543        Poll::Ready(Ok(len))
544    }
545
546    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
547        self.get_mut().poll_flush(cx)
548    }
549
550    fn poll_close(
551        self: Pin<&mut Self>,
552        cx: &mut Context<'_>,
553        msg: &mut Option<Message>,
554    ) -> Poll<Result<(), WsError>> {
555        let me = self.get_mut();
556        send_helper(me, msg, cx)
557    }
558}
559
560impl<S> WebSocketStream<S> {
561    /// Simple send method to replace `futures_sink::Sink` (till v0.3).
562    pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
563    where
564        S: AsyncRead + AsyncWrite + Unpin,
565    {
566        Send {
567            ws: self,
568            msg: Some(msg),
569        }
570        .await
571    }
572}
573
574struct Send<W> {
575    ws: W,
576    msg: Option<Message>,
577}
578
579/// Performs an asynchronous message send to the websocket.
580fn send_helper<S>(
581    ws: &mut WebSocketStream<S>,
582    msg: &mut Option<Message>,
583    cx: &mut Context<'_>,
584) -> Poll<Result<(), WsError>>
585where
586    S: AsyncRead + AsyncWrite + Unpin,
587{
588    if msg.is_some() {
589        ready!(ws.poll_ready(cx))?;
590        let msg = msg.take().expect("unreachable");
591        ws.start_send(msg)?;
592    }
593
594    ws.poll_flush(cx)
595}
596
597impl<S> std::future::Future for Send<&mut WebSocketStream<S>>
598where
599    S: AsyncRead + AsyncWrite + Unpin,
600{
601    type Output = Result<(), WsError>;
602
603    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
604        let me = self.get_mut();
605        send_helper(me.ws, &mut me.msg, cx)
606    }
607}
608
609impl<S> std::future::Future for Send<&Shared<S>>
610where
611    S: AsyncRead + AsyncWrite + Unpin,
612{
613    type Output = Result<(), WsError>;
614
615    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
616        let me = self.get_mut();
617        let mut ws = me.ws.lock();
618        send_helper(&mut ws, &mut me.msg, cx)
619    }
620}
621
622/// The sender part of a [websocket](WebSocketStream) stream.
623#[derive(Debug)]
624pub struct WebSocketSender<S> {
625    shared: Arc<Shared<S>>,
626}
627
628impl<S> WebSocketSender<S> {
629    /// Send a message via [websocket](WebSocketStream).
630    pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
631    where
632        S: AsyncRead + AsyncWrite + Unpin,
633    {
634        Send {
635            ws: &*self.shared,
636            msg: Some(msg),
637        }
638        .await
639    }
640
641    /// Close the underlying [websocket](WebSocketStream).
642    pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
643    where
644        S: AsyncRead + AsyncWrite + Unpin,
645    {
646        self.send(Message::Close(msg)).await
647    }
648
649    /// Checks if this [sender](WebSocketSender) and some [receiver](WebSocketReceiver)
650    /// were split from the same [websocket](WebSocketStream) stream.
651    pub fn is_pair_of(&self, other: &WebSocketReceiver<S>) -> bool {
652        Arc::ptr_eq(&self.shared, &other.shared)
653    }
654}
655
656#[cfg(feature = "futures-03-sink")]
657impl<T> futures_util::Sink<Message> for WebSocketSender<T>
658where
659    T: AsyncRead + AsyncWrite + Unpin,
660{
661    type Error = WsError;
662
663    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
664        self.shared.lock().poll_ready(cx)
665    }
666
667    fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
668        self.shared.lock().start_send(item)
669    }
670
671    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
672        self.shared.lock().poll_flush(cx)
673    }
674
675    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
676        self.shared.lock().poll_close(cx)
677    }
678}
679
680#[cfg(not(feature = "futures-03-sink"))]
681impl<S> bytes::private::SealedSender for WebSocketSender<S>
682where
683    S: AsyncRead + AsyncWrite + Unpin,
684{
685    fn poll_write(
686        self: Pin<&mut Self>,
687        cx: &mut Context<'_>,
688        buf: &[u8],
689    ) -> Poll<Result<usize, WsError>> {
690        let me = self.get_mut();
691        let mut ws = me.shared.lock();
692        ready!(ws.poll_ready(cx))?;
693        let len = buf.len();
694        ws.start_send(Message::binary(buf.to_owned()))?;
695        Poll::Ready(Ok(len))
696    }
697
698    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
699        self.shared.lock().poll_flush(cx)
700    }
701
702    fn poll_close(
703        self: Pin<&mut Self>,
704        cx: &mut Context<'_>,
705        msg: &mut Option<Message>,
706    ) -> Poll<Result<(), WsError>> {
707        let me = self.get_mut();
708        let mut ws = me.shared.lock();
709        send_helper(&mut ws, msg, cx)
710    }
711}
712
713/// The receiver part of a [websocket](WebSocketStream) stream.
714#[derive(Debug)]
715pub struct WebSocketReceiver<S> {
716    shared: Arc<Shared<S>>,
717}
718
719impl<S> WebSocketReceiver<S> {
720    /// Checks if this [receiver](WebSocketReceiver) and some [sender](WebSocketSender)
721    /// were split from the same [websocket](WebSocketStream) stream.
722    pub fn is_pair_of(&self, other: &WebSocketSender<S>) -> bool {
723        Arc::ptr_eq(&self.shared, &other.shared)
724    }
725}
726
727impl<S> Stream for WebSocketReceiver<S>
728where
729    S: AsyncRead + AsyncWrite + Unpin,
730{
731    type Item = Result<Message, WsError>;
732
733    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
734        self.shared.lock().poll_next(cx)
735    }
736}
737
738impl<S> FusedStream for WebSocketReceiver<S>
739where
740    S: AsyncRead + AsyncWrite + Unpin,
741{
742    fn is_terminated(&self) -> bool {
743        self.shared.lock().ended
744    }
745}
746
747#[derive(Debug)]
748struct Shared<S>(Mutex<WebSocketStream<S>>);
749
750impl<S> Shared<S> {
751    fn lock(&self) -> MutexGuard<'_, WebSocketStream<S>> {
752        self.0.lock().expect("lock shared stream")
753    }
754
755    fn into_inner(self) -> WebSocketStream<S> {
756        self.0.into_inner().expect("get shared stream")
757    }
758}
759
760#[cfg(any(
761    feature = "async-tls",
762    feature = "async-std-runtime",
763    feature = "smol-runtime",
764    feature = "tokio-runtime",
765    feature = "gio-runtime"
766))]
767/// Get a domain from an URL.
768#[inline]
769pub(crate) fn domain(
770    request: &tungstenite::handshake::client::Request,
771) -> Result<String, tungstenite::Error> {
772    request
773        .uri()
774        .host()
775        .map(|host| {
776            // If host is an IPv6 address, it might be surrounded by brackets. These brackets are
777            // *not* part of a valid IP, so they must be stripped out.
778            //
779            // The URI from the request is guaranteed to be valid, so we don't need a separate
780            // check for the closing bracket.
781            let host = if host.starts_with('[') {
782                &host[1..host.len() - 1]
783            } else {
784                host
785            };
786
787            host.to_owned()
788        })
789        .ok_or(tungstenite::Error::Url(
790            tungstenite::error::UrlError::NoHostName,
791        ))
792}
793
794#[cfg(any(
795    feature = "async-std-runtime",
796    feature = "smol-runtime",
797    feature = "tokio-runtime",
798    feature = "gio-runtime"
799))]
800/// Get the port from an URL.
801#[inline]
802pub(crate) fn port(
803    request: &tungstenite::handshake::client::Request,
804) -> Result<u16, tungstenite::Error> {
805    request
806        .uri()
807        .port_u16()
808        .or_else(|| match request.uri().scheme_str() {
809            Some("wss") => Some(443),
810            Some("ws") => Some(80),
811            _ => None,
812        })
813        .ok_or(tungstenite::Error::Url(
814            tungstenite::error::UrlError::UnsupportedUrlScheme,
815        ))
816}
817
818#[cfg(test)]
819mod tests {
820    #[cfg(any(
821        feature = "async-tls",
822        feature = "async-std-runtime",
823        feature = "smol-runtime",
824        feature = "tokio-runtime",
825        feature = "gio-runtime"
826    ))]
827    #[test]
828    fn domain_strips_ipv6_brackets() {
829        use tungstenite::client::IntoClientRequest;
830
831        let request = "ws://[::1]:80".into_client_request().unwrap();
832        assert_eq!(crate::domain(&request).unwrap(), "::1");
833    }
834
835    #[cfg(feature = "handshake")]
836    #[test]
837    fn requests_cannot_contain_invalid_uris() {
838        use tungstenite::client::IntoClientRequest;
839
840        assert!("ws://[".into_client_request().is_err());
841        assert!("ws://[blabla/bla".into_client_request().is_err());
842        assert!("ws://[::1/bla".into_client_request().is_err());
843    }
844}