1#![deny(
38    missing_docs,
39    unused_must_use,
40    unused_mut,
41    unused_imports,
42    unused_import_braces
43)]
44
45pub use tungstenite;
46
47mod compat;
48mod handshake;
49
50#[cfg(any(
51    feature = "async-tls",
52    feature = "async-native-tls",
53    feature = "tokio-native-tls",
54    feature = "tokio-rustls-manual-roots",
55    feature = "tokio-rustls-native-certs",
56    feature = "tokio-rustls-webpki-roots",
57    feature = "tokio-openssl",
58))]
59pub mod stream;
60
61use std::{
62    io::{Read, Write},
63    pin::Pin,
64    sync::{Arc, Mutex, MutexGuard},
65    task::{ready, Context, Poll},
66};
67
68use compat::{cvt, AllowStd, ContextWaker};
69use futures_core::stream::{FusedStream, Stream};
70use futures_io::{AsyncRead, AsyncWrite};
71use log::*;
72
73#[cfg(feature = "handshake")]
74use tungstenite::{
75    client::IntoClientRequest,
76    handshake::{
77        client::{ClientHandshake, Response},
78        server::{Callback, NoCallback},
79        HandshakeError,
80    },
81};
82use tungstenite::{
83    error::Error as WsError,
84    protocol::{Message, Role, WebSocket, WebSocketConfig},
85};
86
87#[cfg(feature = "async-std-runtime")]
88pub mod async_std;
89#[cfg(feature = "async-tls")]
90pub mod async_tls;
91#[cfg(feature = "gio-runtime")]
92pub mod gio;
93#[cfg(feature = "tokio-runtime")]
94pub mod tokio;
95
96pub mod bytes;
97pub use bytes::ByteReader;
98pub use bytes::ByteWriter;
99
100use tungstenite::protocol::CloseFrame;
101
102#[cfg(feature = "handshake")]
115pub async fn client_async<'a, R, S>(
116    request: R,
117    stream: S,
118) -> Result<(WebSocketStream<S>, Response), WsError>
119where
120    R: IntoClientRequest + Unpin,
121    S: AsyncRead + AsyncWrite + Unpin,
122{
123    client_async_with_config(request, stream, None).await
124}
125
126#[cfg(feature = "handshake")]
129pub async fn client_async_with_config<'a, R, S>(
130    request: R,
131    stream: S,
132    config: Option<WebSocketConfig>,
133) -> Result<(WebSocketStream<S>, Response), WsError>
134where
135    R: IntoClientRequest + Unpin,
136    S: AsyncRead + AsyncWrite + Unpin,
137{
138    let f = handshake::client_handshake(stream, move |allow_std| {
139        let request = request.into_client_request()?;
140        let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
141        cli_handshake.handshake()
142    });
143    f.await.map_err(|e| match e {
144        HandshakeError::Failure(e) => e,
145        e => WsError::Io(std::io::Error::new(
146            std::io::ErrorKind::Other,
147            e.to_string(),
148        )),
149    })
150}
151
152#[cfg(feature = "handshake")]
164pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
165where
166    S: AsyncRead + AsyncWrite + Unpin,
167{
168    accept_hdr_async(stream, NoCallback).await
169}
170
171#[cfg(feature = "handshake")]
174pub async fn accept_async_with_config<S>(
175    stream: S,
176    config: Option<WebSocketConfig>,
177) -> Result<WebSocketStream<S>, WsError>
178where
179    S: AsyncRead + AsyncWrite + Unpin,
180{
181    accept_hdr_async_with_config(stream, NoCallback, config).await
182}
183
184#[cfg(feature = "handshake")]
190pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
191where
192    S: AsyncRead + AsyncWrite + Unpin,
193    C: Callback + Unpin,
194{
195    accept_hdr_async_with_config(stream, callback, None).await
196}
197
198#[cfg(feature = "handshake")]
201pub async fn accept_hdr_async_with_config<S, C>(
202    stream: S,
203    callback: C,
204    config: Option<WebSocketConfig>,
205) -> Result<WebSocketStream<S>, WsError>
206where
207    S: AsyncRead + AsyncWrite + Unpin,
208    C: Callback + Unpin,
209{
210    let f = handshake::server_handshake(stream, move |allow_std| {
211        tungstenite::accept_hdr_with_config(allow_std, callback, config)
212    });
213    f.await.map_err(|e| match e {
214        HandshakeError::Failure(e) => e,
215        e => WsError::Io(std::io::Error::new(
216            std::io::ErrorKind::Other,
217            e.to_string(),
218        )),
219    })
220}
221
222#[derive(Debug)]
232pub struct WebSocketStream<S> {
233    inner: WebSocket<AllowStd<S>>,
234    #[cfg(feature = "futures-03-sink")]
235    closing: bool,
236    ended: bool,
237    ready: bool,
242}
243
244impl<S> WebSocketStream<S> {
245    pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
248    where
249        S: AsyncRead + AsyncWrite + Unpin,
250    {
251        handshake::without_handshake(stream, move |allow_std| {
252            WebSocket::from_raw_socket(allow_std, role, config)
253        })
254        .await
255    }
256
257    pub async fn from_partially_read(
260        stream: S,
261        part: Vec<u8>,
262        role: Role,
263        config: Option<WebSocketConfig>,
264    ) -> Self
265    where
266        S: AsyncRead + AsyncWrite + Unpin,
267    {
268        handshake::without_handshake(stream, move |allow_std| {
269            WebSocket::from_partially_read(allow_std, part, role, config)
270        })
271        .await
272    }
273
274    pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
275        Self {
276            inner: ws,
277            #[cfg(feature = "futures-03-sink")]
278            closing: false,
279            ended: false,
280            ready: true,
281        }
282    }
283
284    fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
285    where
286        F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
287        AllowStd<S>: Read + Write,
288    {
289        #[cfg(feature = "verbose-logging")]
290        trace!("{}:{} WebSocketStream.with_context", file!(), line!());
291        if let Some((kind, ctx)) = ctx {
292            self.inner.get_mut().set_waker(kind, ctx.waker());
293        }
294        f(&mut self.inner)
295    }
296
297    pub fn into_inner(self) -> S {
299        self.inner.into_inner().into_inner()
300    }
301
302    pub fn get_ref(&self) -> &S
304    where
305        S: AsyncRead + AsyncWrite + Unpin,
306    {
307        self.inner.get_ref().get_ref()
308    }
309
310    pub fn get_mut(&mut self) -> &mut S
312    where
313        S: AsyncRead + AsyncWrite + Unpin,
314    {
315        self.inner.get_mut().get_mut()
316    }
317
318    pub fn get_config(&self) -> &WebSocketConfig {
320        self.inner.get_config()
321    }
322
323    pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
325    where
326        S: AsyncRead + AsyncWrite + Unpin,
327    {
328        self.send(Message::Close(msg)).await
329    }
330
331    pub fn split(self) -> (WebSocketSender<S>, WebSocketReceiver<S>) {
334        let shared = Arc::new(Shared(Mutex::new(self)));
335        let sender = WebSocketSender {
336            shared: shared.clone(),
337        };
338
339        let receiver = WebSocketReceiver { shared };
340        (sender, receiver)
341    }
342
343    pub fn reunite(
348        sender: WebSocketSender<S>,
349        receiver: WebSocketReceiver<S>,
350    ) -> Result<Self, (WebSocketSender<S>, WebSocketReceiver<S>)> {
351        if sender.is_pair_of(&receiver) {
352            drop(receiver);
353            let stream = Arc::try_unwrap(sender.shared)
354                .ok()
355                .expect("reunite the stream")
356                .into_inner();
357
358            Ok(stream)
359        } else {
360            Err((sender, receiver))
361        }
362    }
363}
364
365impl<S> WebSocketStream<S>
366where
367    S: AsyncRead + AsyncWrite + Unpin,
368{
369    fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Message, WsError>>> {
370        #[cfg(feature = "verbose-logging")]
371        trace!("{}:{} WebSocketStream.poll_next", file!(), line!());
372
373        if self.ended {
377            return Poll::Ready(None);
378        }
379
380        match ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
381            #[cfg(feature = "verbose-logging")]
382            trace!(
383                "{}:{} WebSocketStream.with_context poll_next -> read()",
384                file!(),
385                line!()
386            );
387            cvt(s.read())
388        })) {
389            Ok(v) => Poll::Ready(Some(Ok(v))),
390            Err(e) => {
391                self.ended = true;
392                if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) {
393                    Poll::Ready(None)
394                } else {
395                    Poll::Ready(Some(Err(e)))
396                }
397            }
398        }
399    }
400
401    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
402        if self.ready {
403            return Poll::Ready(Ok(()));
404        }
405
406        self.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
408            .map(|r| {
409                self.ready = true;
410                r
411            })
412    }
413
414    fn start_send(&mut self, item: Message) -> Result<(), WsError> {
415        match self.with_context(None, |s| s.write(item)) {
416            Ok(()) => {
417                self.ready = true;
418                Ok(())
419            }
420            Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
421                self.ready = false;
424                Ok(())
425            }
426            Err(e) => {
427                self.ready = true;
428                debug!("websocket start_send error: {}", e);
429                Err(e)
430            }
431        }
432    }
433
434    fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
435        self.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
436            .map(|r| {
437                self.ready = true;
438                match r {
439                    Err(WsError::ConnectionClosed) => Ok(()),
441                    other => other,
442                }
443            })
444    }
445
446    #[cfg(feature = "futures-03-sink")]
447    fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
448        self.ready = true;
449        let res = if self.closing {
450            self.with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
452        } else {
453            self.with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
454        };
455
456        match res {
457            Ok(()) => Poll::Ready(Ok(())),
458            Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())),
459            Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
460                trace!("WouldBlock");
461                self.closing = true;
462                Poll::Pending
463            }
464            Err(err) => {
465                debug!("websocket close error: {}", err);
466                Poll::Ready(Err(err))
467            }
468        }
469    }
470}
471
472impl<S> Stream for WebSocketStream<S>
473where
474    S: AsyncRead + AsyncWrite + Unpin,
475{
476    type Item = Result<Message, WsError>;
477
478    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
479        self.get_mut().poll_next(cx)
480    }
481}
482
483impl<S> FusedStream for WebSocketStream<S>
484where
485    S: AsyncRead + AsyncWrite + Unpin,
486{
487    fn is_terminated(&self) -> bool {
488        self.ended
489    }
490}
491
492#[cfg(feature = "futures-03-sink")]
493impl<S> futures_util::Sink<Message> for WebSocketStream<S>
494where
495    S: AsyncRead + AsyncWrite + Unpin,
496{
497    type Error = WsError;
498
499    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
500        self.get_mut().poll_ready(cx)
501    }
502
503    fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
504        self.get_mut().start_send(item)
505    }
506
507    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
508        self.get_mut().poll_flush(cx)
509    }
510
511    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
512        self.get_mut().poll_close(cx)
513    }
514}
515
516#[cfg(not(feature = "futures-03-sink"))]
517impl<S> bytes::private::SealedSender for WebSocketStream<S>
518where
519    S: AsyncRead + AsyncWrite + Unpin,
520{
521    fn poll_write(
522        self: Pin<&mut Self>,
523        cx: &mut Context<'_>,
524        buf: &[u8],
525    ) -> Poll<Result<usize, WsError>> {
526        let me = self.get_mut();
527        ready!(me.poll_ready(cx))?;
528        let len = buf.len();
529        me.start_send(Message::binary(buf.to_owned()))?;
530        Poll::Ready(Ok(len))
531    }
532
533    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
534        self.get_mut().poll_flush(cx)
535    }
536
537    fn poll_close(
538        self: Pin<&mut Self>,
539        cx: &mut Context<'_>,
540        msg: &mut Option<Message>,
541    ) -> Poll<Result<(), WsError>> {
542        let me = self.get_mut();
543        send_helper(me, msg, cx)
544    }
545}
546
547impl<S> WebSocketStream<S> {
548    pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
550    where
551        S: AsyncRead + AsyncWrite + Unpin,
552    {
553        Send {
554            ws: self,
555            msg: Some(msg),
556        }
557        .await
558    }
559}
560
561struct Send<W> {
562    ws: W,
563    msg: Option<Message>,
564}
565
566fn send_helper<S>(
568    ws: &mut WebSocketStream<S>,
569    msg: &mut Option<Message>,
570    cx: &mut Context<'_>,
571) -> Poll<Result<(), WsError>>
572where
573    S: AsyncRead + AsyncWrite + Unpin,
574{
575    if msg.is_some() {
576        ready!(ws.poll_ready(cx))?;
577        let msg = msg.take().expect("unreachable");
578        ws.start_send(msg)?;
579    }
580
581    ws.poll_flush(cx)
582}
583
584impl<S> std::future::Future for Send<&mut WebSocketStream<S>>
585where
586    S: AsyncRead + AsyncWrite + Unpin,
587{
588    type Output = Result<(), WsError>;
589
590    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
591        let me = self.get_mut();
592        send_helper(me.ws, &mut me.msg, cx)
593    }
594}
595
596impl<S> std::future::Future for Send<&Shared<S>>
597where
598    S: AsyncRead + AsyncWrite + Unpin,
599{
600    type Output = Result<(), WsError>;
601
602    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
603        let me = self.get_mut();
604        let mut ws = me.ws.lock();
605        send_helper(&mut ws, &mut me.msg, cx)
606    }
607}
608
609#[derive(Debug)]
611pub struct WebSocketSender<S> {
612    shared: Arc<Shared<S>>,
613}
614
615impl<S> WebSocketSender<S> {
616    pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
618    where
619        S: AsyncRead + AsyncWrite + Unpin,
620    {
621        Send {
622            ws: &*self.shared,
623            msg: Some(msg),
624        }
625        .await
626    }
627
628    pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
630    where
631        S: AsyncRead + AsyncWrite + Unpin,
632    {
633        self.send(Message::Close(msg)).await
634    }
635
636    pub fn is_pair_of(&self, other: &WebSocketReceiver<S>) -> bool {
639        Arc::ptr_eq(&self.shared, &other.shared)
640    }
641}
642
643#[cfg(feature = "futures-03-sink")]
644impl<T> futures_util::Sink<Message> for WebSocketSender<T>
645where
646    T: AsyncRead + AsyncWrite + Unpin,
647{
648    type Error = WsError;
649
650    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
651        self.shared.lock().poll_ready(cx)
652    }
653
654    fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
655        self.shared.lock().start_send(item)
656    }
657
658    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
659        self.shared.lock().poll_flush(cx)
660    }
661
662    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
663        self.shared.lock().poll_close(cx)
664    }
665}
666
667#[cfg(not(feature = "futures-03-sink"))]
668impl<S> bytes::private::SealedSender for WebSocketSender<S>
669where
670    S: AsyncRead + AsyncWrite + Unpin,
671{
672    fn poll_write(
673        self: Pin<&mut Self>,
674        cx: &mut Context<'_>,
675        buf: &[u8],
676    ) -> Poll<Result<usize, WsError>> {
677        let me = self.get_mut();
678        let mut ws = me.shared.lock();
679        ready!(ws.poll_ready(cx))?;
680        let len = buf.len();
681        ws.start_send(Message::binary(buf.to_owned()))?;
682        Poll::Ready(Ok(len))
683    }
684
685    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
686        self.shared.lock().poll_flush(cx)
687    }
688
689    fn poll_close(
690        self: Pin<&mut Self>,
691        cx: &mut Context<'_>,
692        msg: &mut Option<Message>,
693    ) -> Poll<Result<(), WsError>> {
694        let me = self.get_mut();
695        let mut ws = me.shared.lock();
696        send_helper(&mut ws, msg, cx)
697    }
698}
699
700#[derive(Debug)]
702pub struct WebSocketReceiver<S> {
703    shared: Arc<Shared<S>>,
704}
705
706impl<S> WebSocketReceiver<S> {
707    pub fn is_pair_of(&self, other: &WebSocketSender<S>) -> bool {
710        Arc::ptr_eq(&self.shared, &other.shared)
711    }
712}
713
714impl<S> Stream for WebSocketReceiver<S>
715where
716    S: AsyncRead + AsyncWrite + Unpin,
717{
718    type Item = Result<Message, WsError>;
719
720    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
721        self.shared.lock().poll_next(cx)
722    }
723}
724
725impl<S> FusedStream for WebSocketReceiver<S>
726where
727    S: AsyncRead + AsyncWrite + Unpin,
728{
729    fn is_terminated(&self) -> bool {
730        self.shared.lock().ended
731    }
732}
733
734#[derive(Debug)]
735struct Shared<S>(Mutex<WebSocketStream<S>>);
736
737impl<S> Shared<S> {
738    fn lock(&self) -> MutexGuard<'_, WebSocketStream<S>> {
739        self.0.lock().expect("lock shared stream")
740    }
741
742    fn into_inner(self) -> WebSocketStream<S> {
743        self.0.into_inner().expect("get shared stream")
744    }
745}
746
747#[cfg(any(
748    feature = "async-tls",
749    feature = "async-std-runtime",
750    feature = "tokio-runtime",
751    feature = "gio-runtime"
752))]
753#[inline]
755pub(crate) fn domain(
756    request: &tungstenite::handshake::client::Request,
757) -> Result<String, tungstenite::Error> {
758    request
759        .uri()
760        .host()
761        .map(|host| {
762            let host = if host.starts_with('[') {
768                &host[1..host.len() - 1]
769            } else {
770                host
771            };
772
773            host.to_owned()
774        })
775        .ok_or(tungstenite::Error::Url(
776            tungstenite::error::UrlError::NoHostName,
777        ))
778}
779
780#[cfg(any(
781    feature = "async-std-runtime",
782    feature = "tokio-runtime",
783    feature = "gio-runtime"
784))]
785#[inline]
787pub(crate) fn port(
788    request: &tungstenite::handshake::client::Request,
789) -> Result<u16, tungstenite::Error> {
790    request
791        .uri()
792        .port_u16()
793        .or_else(|| match request.uri().scheme_str() {
794            Some("wss") => Some(443),
795            Some("ws") => Some(80),
796            _ => None,
797        })
798        .ok_or(tungstenite::Error::Url(
799            tungstenite::error::UrlError::UnsupportedUrlScheme,
800        ))
801}
802
803#[cfg(test)]
804mod tests {
805    #[cfg(any(
806        feature = "async-tls",
807        feature = "async-std-runtime",
808        feature = "tokio-runtime",
809        feature = "gio-runtime"
810    ))]
811    #[test]
812    fn domain_strips_ipv6_brackets() {
813        use tungstenite::client::IntoClientRequest;
814
815        let request = "ws://[::1]:80".into_client_request().unwrap();
816        assert_eq!(crate::domain(&request).unwrap(), "::1");
817    }
818
819    #[cfg(feature = "handshake")]
820    #[test]
821    fn requests_cannot_contain_invalid_uris() {
822        use tungstenite::client::IntoClientRequest;
823
824        assert!("ws://[".into_client_request().is_err());
825        assert!("ws://[blabla/bla".into_client_request().is_err());
826        assert!("ws://[::1/bla".into_client_request().is_err());
827    }
828}