async_tungstenite/
lib.rs

1//! Async WebSockets.
2//!
3//! This crate is based on [tungstenite](https://crates.io/crates/tungstenite)
4//! Rust WebSocket library and provides async bindings and wrappers for it, so you
5//! can use it with non-blocking/asynchronous `TcpStream`s from and couple it
6//! together with other crates from the async stack. In addition, optional
7//! integration with various other crates can be enabled via feature flags
8//!
9//!  * `async-tls`: Enables the `async_tls` module, which provides integration
10//!    with the [async-tls](https://crates.io/crates/async-tls) TLS stack and can
11//!    be used independent of any async runtime.
12//!  * `async-std-runtime`: Enables the `async_std` module, which provides
13//!    integration with the [async-std](https://async.rs) runtime.
14//!  * `async-native-tls`: Enables the additional functions in the `async_std`
15//!    module to implement TLS via
16//!    [async-native-tls](https://crates.io/crates/async-native-tls).
17//!  * `tokio-runtime`: Enables the `tokio` module, which provides integration
18//!    with the [tokio](https://tokio.rs) runtime.
19//!  * `tokio-native-tls`: Enables the additional functions in the `tokio` module to
20//!    implement TLS via [tokio-native-tls](https://crates.io/crates/tokio-native-tls).
21//!  * `tokio-rustls-native-certs`: Enables the additional functions in the `tokio`
22//!    module to implement TLS via [tokio-rustls](https://crates.io/crates/tokio-rustls)
23//!    and uses native system certificates found with
24//!    [rustls-native-certs](https://github.com/rustls/rustls-native-certs).
25//!  * `tokio-rustls-platform-verifier`: Enables the additional functions in the `tokio`
26//!    module to implement TLS via [tokio-rustls](https://crates.io/crates/tokio-rustls)
27//!    and uses native system certificates via the platform verifier found with
28//!    [rustls-platform-verifier](https://github.com/rustls/rustls-platform-verifier).
29//!  * `tokio-rustls-webpki-roots`: Enables the additional functions in the `tokio`
30//!    module to implement TLS via [tokio-rustls](https://crates.io/crates/tokio-rustls)
31//!    and uses the certificates [webpki-roots](https://github.com/rustls/webpki-roots)
32//!    provides.
33//!  * `tokio-openssl`: Enables the additional functions in the `tokio` module to
34//!    implement TLS via [tokio-openssl](https://crates.io/crates/tokio-openssl).
35//!  * `gio-runtime`: Enables the `gio` module, which provides integration with
36//!    the [gio](https://www.gtk-rs.org) runtime.
37//!
38//! Each WebSocket stream implements the required `Stream` and `Sink` traits,
39//! making the socket a stream of WebSocket messages coming in and going out.
40
41#![deny(
42    missing_docs,
43    unused_must_use,
44    unused_mut,
45    unused_imports,
46    unused_import_braces
47)]
48
49pub use tungstenite;
50
51mod compat;
52mod handshake;
53
54#[cfg(any(
55    feature = "async-tls",
56    feature = "async-native-tls",
57    feature = "tokio-native-tls",
58    feature = "tokio-rustls-manual-roots",
59    feature = "tokio-rustls-native-certs",
60    feature = "tokio-rustls-platform-verifier",
61    feature = "tokio-rustls-webpki-roots",
62    feature = "tokio-openssl",
63))]
64pub mod stream;
65
66use std::{
67    io::{Read, Write},
68    pin::Pin,
69    sync::{Arc, Mutex, MutexGuard},
70    task::{ready, Context, Poll},
71};
72
73use compat::{cvt, AllowStd, ContextWaker};
74use futures_core::stream::{FusedStream, Stream};
75use futures_io::{AsyncRead, AsyncWrite};
76use log::*;
77
78#[cfg(feature = "handshake")]
79use tungstenite::{
80    client::IntoClientRequest,
81    handshake::{
82        client::{ClientHandshake, Response},
83        server::{Callback, NoCallback},
84        HandshakeError,
85    },
86};
87use tungstenite::{
88    error::Error as WsError,
89    protocol::{Message, Role, WebSocket, WebSocketConfig},
90};
91
92#[cfg(feature = "async-std-runtime")]
93pub mod async_std;
94#[cfg(feature = "async-tls")]
95pub mod async_tls;
96#[cfg(feature = "gio-runtime")]
97pub mod gio;
98#[cfg(feature = "tokio-runtime")]
99pub mod tokio;
100
101pub mod bytes;
102pub use bytes::ByteReader;
103pub use bytes::ByteWriter;
104
105use tungstenite::protocol::CloseFrame;
106
107/// Creates a WebSocket handshake from a request and a stream.
108/// For convenience, the user may call this with a url string, a URL,
109/// or a `Request`. Calling with `Request` allows the user to add
110/// a WebSocket protocol or other custom headers.
111///
112/// Internally, this custom creates a handshake representation and returns
113/// a future representing the resolution of the WebSocket handshake. The
114/// returned future will resolve to either `WebSocketStream<S>` or `Error`
115/// depending on whether the handshake is successful.
116///
117/// This is typically used for clients who have already established, for
118/// example, a TCP connection to the remote server.
119#[cfg(feature = "handshake")]
120pub async fn client_async<'a, R, S>(
121    request: R,
122    stream: S,
123) -> Result<(WebSocketStream<S>, Response), WsError>
124where
125    R: IntoClientRequest + Unpin,
126    S: AsyncRead + AsyncWrite + Unpin,
127{
128    client_async_with_config(request, stream, None).await
129}
130
131/// The same as `client_async()` but the one can specify a websocket configuration.
132/// Please refer to `client_async()` for more details.
133#[cfg(feature = "handshake")]
134pub async fn client_async_with_config<'a, R, S>(
135    request: R,
136    stream: S,
137    config: Option<WebSocketConfig>,
138) -> Result<(WebSocketStream<S>, Response), WsError>
139where
140    R: IntoClientRequest + Unpin,
141    S: AsyncRead + AsyncWrite + Unpin,
142{
143    let f = handshake::client_handshake(stream, move |allow_std| {
144        let request = request.into_client_request()?;
145        let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
146        cli_handshake.handshake()
147    });
148    f.await.map_err(|e| match e {
149        HandshakeError::Failure(e) => e,
150        e => WsError::Io(std::io::Error::new(
151            std::io::ErrorKind::Other,
152            e.to_string(),
153        )),
154    })
155}
156
157/// Accepts a new WebSocket connection with the provided stream.
158///
159/// This function will internally call `server::accept` to create a
160/// handshake representation and returns a future representing the
161/// resolution of the WebSocket handshake. The returned future will resolve
162/// to either `WebSocketStream<S>` or `Error` depending if it's successful
163/// or not.
164///
165/// This is typically used after a socket has been accepted from a
166/// `TcpListener`. That socket is then passed to this function to perform
167/// the server half of the accepting a client's websocket connection.
168#[cfg(feature = "handshake")]
169pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
170where
171    S: AsyncRead + AsyncWrite + Unpin,
172{
173    accept_hdr_async(stream, NoCallback).await
174}
175
176/// The same as `accept_async()` but the one can specify a websocket configuration.
177/// Please refer to `accept_async()` for more details.
178#[cfg(feature = "handshake")]
179pub async fn accept_async_with_config<S>(
180    stream: S,
181    config: Option<WebSocketConfig>,
182) -> Result<WebSocketStream<S>, WsError>
183where
184    S: AsyncRead + AsyncWrite + Unpin,
185{
186    accept_hdr_async_with_config(stream, NoCallback, config).await
187}
188
189/// Accepts a new WebSocket connection with the provided stream.
190///
191/// This function does the same as `accept_async()` but accepts an extra callback
192/// for header processing. The callback receives headers of the incoming
193/// requests and is able to add extra headers to the reply.
194#[cfg(feature = "handshake")]
195pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
196where
197    S: AsyncRead + AsyncWrite + Unpin,
198    C: Callback + Unpin,
199{
200    accept_hdr_async_with_config(stream, callback, None).await
201}
202
203/// The same as `accept_hdr_async()` but the one can specify a websocket configuration.
204/// Please refer to `accept_hdr_async()` for more details.
205#[cfg(feature = "handshake")]
206pub async fn accept_hdr_async_with_config<S, C>(
207    stream: S,
208    callback: C,
209    config: Option<WebSocketConfig>,
210) -> Result<WebSocketStream<S>, WsError>
211where
212    S: AsyncRead + AsyncWrite + Unpin,
213    C: Callback + Unpin,
214{
215    let f = handshake::server_handshake(stream, move |allow_std| {
216        tungstenite::accept_hdr_with_config(allow_std, callback, config)
217    });
218    f.await.map_err(|e| match e {
219        HandshakeError::Failure(e) => e,
220        e => WsError::Io(std::io::Error::new(
221            std::io::ErrorKind::Other,
222            e.to_string(),
223        )),
224    })
225}
226
227/// A wrapper around an underlying raw stream which implements the WebSocket
228/// protocol.
229///
230/// A `WebSocketStream<S>` represents a handshake that has been completed
231/// successfully and both the server and the client are ready for receiving
232/// and sending data. Message from a `WebSocketStream<S>` are accessible
233/// through the respective `Stream` and `Sink`. Check more information about
234/// them in `futures-rs` crate documentation or have a look on the examples
235/// and unit tests for this crate.
236#[derive(Debug)]
237pub struct WebSocketStream<S> {
238    inner: WebSocket<AllowStd<S>>,
239    #[cfg(feature = "futures-03-sink")]
240    closing: bool,
241    ended: bool,
242    /// Tungstenite is probably ready to receive more data.
243    ///
244    /// `false` once start_send hits `WouldBlock` errors.
245    /// `true` initially and after `flush`ing.
246    ready: bool,
247}
248
249impl<S> WebSocketStream<S> {
250    /// Convert a raw socket into a WebSocketStream without performing a
251    /// handshake.
252    pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
253    where
254        S: AsyncRead + AsyncWrite + Unpin,
255    {
256        handshake::without_handshake(stream, move |allow_std| {
257            WebSocket::from_raw_socket(allow_std, role, config)
258        })
259        .await
260    }
261
262    /// Convert a raw socket into a WebSocketStream without performing a
263    /// handshake.
264    pub async fn from_partially_read(
265        stream: S,
266        part: Vec<u8>,
267        role: Role,
268        config: Option<WebSocketConfig>,
269    ) -> Self
270    where
271        S: AsyncRead + AsyncWrite + Unpin,
272    {
273        handshake::without_handshake(stream, move |allow_std| {
274            WebSocket::from_partially_read(allow_std, part, role, config)
275        })
276        .await
277    }
278
279    pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
280        Self {
281            inner: ws,
282            #[cfg(feature = "futures-03-sink")]
283            closing: false,
284            ended: false,
285            ready: true,
286        }
287    }
288
289    fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
290    where
291        F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
292        AllowStd<S>: Read + Write,
293    {
294        #[cfg(feature = "verbose-logging")]
295        trace!("{}:{} WebSocketStream.with_context", file!(), line!());
296        if let Some((kind, ctx)) = ctx {
297            self.inner.get_mut().set_waker(kind, ctx.waker());
298        }
299        f(&mut self.inner)
300    }
301
302    /// Consumes the `WebSocketStream` and returns the underlying stream.
303    pub fn into_inner(self) -> S {
304        self.inner.into_inner().into_inner()
305    }
306
307    /// Returns a shared reference to the inner stream.
308    pub fn get_ref(&self) -> &S
309    where
310        S: AsyncRead + AsyncWrite + Unpin,
311    {
312        self.inner.get_ref().get_ref()
313    }
314
315    /// Returns a mutable reference to the inner stream.
316    pub fn get_mut(&mut self) -> &mut S
317    where
318        S: AsyncRead + AsyncWrite + Unpin,
319    {
320        self.inner.get_mut().get_mut()
321    }
322
323    /// Returns a reference to the configuration of the tungstenite stream.
324    pub fn get_config(&self) -> &WebSocketConfig {
325        self.inner.get_config()
326    }
327
328    /// Close the underlying web socket
329    pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
330    where
331        S: AsyncRead + AsyncWrite + Unpin,
332    {
333        self.send(Message::Close(msg)).await
334    }
335
336    /// Splits the websocket stream into separate
337    /// [sender](WebSocketSender) and [receiver](WebSocketReceiver) parts.
338    pub fn split(self) -> (WebSocketSender<S>, WebSocketReceiver<S>) {
339        let shared = Arc::new(Shared(Mutex::new(self)));
340        let sender = WebSocketSender {
341            shared: shared.clone(),
342        };
343
344        let receiver = WebSocketReceiver { shared };
345        (sender, receiver)
346    }
347
348    /// Attempts to reunite the [sender](WebSocketSender) and [receiver](WebSocketReceiver)
349    /// parts back into a single stream. If both parts originate from the same
350    /// [`split`](WebSocketStream::split) call, returns `Ok` with the original stream.
351    /// Otherwise, returns `Err` containing the provided parts.
352    pub fn reunite(
353        sender: WebSocketSender<S>,
354        receiver: WebSocketReceiver<S>,
355    ) -> Result<Self, (WebSocketSender<S>, WebSocketReceiver<S>)> {
356        if sender.is_pair_of(&receiver) {
357            drop(receiver);
358            let stream = Arc::try_unwrap(sender.shared)
359                .ok()
360                .expect("reunite the stream")
361                .into_inner();
362
363            Ok(stream)
364        } else {
365            Err((sender, receiver))
366        }
367    }
368}
369
370impl<S> WebSocketStream<S>
371where
372    S: AsyncRead + AsyncWrite + Unpin,
373{
374    fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Message, WsError>>> {
375        #[cfg(feature = "verbose-logging")]
376        trace!("{}:{} WebSocketStream.poll_next", file!(), line!());
377
378        // The connection has been closed or a critical error has occurred.
379        // We have already returned the error to the user, the `Stream` is unusable,
380        // so we assume that the stream has been "fused".
381        if self.ended {
382            return Poll::Ready(None);
383        }
384
385        match ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
386            #[cfg(feature = "verbose-logging")]
387            trace!(
388                "{}:{} WebSocketStream.with_context poll_next -> read()",
389                file!(),
390                line!()
391            );
392            cvt(s.read())
393        })) {
394            Ok(v) => Poll::Ready(Some(Ok(v))),
395            Err(e) => {
396                self.ended = true;
397                if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) {
398                    Poll::Ready(None)
399                } else {
400                    Poll::Ready(Some(Err(e)))
401                }
402            }
403        }
404    }
405
406    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
407        if self.ready {
408            return Poll::Ready(Ok(()));
409        }
410
411        // Currently blocked so try to flush the blockage away
412        self.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
413            .map(|r| {
414                self.ready = true;
415                r
416            })
417    }
418
419    fn start_send(&mut self, item: Message) -> Result<(), WsError> {
420        match self.with_context(None, |s| s.write(item)) {
421            Ok(()) => {
422                self.ready = true;
423                Ok(())
424            }
425            Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
426                // the message was accepted and queued so not an error
427                // but `poll_ready` will now start trying to flush the block
428                self.ready = false;
429                Ok(())
430            }
431            Err(e) => {
432                self.ready = true;
433                debug!("websocket start_send error: {}", e);
434                Err(e)
435            }
436        }
437    }
438
439    fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
440        self.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
441            .map(|r| {
442                self.ready = true;
443                match r {
444                    // WebSocket connection has just been closed. Flushing completed, not an error.
445                    Err(WsError::ConnectionClosed) => Ok(()),
446                    other => other,
447                }
448            })
449    }
450
451    #[cfg(feature = "futures-03-sink")]
452    fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
453        self.ready = true;
454        let res = if self.closing {
455            // After queueing it, we call `flush` to drive the close handshake to completion.
456            self.with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
457        } else {
458            self.with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
459        };
460
461        match res {
462            Ok(()) => Poll::Ready(Ok(())),
463            Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())),
464            Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
465                trace!("WouldBlock");
466                self.closing = true;
467                Poll::Pending
468            }
469            Err(err) => {
470                debug!("websocket close error: {}", err);
471                Poll::Ready(Err(err))
472            }
473        }
474    }
475}
476
477impl<S> Stream for WebSocketStream<S>
478where
479    S: AsyncRead + AsyncWrite + Unpin,
480{
481    type Item = Result<Message, WsError>;
482
483    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
484        self.get_mut().poll_next(cx)
485    }
486}
487
488impl<S> FusedStream for WebSocketStream<S>
489where
490    S: AsyncRead + AsyncWrite + Unpin,
491{
492    fn is_terminated(&self) -> bool {
493        self.ended
494    }
495}
496
497#[cfg(feature = "futures-03-sink")]
498impl<S> futures_util::Sink<Message> for WebSocketStream<S>
499where
500    S: AsyncRead + AsyncWrite + Unpin,
501{
502    type Error = WsError;
503
504    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
505        self.get_mut().poll_ready(cx)
506    }
507
508    fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
509        self.get_mut().start_send(item)
510    }
511
512    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
513        self.get_mut().poll_flush(cx)
514    }
515
516    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
517        self.get_mut().poll_close(cx)
518    }
519}
520
521#[cfg(not(feature = "futures-03-sink"))]
522impl<S> bytes::private::SealedSender for WebSocketStream<S>
523where
524    S: AsyncRead + AsyncWrite + Unpin,
525{
526    fn poll_write(
527        self: Pin<&mut Self>,
528        cx: &mut Context<'_>,
529        buf: &[u8],
530    ) -> Poll<Result<usize, WsError>> {
531        let me = self.get_mut();
532        ready!(me.poll_ready(cx))?;
533        let len = buf.len();
534        me.start_send(Message::binary(buf.to_owned()))?;
535        Poll::Ready(Ok(len))
536    }
537
538    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
539        self.get_mut().poll_flush(cx)
540    }
541
542    fn poll_close(
543        self: Pin<&mut Self>,
544        cx: &mut Context<'_>,
545        msg: &mut Option<Message>,
546    ) -> Poll<Result<(), WsError>> {
547        let me = self.get_mut();
548        send_helper(me, msg, cx)
549    }
550}
551
552impl<S> WebSocketStream<S> {
553    /// Simple send method to replace `futures_sink::Sink` (till v0.3).
554    pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
555    where
556        S: AsyncRead + AsyncWrite + Unpin,
557    {
558        Send {
559            ws: self,
560            msg: Some(msg),
561        }
562        .await
563    }
564}
565
566struct Send<W> {
567    ws: W,
568    msg: Option<Message>,
569}
570
571/// Performs an asynchronous message send to the websocket.
572fn send_helper<S>(
573    ws: &mut WebSocketStream<S>,
574    msg: &mut Option<Message>,
575    cx: &mut Context<'_>,
576) -> Poll<Result<(), WsError>>
577where
578    S: AsyncRead + AsyncWrite + Unpin,
579{
580    if msg.is_some() {
581        ready!(ws.poll_ready(cx))?;
582        let msg = msg.take().expect("unreachable");
583        ws.start_send(msg)?;
584    }
585
586    ws.poll_flush(cx)
587}
588
589impl<S> std::future::Future for Send<&mut WebSocketStream<S>>
590where
591    S: AsyncRead + AsyncWrite + Unpin,
592{
593    type Output = Result<(), WsError>;
594
595    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
596        let me = self.get_mut();
597        send_helper(me.ws, &mut me.msg, cx)
598    }
599}
600
601impl<S> std::future::Future for Send<&Shared<S>>
602where
603    S: AsyncRead + AsyncWrite + Unpin,
604{
605    type Output = Result<(), WsError>;
606
607    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
608        let me = self.get_mut();
609        let mut ws = me.ws.lock();
610        send_helper(&mut ws, &mut me.msg, cx)
611    }
612}
613
614/// The sender part of a [websocket](WebSocketStream) stream.
615#[derive(Debug)]
616pub struct WebSocketSender<S> {
617    shared: Arc<Shared<S>>,
618}
619
620impl<S> WebSocketSender<S> {
621    /// Send a message via [websocket](WebSocketStream).
622    pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
623    where
624        S: AsyncRead + AsyncWrite + Unpin,
625    {
626        Send {
627            ws: &*self.shared,
628            msg: Some(msg),
629        }
630        .await
631    }
632
633    /// Close the underlying [websocket](WebSocketStream).
634    pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
635    where
636        S: AsyncRead + AsyncWrite + Unpin,
637    {
638        self.send(Message::Close(msg)).await
639    }
640
641    /// Checks if this [sender](WebSocketSender) and some [receiver](WebSocketReceiver)
642    /// were split from the same [websocket](WebSocketStream) stream.
643    pub fn is_pair_of(&self, other: &WebSocketReceiver<S>) -> bool {
644        Arc::ptr_eq(&self.shared, &other.shared)
645    }
646}
647
648#[cfg(feature = "futures-03-sink")]
649impl<T> futures_util::Sink<Message> for WebSocketSender<T>
650where
651    T: AsyncRead + AsyncWrite + Unpin,
652{
653    type Error = WsError;
654
655    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
656        self.shared.lock().poll_ready(cx)
657    }
658
659    fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
660        self.shared.lock().start_send(item)
661    }
662
663    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
664        self.shared.lock().poll_flush(cx)
665    }
666
667    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
668        self.shared.lock().poll_close(cx)
669    }
670}
671
672#[cfg(not(feature = "futures-03-sink"))]
673impl<S> bytes::private::SealedSender for WebSocketSender<S>
674where
675    S: AsyncRead + AsyncWrite + Unpin,
676{
677    fn poll_write(
678        self: Pin<&mut Self>,
679        cx: &mut Context<'_>,
680        buf: &[u8],
681    ) -> Poll<Result<usize, WsError>> {
682        let me = self.get_mut();
683        let mut ws = me.shared.lock();
684        ready!(ws.poll_ready(cx))?;
685        let len = buf.len();
686        ws.start_send(Message::binary(buf.to_owned()))?;
687        Poll::Ready(Ok(len))
688    }
689
690    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
691        self.shared.lock().poll_flush(cx)
692    }
693
694    fn poll_close(
695        self: Pin<&mut Self>,
696        cx: &mut Context<'_>,
697        msg: &mut Option<Message>,
698    ) -> Poll<Result<(), WsError>> {
699        let me = self.get_mut();
700        let mut ws = me.shared.lock();
701        send_helper(&mut ws, msg, cx)
702    }
703}
704
705/// The receiver part of a [websocket](WebSocketStream) stream.
706#[derive(Debug)]
707pub struct WebSocketReceiver<S> {
708    shared: Arc<Shared<S>>,
709}
710
711impl<S> WebSocketReceiver<S> {
712    /// Checks if this [receiver](WebSocketReceiver) and some [sender](WebSocketSender)
713    /// were split from the same [websocket](WebSocketStream) stream.
714    pub fn is_pair_of(&self, other: &WebSocketSender<S>) -> bool {
715        Arc::ptr_eq(&self.shared, &other.shared)
716    }
717}
718
719impl<S> Stream for WebSocketReceiver<S>
720where
721    S: AsyncRead + AsyncWrite + Unpin,
722{
723    type Item = Result<Message, WsError>;
724
725    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
726        self.shared.lock().poll_next(cx)
727    }
728}
729
730impl<S> FusedStream for WebSocketReceiver<S>
731where
732    S: AsyncRead + AsyncWrite + Unpin,
733{
734    fn is_terminated(&self) -> bool {
735        self.shared.lock().ended
736    }
737}
738
739#[derive(Debug)]
740struct Shared<S>(Mutex<WebSocketStream<S>>);
741
742impl<S> Shared<S> {
743    fn lock(&self) -> MutexGuard<'_, WebSocketStream<S>> {
744        self.0.lock().expect("lock shared stream")
745    }
746
747    fn into_inner(self) -> WebSocketStream<S> {
748        self.0.into_inner().expect("get shared stream")
749    }
750}
751
752#[cfg(any(
753    feature = "async-tls",
754    feature = "async-std-runtime",
755    feature = "tokio-runtime",
756    feature = "gio-runtime"
757))]
758/// Get a domain from an URL.
759#[inline]
760pub(crate) fn domain(
761    request: &tungstenite::handshake::client::Request,
762) -> Result<String, tungstenite::Error> {
763    request
764        .uri()
765        .host()
766        .map(|host| {
767            // If host is an IPv6 address, it might be surrounded by brackets. These brackets are
768            // *not* part of a valid IP, so they must be stripped out.
769            //
770            // The URI from the request is guaranteed to be valid, so we don't need a separate
771            // check for the closing bracket.
772            let host = if host.starts_with('[') {
773                &host[1..host.len() - 1]
774            } else {
775                host
776            };
777
778            host.to_owned()
779        })
780        .ok_or(tungstenite::Error::Url(
781            tungstenite::error::UrlError::NoHostName,
782        ))
783}
784
785#[cfg(any(
786    feature = "async-std-runtime",
787    feature = "tokio-runtime",
788    feature = "gio-runtime"
789))]
790/// Get the port from an URL.
791#[inline]
792pub(crate) fn port(
793    request: &tungstenite::handshake::client::Request,
794) -> Result<u16, tungstenite::Error> {
795    request
796        .uri()
797        .port_u16()
798        .or_else(|| match request.uri().scheme_str() {
799            Some("wss") => Some(443),
800            Some("ws") => Some(80),
801            _ => None,
802        })
803        .ok_or(tungstenite::Error::Url(
804            tungstenite::error::UrlError::UnsupportedUrlScheme,
805        ))
806}
807
808#[cfg(test)]
809mod tests {
810    #[cfg(any(
811        feature = "async-tls",
812        feature = "async-std-runtime",
813        feature = "tokio-runtime",
814        feature = "gio-runtime"
815    ))]
816    #[test]
817    fn domain_strips_ipv6_brackets() {
818        use tungstenite::client::IntoClientRequest;
819
820        let request = "ws://[::1]:80".into_client_request().unwrap();
821        assert_eq!(crate::domain(&request).unwrap(), "::1");
822    }
823
824    #[cfg(feature = "handshake")]
825    #[test]
826    fn requests_cannot_contain_invalid_uris() {
827        use tungstenite::client::IntoClientRequest;
828
829        assert!("ws://[".into_client_request().is_err());
830        assert!("ws://[blabla/bla".into_client_request().is_err());
831        assert!("ws://[::1/bla".into_client_request().is_err());
832    }
833}