async_tungstenite/
lib.rs

1//! Async WebSockets.
2//!
3//! This crate is based on [tungstenite](https://crates.io/crates/tungstenite)
4//! Rust WebSocket library and provides async bindings and wrappers for it, so you
5//! can use it with non-blocking/asynchronous `TcpStream`s from and couple it
6//! together with other crates from the async stack. In addition, optional
7//! integration with various other crates can be enabled via feature flags
8//!
9//!  * `async-tls`: Enables the `async_tls` module, which provides integration
10//!    with the [async-tls](https://crates.io/crates/async-tls) TLS stack and can
11//!    be used independent of any async runtime.
12//!  * `async-std-runtime`: Enables the `async_std` module, which provides
13//!    integration with the [async-std](https://async.rs) runtime.
14//!  * `async-native-tls`: Enables the additional functions in the `async_std`
15//!    module to implement TLS via
16//!    [async-native-tls](https://crates.io/crates/async-native-tls).
17//!  * `tokio-runtime`: Enables the `tokio` module, which provides integration
18//!    with the [tokio](https://tokio.rs) runtime.
19//!  * `tokio-native-tls`: Enables the additional functions in the `tokio` module to
20//!    implement TLS via [tokio-native-tls](https://crates.io/crates/tokio-native-tls).
21//!  * `tokio-rustls-native-certs`: Enables the additional functions in the `tokio`
22//!    module to implement TLS via [tokio-rustls](https://crates.io/crates/tokio-rustls)
23//!    and uses native system certificates found with
24//!    [rustls-native-certs](https://github.com/rustls/rustls-native-certs).
25//!  * `tokio-rustls-platform-verifier`: Enables the additional functions in the `tokio`
26//!    module to implement TLS via [tokio-rustls](https://crates.io/crates/tokio-rustls)
27//!    and uses native system certificates via the platform verifier found with
28//!    [rustls-platform-verifier](https://github.com/rustls/rustls-platform-verifier).
29//!  * `tokio-rustls-webpki-roots`: Enables the additional functions in the `tokio`
30//!    module to implement TLS via [tokio-rustls](https://crates.io/crates/tokio-rustls)
31//!    and uses the certificates [webpki-roots](https://github.com/rustls/webpki-roots)
32//!    provides.
33//!  * `tokio-openssl`: Enables the additional functions in the `tokio` module to
34//!    implement TLS via [tokio-openssl](https://crates.io/crates/tokio-openssl).
35//!  * `gio-runtime`: Enables the `gio` module, which provides integration with
36//!    the [gio](https://www.gtk-rs.org) runtime.
37//!
38//! Each WebSocket stream implements the required `Stream` and `Sink` traits,
39//! making the socket a stream of WebSocket messages coming in and going out.
40
41#![deny(
42    missing_docs,
43    unused_must_use,
44    unused_mut,
45    unused_imports,
46    unused_import_braces
47)]
48
49pub use tungstenite;
50
51mod compat;
52mod handshake;
53
54#[cfg(any(
55    feature = "async-tls",
56    feature = "async-native-tls",
57    feature = "tokio-native-tls",
58    feature = "tokio-rustls-manual-roots",
59    feature = "tokio-rustls-native-certs",
60    feature = "tokio-rustls-platform-verifier",
61    feature = "tokio-rustls-webpki-roots",
62    feature = "tokio-openssl",
63))]
64pub mod stream;
65
66use std::{
67    io::{Read, Write},
68    pin::Pin,
69    sync::{Arc, Mutex, MutexGuard},
70    task::{ready, Context, Poll},
71};
72
73use compat::{cvt, AllowStd, ContextWaker};
74use futures_core::stream::{FusedStream, Stream};
75use futures_io::{AsyncRead, AsyncWrite};
76use log::*;
77
78#[cfg(feature = "handshake")]
79use tungstenite::{
80    client::IntoClientRequest,
81    handshake::{
82        client::{ClientHandshake, Response},
83        server::{Callback, NoCallback},
84        HandshakeError,
85    },
86};
87use tungstenite::{
88    error::Error as WsError,
89    protocol::{Message, Role, WebSocket, WebSocketConfig},
90};
91
92#[cfg(feature = "async-std-runtime")]
93pub mod async_std;
94#[cfg(feature = "async-tls")]
95pub mod async_tls;
96#[cfg(feature = "gio-runtime")]
97pub mod gio;
98#[cfg(feature = "smol-runtime")]
99pub mod smol;
100#[cfg(feature = "tokio-runtime")]
101pub mod tokio;
102
103pub mod bytes;
104pub use bytes::ByteReader;
105pub use bytes::ByteWriter;
106
107use tungstenite::protocol::CloseFrame;
108
109/// Creates a WebSocket handshake from a request and a stream.
110/// For convenience, the user may call this with a url string, a URL,
111/// or a `Request`. Calling with `Request` allows the user to add
112/// a WebSocket protocol or other custom headers.
113///
114/// Internally, this custom creates a handshake representation and returns
115/// a future representing the resolution of the WebSocket handshake. The
116/// returned future will resolve to either `WebSocketStream<S>` or `Error`
117/// depending on whether the handshake is successful.
118///
119/// This is typically used for clients who have already established, for
120/// example, a TCP connection to the remote server.
121#[cfg(feature = "handshake")]
122pub async fn client_async<'a, R, S>(
123    request: R,
124    stream: S,
125) -> Result<(WebSocketStream<S>, Response), WsError>
126where
127    R: IntoClientRequest + Unpin,
128    S: AsyncRead + AsyncWrite + Unpin,
129{
130    client_async_with_config(request, stream, None).await
131}
132
133/// The same as `client_async()` but the one can specify a websocket configuration.
134/// Please refer to `client_async()` for more details.
135#[cfg(feature = "handshake")]
136pub async fn client_async_with_config<'a, R, S>(
137    request: R,
138    stream: S,
139    config: Option<WebSocketConfig>,
140) -> Result<(WebSocketStream<S>, Response), WsError>
141where
142    R: IntoClientRequest + Unpin,
143    S: AsyncRead + AsyncWrite + Unpin,
144{
145    let f = handshake::client_handshake(stream, move |allow_std| {
146        let request = request.into_client_request()?;
147        let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
148        cli_handshake.handshake()
149    });
150    f.await.map_err(|e| match e {
151        HandshakeError::Failure(e) => e,
152        e => WsError::Io(std::io::Error::new(
153            std::io::ErrorKind::Other,
154            e.to_string(),
155        )),
156    })
157}
158
159/// Accepts a new WebSocket connection with the provided stream.
160///
161/// This function will internally call `server::accept` to create a
162/// handshake representation and returns a future representing the
163/// resolution of the WebSocket handshake. The returned future will resolve
164/// to either `WebSocketStream<S>` or `Error` depending if it's successful
165/// or not.
166///
167/// This is typically used after a socket has been accepted from a
168/// `TcpListener`. That socket is then passed to this function to perform
169/// the server half of the accepting a client's websocket connection.
170#[cfg(feature = "handshake")]
171pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
172where
173    S: AsyncRead + AsyncWrite + Unpin,
174{
175    accept_hdr_async(stream, NoCallback).await
176}
177
178/// The same as `accept_async()` but the one can specify a websocket configuration.
179/// Please refer to `accept_async()` for more details.
180#[cfg(feature = "handshake")]
181pub async fn accept_async_with_config<S>(
182    stream: S,
183    config: Option<WebSocketConfig>,
184) -> Result<WebSocketStream<S>, WsError>
185where
186    S: AsyncRead + AsyncWrite + Unpin,
187{
188    accept_hdr_async_with_config(stream, NoCallback, config).await
189}
190
191/// Accepts a new WebSocket connection with the provided stream.
192///
193/// This function does the same as `accept_async()` but accepts an extra callback
194/// for header processing. The callback receives headers of the incoming
195/// requests and is able to add extra headers to the reply.
196#[cfg(feature = "handshake")]
197pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
198where
199    S: AsyncRead + AsyncWrite + Unpin,
200    C: Callback + Unpin,
201{
202    accept_hdr_async_with_config(stream, callback, None).await
203}
204
205/// The same as `accept_hdr_async()` but the one can specify a websocket configuration.
206/// Please refer to `accept_hdr_async()` for more details.
207#[cfg(feature = "handshake")]
208pub async fn accept_hdr_async_with_config<S, C>(
209    stream: S,
210    callback: C,
211    config: Option<WebSocketConfig>,
212) -> Result<WebSocketStream<S>, WsError>
213where
214    S: AsyncRead + AsyncWrite + Unpin,
215    C: Callback + Unpin,
216{
217    let f = handshake::server_handshake(stream, move |allow_std| {
218        tungstenite::accept_hdr_with_config(allow_std, callback, config)
219    });
220    f.await.map_err(|e| match e {
221        HandshakeError::Failure(e) => e,
222        e => WsError::Io(std::io::Error::new(
223            std::io::ErrorKind::Other,
224            e.to_string(),
225        )),
226    })
227}
228
229/// A wrapper around an underlying raw stream which implements the WebSocket
230/// protocol.
231///
232/// A `WebSocketStream<S>` represents a handshake that has been completed
233/// successfully and both the server and the client are ready for receiving
234/// and sending data. Message from a `WebSocketStream<S>` are accessible
235/// through the respective `Stream` and `Sink`. Check more information about
236/// them in `futures-rs` crate documentation or have a look on the examples
237/// and unit tests for this crate.
238#[derive(Debug)]
239pub struct WebSocketStream<S> {
240    inner: WebSocket<AllowStd<S>>,
241    #[cfg(feature = "futures-03-sink")]
242    closing: bool,
243    ended: bool,
244    /// Tungstenite is probably ready to receive more data.
245    ///
246    /// `false` once start_send hits `WouldBlock` errors.
247    /// `true` initially and after `flush`ing.
248    ready: bool,
249}
250
251impl<S> WebSocketStream<S> {
252    /// Convert a raw socket into a WebSocketStream without performing a
253    /// handshake.
254    pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
255    where
256        S: AsyncRead + AsyncWrite + Unpin,
257    {
258        handshake::without_handshake(stream, move |allow_std| {
259            WebSocket::from_raw_socket(allow_std, role, config)
260        })
261        .await
262    }
263
264    /// Convert a raw socket into a WebSocketStream without performing a
265    /// handshake.
266    pub async fn from_partially_read(
267        stream: S,
268        part: Vec<u8>,
269        role: Role,
270        config: Option<WebSocketConfig>,
271    ) -> Self
272    where
273        S: AsyncRead + AsyncWrite + Unpin,
274    {
275        handshake::without_handshake(stream, move |allow_std| {
276            WebSocket::from_partially_read(allow_std, part, role, config)
277        })
278        .await
279    }
280
281    pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
282        Self {
283            inner: ws,
284            #[cfg(feature = "futures-03-sink")]
285            closing: false,
286            ended: false,
287            ready: true,
288        }
289    }
290
291    fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
292    where
293        F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
294        AllowStd<S>: Read + Write,
295    {
296        #[cfg(feature = "verbose-logging")]
297        trace!("{}:{} WebSocketStream.with_context", file!(), line!());
298        if let Some((kind, ctx)) = ctx {
299            self.inner.get_mut().set_waker(kind, ctx.waker());
300        }
301        f(&mut self.inner)
302    }
303
304    /// Consumes the `WebSocketStream` and returns the underlying stream.
305    pub fn into_inner(self) -> S {
306        self.inner.into_inner().into_inner()
307    }
308
309    /// Returns a shared reference to the inner stream.
310    pub fn get_ref(&self) -> &S
311    where
312        S: AsyncRead + AsyncWrite + Unpin,
313    {
314        self.inner.get_ref().get_ref()
315    }
316
317    /// Returns a mutable reference to the inner stream.
318    pub fn get_mut(&mut self) -> &mut S
319    where
320        S: AsyncRead + AsyncWrite + Unpin,
321    {
322        self.inner.get_mut().get_mut()
323    }
324
325    /// Returns a reference to the configuration of the tungstenite stream.
326    pub fn get_config(&self) -> &WebSocketConfig {
327        self.inner.get_config()
328    }
329
330    /// Close the underlying web socket
331    pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
332    where
333        S: AsyncRead + AsyncWrite + Unpin,
334    {
335        self.send(Message::Close(msg)).await
336    }
337
338    /// Splits the websocket stream into separate
339    /// [sender](WebSocketSender) and [receiver](WebSocketReceiver) parts.
340    pub fn split(self) -> (WebSocketSender<S>, WebSocketReceiver<S>) {
341        let shared = Arc::new(Shared(Mutex::new(self)));
342        let sender = WebSocketSender {
343            shared: shared.clone(),
344        };
345
346        let receiver = WebSocketReceiver { shared };
347        (sender, receiver)
348    }
349
350    /// Attempts to reunite the [sender](WebSocketSender) and [receiver](WebSocketReceiver)
351    /// parts back into a single stream. If both parts originate from the same
352    /// [`split`](WebSocketStream::split) call, returns `Ok` with the original stream.
353    /// Otherwise, returns `Err` containing the provided parts.
354    pub fn reunite(
355        sender: WebSocketSender<S>,
356        receiver: WebSocketReceiver<S>,
357    ) -> Result<Self, (WebSocketSender<S>, WebSocketReceiver<S>)> {
358        if sender.is_pair_of(&receiver) {
359            drop(receiver);
360            let stream = Arc::try_unwrap(sender.shared)
361                .ok()
362                .expect("reunite the stream")
363                .into_inner();
364
365            Ok(stream)
366        } else {
367            Err((sender, receiver))
368        }
369    }
370}
371
372impl<S> WebSocketStream<S>
373where
374    S: AsyncRead + AsyncWrite + Unpin,
375{
376    fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Message, WsError>>> {
377        #[cfg(feature = "verbose-logging")]
378        trace!("{}:{} WebSocketStream.poll_next", file!(), line!());
379
380        // The connection has been closed or a critical error has occurred.
381        // We have already returned the error to the user, the `Stream` is unusable,
382        // so we assume that the stream has been "fused".
383        if self.ended {
384            return Poll::Ready(None);
385        }
386
387        match ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
388            #[cfg(feature = "verbose-logging")]
389            trace!(
390                "{}:{} WebSocketStream.with_context poll_next -> read()",
391                file!(),
392                line!()
393            );
394            cvt(s.read())
395        })) {
396            Ok(v) => Poll::Ready(Some(Ok(v))),
397            Err(e) => {
398                self.ended = true;
399                if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) {
400                    Poll::Ready(None)
401                } else {
402                    Poll::Ready(Some(Err(e)))
403                }
404            }
405        }
406    }
407
408    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
409        if self.ready {
410            return Poll::Ready(Ok(()));
411        }
412
413        // Currently blocked so try to flush the blockage away
414        self.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
415            .map(|r| {
416                self.ready = true;
417                r
418            })
419    }
420
421    fn start_send(&mut self, item: Message) -> Result<(), WsError> {
422        match self.with_context(None, |s| s.write(item)) {
423            Ok(()) => {
424                self.ready = true;
425                Ok(())
426            }
427            Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
428                // the message was accepted and queued so not an error
429                // but `poll_ready` will now start trying to flush the block
430                self.ready = false;
431                Ok(())
432            }
433            Err(e) => {
434                self.ready = true;
435                debug!("websocket start_send error: {}", e);
436                Err(e)
437            }
438        }
439    }
440
441    fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
442        self.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
443            .map(|r| {
444                self.ready = true;
445                match r {
446                    // WebSocket connection has just been closed. Flushing completed, not an error.
447                    Err(WsError::ConnectionClosed) => Ok(()),
448                    other => other,
449                }
450            })
451    }
452
453    #[cfg(feature = "futures-03-sink")]
454    fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
455        self.ready = true;
456        let res = if self.closing {
457            // After queueing it, we call `flush` to drive the close handshake to completion.
458            self.with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
459        } else {
460            self.with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
461        };
462
463        match res {
464            Ok(()) => Poll::Ready(Ok(())),
465            Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())),
466            Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
467                trace!("WouldBlock");
468                self.closing = true;
469                Poll::Pending
470            }
471            Err(err) => {
472                debug!("websocket close error: {}", err);
473                Poll::Ready(Err(err))
474            }
475        }
476    }
477}
478
479impl<S> Stream for WebSocketStream<S>
480where
481    S: AsyncRead + AsyncWrite + Unpin,
482{
483    type Item = Result<Message, WsError>;
484
485    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
486        self.get_mut().poll_next(cx)
487    }
488}
489
490impl<S> FusedStream for WebSocketStream<S>
491where
492    S: AsyncRead + AsyncWrite + Unpin,
493{
494    fn is_terminated(&self) -> bool {
495        self.ended
496    }
497}
498
499#[cfg(feature = "futures-03-sink")]
500impl<S> futures_util::Sink<Message> for WebSocketStream<S>
501where
502    S: AsyncRead + AsyncWrite + Unpin,
503{
504    type Error = WsError;
505
506    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
507        self.get_mut().poll_ready(cx)
508    }
509
510    fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
511        self.get_mut().start_send(item)
512    }
513
514    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
515        self.get_mut().poll_flush(cx)
516    }
517
518    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
519        self.get_mut().poll_close(cx)
520    }
521}
522
523#[cfg(not(feature = "futures-03-sink"))]
524impl<S> bytes::private::SealedSender for WebSocketStream<S>
525where
526    S: AsyncRead + AsyncWrite + Unpin,
527{
528    fn poll_write(
529        self: Pin<&mut Self>,
530        cx: &mut Context<'_>,
531        buf: &[u8],
532    ) -> Poll<Result<usize, WsError>> {
533        let me = self.get_mut();
534        ready!(me.poll_ready(cx))?;
535        let len = buf.len();
536        me.start_send(Message::binary(buf.to_owned()))?;
537        Poll::Ready(Ok(len))
538    }
539
540    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
541        self.get_mut().poll_flush(cx)
542    }
543
544    fn poll_close(
545        self: Pin<&mut Self>,
546        cx: &mut Context<'_>,
547        msg: &mut Option<Message>,
548    ) -> Poll<Result<(), WsError>> {
549        let me = self.get_mut();
550        send_helper(me, msg, cx)
551    }
552}
553
554impl<S> WebSocketStream<S> {
555    /// Simple send method to replace `futures_sink::Sink` (till v0.3).
556    pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
557    where
558        S: AsyncRead + AsyncWrite + Unpin,
559    {
560        Send {
561            ws: self,
562            msg: Some(msg),
563        }
564        .await
565    }
566}
567
568struct Send<W> {
569    ws: W,
570    msg: Option<Message>,
571}
572
573/// Performs an asynchronous message send to the websocket.
574fn send_helper<S>(
575    ws: &mut WebSocketStream<S>,
576    msg: &mut Option<Message>,
577    cx: &mut Context<'_>,
578) -> Poll<Result<(), WsError>>
579where
580    S: AsyncRead + AsyncWrite + Unpin,
581{
582    if msg.is_some() {
583        ready!(ws.poll_ready(cx))?;
584        let msg = msg.take().expect("unreachable");
585        ws.start_send(msg)?;
586    }
587
588    ws.poll_flush(cx)
589}
590
591impl<S> std::future::Future for Send<&mut WebSocketStream<S>>
592where
593    S: AsyncRead + AsyncWrite + Unpin,
594{
595    type Output = Result<(), WsError>;
596
597    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
598        let me = self.get_mut();
599        send_helper(me.ws, &mut me.msg, cx)
600    }
601}
602
603impl<S> std::future::Future for Send<&Shared<S>>
604where
605    S: AsyncRead + AsyncWrite + Unpin,
606{
607    type Output = Result<(), WsError>;
608
609    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
610        let me = self.get_mut();
611        let mut ws = me.ws.lock();
612        send_helper(&mut ws, &mut me.msg, cx)
613    }
614}
615
616/// The sender part of a [websocket](WebSocketStream) stream.
617#[derive(Debug)]
618pub struct WebSocketSender<S> {
619    shared: Arc<Shared<S>>,
620}
621
622impl<S> WebSocketSender<S> {
623    /// Send a message via [websocket](WebSocketStream).
624    pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
625    where
626        S: AsyncRead + AsyncWrite + Unpin,
627    {
628        Send {
629            ws: &*self.shared,
630            msg: Some(msg),
631        }
632        .await
633    }
634
635    /// Close the underlying [websocket](WebSocketStream).
636    pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
637    where
638        S: AsyncRead + AsyncWrite + Unpin,
639    {
640        self.send(Message::Close(msg)).await
641    }
642
643    /// Checks if this [sender](WebSocketSender) and some [receiver](WebSocketReceiver)
644    /// were split from the same [websocket](WebSocketStream) stream.
645    pub fn is_pair_of(&self, other: &WebSocketReceiver<S>) -> bool {
646        Arc::ptr_eq(&self.shared, &other.shared)
647    }
648}
649
650#[cfg(feature = "futures-03-sink")]
651impl<T> futures_util::Sink<Message> for WebSocketSender<T>
652where
653    T: AsyncRead + AsyncWrite + Unpin,
654{
655    type Error = WsError;
656
657    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
658        self.shared.lock().poll_ready(cx)
659    }
660
661    fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
662        self.shared.lock().start_send(item)
663    }
664
665    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
666        self.shared.lock().poll_flush(cx)
667    }
668
669    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
670        self.shared.lock().poll_close(cx)
671    }
672}
673
674#[cfg(not(feature = "futures-03-sink"))]
675impl<S> bytes::private::SealedSender for WebSocketSender<S>
676where
677    S: AsyncRead + AsyncWrite + Unpin,
678{
679    fn poll_write(
680        self: Pin<&mut Self>,
681        cx: &mut Context<'_>,
682        buf: &[u8],
683    ) -> Poll<Result<usize, WsError>> {
684        let me = self.get_mut();
685        let mut ws = me.shared.lock();
686        ready!(ws.poll_ready(cx))?;
687        let len = buf.len();
688        ws.start_send(Message::binary(buf.to_owned()))?;
689        Poll::Ready(Ok(len))
690    }
691
692    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
693        self.shared.lock().poll_flush(cx)
694    }
695
696    fn poll_close(
697        self: Pin<&mut Self>,
698        cx: &mut Context<'_>,
699        msg: &mut Option<Message>,
700    ) -> Poll<Result<(), WsError>> {
701        let me = self.get_mut();
702        let mut ws = me.shared.lock();
703        send_helper(&mut ws, msg, cx)
704    }
705}
706
707/// The receiver part of a [websocket](WebSocketStream) stream.
708#[derive(Debug)]
709pub struct WebSocketReceiver<S> {
710    shared: Arc<Shared<S>>,
711}
712
713impl<S> WebSocketReceiver<S> {
714    /// Checks if this [receiver](WebSocketReceiver) and some [sender](WebSocketSender)
715    /// were split from the same [websocket](WebSocketStream) stream.
716    pub fn is_pair_of(&self, other: &WebSocketSender<S>) -> bool {
717        Arc::ptr_eq(&self.shared, &other.shared)
718    }
719}
720
721impl<S> Stream for WebSocketReceiver<S>
722where
723    S: AsyncRead + AsyncWrite + Unpin,
724{
725    type Item = Result<Message, WsError>;
726
727    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
728        self.shared.lock().poll_next(cx)
729    }
730}
731
732impl<S> FusedStream for WebSocketReceiver<S>
733where
734    S: AsyncRead + AsyncWrite + Unpin,
735{
736    fn is_terminated(&self) -> bool {
737        self.shared.lock().ended
738    }
739}
740
741#[derive(Debug)]
742struct Shared<S>(Mutex<WebSocketStream<S>>);
743
744impl<S> Shared<S> {
745    fn lock(&self) -> MutexGuard<'_, WebSocketStream<S>> {
746        self.0.lock().expect("lock shared stream")
747    }
748
749    fn into_inner(self) -> WebSocketStream<S> {
750        self.0.into_inner().expect("get shared stream")
751    }
752}
753
754#[cfg(any(
755    feature = "async-tls",
756    feature = "async-std-runtime",
757    feature = "smol-runtime",
758    feature = "tokio-runtime",
759    feature = "gio-runtime"
760))]
761/// Get a domain from an URL.
762#[inline]
763pub(crate) fn domain(
764    request: &tungstenite::handshake::client::Request,
765) -> Result<String, tungstenite::Error> {
766    request
767        .uri()
768        .host()
769        .map(|host| {
770            // If host is an IPv6 address, it might be surrounded by brackets. These brackets are
771            // *not* part of a valid IP, so they must be stripped out.
772            //
773            // The URI from the request is guaranteed to be valid, so we don't need a separate
774            // check for the closing bracket.
775            let host = if host.starts_with('[') {
776                &host[1..host.len() - 1]
777            } else {
778                host
779            };
780
781            host.to_owned()
782        })
783        .ok_or(tungstenite::Error::Url(
784            tungstenite::error::UrlError::NoHostName,
785        ))
786}
787
788#[cfg(any(
789    feature = "async-std-runtime",
790    feature = "smol-runtime",
791    feature = "tokio-runtime",
792    feature = "gio-runtime"
793))]
794/// Get the port from an URL.
795#[inline]
796pub(crate) fn port(
797    request: &tungstenite::handshake::client::Request,
798) -> Result<u16, tungstenite::Error> {
799    request
800        .uri()
801        .port_u16()
802        .or_else(|| match request.uri().scheme_str() {
803            Some("wss") => Some(443),
804            Some("ws") => Some(80),
805            _ => None,
806        })
807        .ok_or(tungstenite::Error::Url(
808            tungstenite::error::UrlError::UnsupportedUrlScheme,
809        ))
810}
811
812#[cfg(test)]
813mod tests {
814    #[cfg(any(
815        feature = "async-tls",
816        feature = "async-std-runtime",
817        feature = "smol-runtime",
818        feature = "tokio-runtime",
819        feature = "gio-runtime"
820    ))]
821    #[test]
822    fn domain_strips_ipv6_brackets() {
823        use tungstenite::client::IntoClientRequest;
824
825        let request = "ws://[::1]:80".into_client_request().unwrap();
826        assert_eq!(crate::domain(&request).unwrap(), "::1");
827    }
828
829    #[cfg(feature = "handshake")]
830    #[test]
831    fn requests_cannot_contain_invalid_uris() {
832        use tungstenite::client::IntoClientRequest;
833
834        assert!("ws://[".into_client_request().is_err());
835        assert!("ws://[blabla/bla".into_client_request().is_err());
836        assert!("ws://[::1/bla".into_client_request().is_err());
837    }
838}