1#![deny(
42 missing_docs,
43 unused_must_use,
44 unused_mut,
45 unused_imports,
46 unused_import_braces
47)]
48
49pub use tungstenite;
50
51mod compat;
52mod handshake;
53
54#[cfg(any(
55 feature = "async-tls",
56 feature = "async-native-tls",
57 feature = "smol-native-tls",
58 feature = "futures-rustls-manual-roots",
59 feature = "futures-rustls-webpki-roots",
60 feature = "futures-rustls-native-certs",
61 feature = "futures-rustls-platform-verifier",
62 feature = "tokio-native-tls",
63 feature = "tokio-rustls-manual-roots",
64 feature = "tokio-rustls-native-certs",
65 feature = "tokio-rustls-platform-verifier",
66 feature = "tokio-rustls-webpki-roots",
67 feature = "tokio-openssl",
68))]
69pub mod stream;
70
71use std::{
72 io::{Read, Write},
73 pin::Pin,
74 sync::{Arc, Mutex, MutexGuard},
75 task::{ready, Context, Poll},
76};
77
78use compat::{cvt, AllowStd, ContextWaker};
79use futures_core::stream::{FusedStream, Stream};
80use futures_io::{AsyncRead, AsyncWrite};
81use log::*;
82
83#[cfg(feature = "handshake")]
84use tungstenite::{
85 client::IntoClientRequest,
86 handshake::{
87 client::{ClientHandshake, Response},
88 server::{Callback, NoCallback},
89 HandshakeError,
90 },
91};
92use tungstenite::{
93 error::Error as WsError,
94 protocol::{Message, Role, WebSocket, WebSocketConfig},
95};
96
97#[cfg(feature = "async-std-runtime")]
98#[deprecated = "async-std is unmaintained upstream. Please use the smol runtime instead."]
99pub mod async_std;
100#[cfg(feature = "async-tls")]
101pub mod async_tls;
102#[cfg(feature = "gio-runtime")]
103pub mod gio;
104#[cfg(feature = "smol-runtime")]
105pub mod smol;
106#[cfg(feature = "tokio-runtime")]
107pub mod tokio;
108
109pub mod bytes;
110pub use bytes::ByteReader;
111pub use bytes::ByteWriter;
112
113use tungstenite::protocol::CloseFrame;
114
115#[cfg(feature = "handshake")]
128pub async fn client_async<'a, R, S>(
129 request: R,
130 stream: S,
131) -> Result<(WebSocketStream<S>, Response), WsError>
132where
133 R: IntoClientRequest + Unpin,
134 S: AsyncRead + AsyncWrite + Unpin,
135{
136 client_async_with_config(request, stream, None).await
137}
138
139#[cfg(feature = "handshake")]
142pub async fn client_async_with_config<'a, R, S>(
143 request: R,
144 stream: S,
145 config: Option<WebSocketConfig>,
146) -> Result<(WebSocketStream<S>, Response), WsError>
147where
148 R: IntoClientRequest + Unpin,
149 S: AsyncRead + AsyncWrite + Unpin,
150{
151 let f = handshake::client_handshake(stream, move |allow_std| {
152 let request = request.into_client_request()?;
153 let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
154 cli_handshake.handshake()
155 });
156 f.await.map_err(|e| match e {
157 HandshakeError::Failure(e) => e,
158 e => WsError::Io(std::io::Error::new(
159 std::io::ErrorKind::Other,
160 e.to_string(),
161 )),
162 })
163}
164
165#[cfg(feature = "handshake")]
177pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
178where
179 S: AsyncRead + AsyncWrite + Unpin,
180{
181 accept_hdr_async(stream, NoCallback).await
182}
183
184#[cfg(feature = "handshake")]
187pub async fn accept_async_with_config<S>(
188 stream: S,
189 config: Option<WebSocketConfig>,
190) -> Result<WebSocketStream<S>, WsError>
191where
192 S: AsyncRead + AsyncWrite + Unpin,
193{
194 accept_hdr_async_with_config(stream, NoCallback, config).await
195}
196
197#[cfg(feature = "handshake")]
203pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
204where
205 S: AsyncRead + AsyncWrite + Unpin,
206 C: Callback + Unpin,
207{
208 accept_hdr_async_with_config(stream, callback, None).await
209}
210
211#[cfg(feature = "handshake")]
214pub async fn accept_hdr_async_with_config<S, C>(
215 stream: S,
216 callback: C,
217 config: Option<WebSocketConfig>,
218) -> Result<WebSocketStream<S>, WsError>
219where
220 S: AsyncRead + AsyncWrite + Unpin,
221 C: Callback + Unpin,
222{
223 let f = handshake::server_handshake(stream, move |allow_std| {
224 tungstenite::accept_hdr_with_config(allow_std, callback, config)
225 });
226 f.await.map_err(|e| match e {
227 HandshakeError::Failure(e) => e,
228 e => WsError::Io(std::io::Error::new(
229 std::io::ErrorKind::Other,
230 e.to_string(),
231 )),
232 })
233}
234
235#[derive(Debug)]
245pub struct WebSocketStream<S> {
246 inner: WebSocket<AllowStd<S>>,
247 #[cfg(feature = "futures-03-sink")]
248 closing: bool,
249 ended: bool,
250 ready: bool,
255}
256
257impl<S> WebSocketStream<S> {
258 pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
261 where
262 S: AsyncRead + AsyncWrite + Unpin,
263 {
264 handshake::without_handshake(stream, move |allow_std| {
265 WebSocket::from_raw_socket(allow_std, role, config)
266 })
267 .await
268 }
269
270 pub async fn from_partially_read(
273 stream: S,
274 part: Vec<u8>,
275 role: Role,
276 config: Option<WebSocketConfig>,
277 ) -> Self
278 where
279 S: AsyncRead + AsyncWrite + Unpin,
280 {
281 handshake::without_handshake(stream, move |allow_std| {
282 WebSocket::from_partially_read(allow_std, part, role, config)
283 })
284 .await
285 }
286
287 pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
288 Self {
289 inner: ws,
290 #[cfg(feature = "futures-03-sink")]
291 closing: false,
292 ended: false,
293 ready: true,
294 }
295 }
296
297 fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
298 where
299 F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
300 AllowStd<S>: Read + Write,
301 {
302 #[cfg(feature = "verbose-logging")]
303 trace!("{}:{} WebSocketStream.with_context", file!(), line!());
304 if let Some((kind, ctx)) = ctx {
305 self.inner.get_mut().set_waker(kind, ctx.waker());
306 }
307 f(&mut self.inner)
308 }
309
310 pub fn into_inner(self) -> S {
312 self.inner.into_inner().into_inner()
313 }
314
315 pub fn get_ref(&self) -> &S
317 where
318 S: AsyncRead + AsyncWrite + Unpin,
319 {
320 self.inner.get_ref().get_ref()
321 }
322
323 pub fn get_mut(&mut self) -> &mut S
325 where
326 S: AsyncRead + AsyncWrite + Unpin,
327 {
328 self.inner.get_mut().get_mut()
329 }
330
331 pub fn get_config(&self) -> &WebSocketConfig {
333 self.inner.get_config()
334 }
335
336 pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
338 where
339 S: AsyncRead + AsyncWrite + Unpin,
340 {
341 self.send(Message::Close(msg)).await
342 }
343
344 pub fn split(self) -> (WebSocketSender<S>, WebSocketReceiver<S>) {
347 let shared = Arc::new(Shared(Mutex::new(self)));
348 let sender = WebSocketSender {
349 shared: shared.clone(),
350 };
351
352 let receiver = WebSocketReceiver { shared };
353 (sender, receiver)
354 }
355
356 pub fn reunite(
361 sender: WebSocketSender<S>,
362 receiver: WebSocketReceiver<S>,
363 ) -> Result<Self, (WebSocketSender<S>, WebSocketReceiver<S>)> {
364 if sender.is_pair_of(&receiver) {
365 drop(receiver);
366 let stream = Arc::try_unwrap(sender.shared)
367 .ok()
368 .expect("reunite the stream")
369 .into_inner();
370
371 Ok(stream)
372 } else {
373 Err((sender, receiver))
374 }
375 }
376}
377
378impl<S> WebSocketStream<S>
379where
380 S: AsyncRead + AsyncWrite + Unpin,
381{
382 fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Message, WsError>>> {
383 #[cfg(feature = "verbose-logging")]
384 trace!("{}:{} WebSocketStream.poll_next", file!(), line!());
385
386 if self.ended {
390 return Poll::Ready(None);
391 }
392
393 match ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
394 #[cfg(feature = "verbose-logging")]
395 trace!(
396 "{}:{} WebSocketStream.with_context poll_next -> read()",
397 file!(),
398 line!()
399 );
400 cvt(s.read())
401 })) {
402 Ok(v) => Poll::Ready(Some(Ok(v))),
403 Err(e) => {
404 self.ended = true;
405 if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) {
406 Poll::Ready(None)
407 } else {
408 Poll::Ready(Some(Err(e)))
409 }
410 }
411 }
412 }
413
414 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
415 if self.ready {
416 return Poll::Ready(Ok(()));
417 }
418
419 self.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
421 .map(|r| {
422 self.ready = true;
423 r
424 })
425 }
426
427 fn start_send(&mut self, item: Message) -> Result<(), WsError> {
428 match self.with_context(None, |s| s.write(item)) {
429 Ok(()) => {
430 self.ready = true;
431 Ok(())
432 }
433 Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
434 self.ready = false;
437 Ok(())
438 }
439 Err(e) => {
440 self.ready = true;
441 debug!("websocket start_send error: {}", e);
442 Err(e)
443 }
444 }
445 }
446
447 fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
448 self.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
449 .map(|r| {
450 self.ready = true;
451 match r {
452 Err(WsError::ConnectionClosed) => Ok(()),
454 other => other,
455 }
456 })
457 }
458
459 #[cfg(feature = "futures-03-sink")]
460 fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
461 self.ready = true;
462 let res = if self.closing {
463 self.with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
465 } else {
466 self.with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
467 };
468
469 match res {
470 Ok(()) => Poll::Ready(Ok(())),
471 Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())),
472 Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
473 trace!("WouldBlock");
474 self.closing = true;
475 Poll::Pending
476 }
477 Err(err) => {
478 debug!("websocket close error: {}", err);
479 Poll::Ready(Err(err))
480 }
481 }
482 }
483}
484
485impl<S> Stream for WebSocketStream<S>
486where
487 S: AsyncRead + AsyncWrite + Unpin,
488{
489 type Item = Result<Message, WsError>;
490
491 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
492 self.get_mut().poll_next(cx)
493 }
494}
495
496impl<S> FusedStream for WebSocketStream<S>
497where
498 S: AsyncRead + AsyncWrite + Unpin,
499{
500 fn is_terminated(&self) -> bool {
501 self.ended
502 }
503}
504
505#[cfg(feature = "futures-03-sink")]
506impl<S> futures_util::Sink<Message> for WebSocketStream<S>
507where
508 S: AsyncRead + AsyncWrite + Unpin,
509{
510 type Error = WsError;
511
512 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
513 self.get_mut().poll_ready(cx)
514 }
515
516 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
517 self.get_mut().start_send(item)
518 }
519
520 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
521 self.get_mut().poll_flush(cx)
522 }
523
524 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
525 self.get_mut().poll_close(cx)
526 }
527}
528
529#[cfg(not(feature = "futures-03-sink"))]
530impl<S> bytes::private::SealedSender for WebSocketStream<S>
531where
532 S: AsyncRead + AsyncWrite + Unpin,
533{
534 fn poll_write(
535 self: Pin<&mut Self>,
536 cx: &mut Context<'_>,
537 buf: &[u8],
538 ) -> Poll<Result<usize, WsError>> {
539 let me = self.get_mut();
540 ready!(me.poll_ready(cx))?;
541 let len = buf.len();
542 me.start_send(Message::binary(buf.to_owned()))?;
543 Poll::Ready(Ok(len))
544 }
545
546 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
547 self.get_mut().poll_flush(cx)
548 }
549
550 fn poll_close(
551 self: Pin<&mut Self>,
552 cx: &mut Context<'_>,
553 msg: &mut Option<Message>,
554 ) -> Poll<Result<(), WsError>> {
555 let me = self.get_mut();
556 send_helper(me, msg, cx)
557 }
558}
559
560impl<S> WebSocketStream<S> {
561 pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
563 where
564 S: AsyncRead + AsyncWrite + Unpin,
565 {
566 Send {
567 ws: self,
568 msg: Some(msg),
569 }
570 .await
571 }
572}
573
574struct Send<W> {
575 ws: W,
576 msg: Option<Message>,
577}
578
579fn send_helper<S>(
581 ws: &mut WebSocketStream<S>,
582 msg: &mut Option<Message>,
583 cx: &mut Context<'_>,
584) -> Poll<Result<(), WsError>>
585where
586 S: AsyncRead + AsyncWrite + Unpin,
587{
588 if msg.is_some() {
589 ready!(ws.poll_ready(cx))?;
590 let msg = msg.take().expect("unreachable");
591 ws.start_send(msg)?;
592 }
593
594 ws.poll_flush(cx)
595}
596
597impl<S> std::future::Future for Send<&mut WebSocketStream<S>>
598where
599 S: AsyncRead + AsyncWrite + Unpin,
600{
601 type Output = Result<(), WsError>;
602
603 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
604 let me = self.get_mut();
605 send_helper(me.ws, &mut me.msg, cx)
606 }
607}
608
609impl<S> std::future::Future for Send<&Shared<S>>
610where
611 S: AsyncRead + AsyncWrite + Unpin,
612{
613 type Output = Result<(), WsError>;
614
615 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
616 let me = self.get_mut();
617 let mut ws = me.ws.lock();
618 send_helper(&mut ws, &mut me.msg, cx)
619 }
620}
621
622#[derive(Debug)]
624pub struct WebSocketSender<S> {
625 shared: Arc<Shared<S>>,
626}
627
628impl<S> WebSocketSender<S> {
629 pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
631 where
632 S: AsyncRead + AsyncWrite + Unpin,
633 {
634 Send {
635 ws: &*self.shared,
636 msg: Some(msg),
637 }
638 .await
639 }
640
641 pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
643 where
644 S: AsyncRead + AsyncWrite + Unpin,
645 {
646 self.send(Message::Close(msg)).await
647 }
648
649 pub fn is_pair_of(&self, other: &WebSocketReceiver<S>) -> bool {
652 Arc::ptr_eq(&self.shared, &other.shared)
653 }
654}
655
656#[cfg(feature = "futures-03-sink")]
657impl<T> futures_util::Sink<Message> for WebSocketSender<T>
658where
659 T: AsyncRead + AsyncWrite + Unpin,
660{
661 type Error = WsError;
662
663 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
664 self.shared.lock().poll_ready(cx)
665 }
666
667 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
668 self.shared.lock().start_send(item)
669 }
670
671 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
672 self.shared.lock().poll_flush(cx)
673 }
674
675 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
676 self.shared.lock().poll_close(cx)
677 }
678}
679
680#[cfg(not(feature = "futures-03-sink"))]
681impl<S> bytes::private::SealedSender for WebSocketSender<S>
682where
683 S: AsyncRead + AsyncWrite + Unpin,
684{
685 fn poll_write(
686 self: Pin<&mut Self>,
687 cx: &mut Context<'_>,
688 buf: &[u8],
689 ) -> Poll<Result<usize, WsError>> {
690 let me = self.get_mut();
691 let mut ws = me.shared.lock();
692 ready!(ws.poll_ready(cx))?;
693 let len = buf.len();
694 ws.start_send(Message::binary(buf.to_owned()))?;
695 Poll::Ready(Ok(len))
696 }
697
698 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
699 self.shared.lock().poll_flush(cx)
700 }
701
702 fn poll_close(
703 self: Pin<&mut Self>,
704 cx: &mut Context<'_>,
705 msg: &mut Option<Message>,
706 ) -> Poll<Result<(), WsError>> {
707 let me = self.get_mut();
708 let mut ws = me.shared.lock();
709 send_helper(&mut ws, msg, cx)
710 }
711}
712
713#[derive(Debug)]
715pub struct WebSocketReceiver<S> {
716 shared: Arc<Shared<S>>,
717}
718
719impl<S> WebSocketReceiver<S> {
720 pub fn is_pair_of(&self, other: &WebSocketSender<S>) -> bool {
723 Arc::ptr_eq(&self.shared, &other.shared)
724 }
725}
726
727impl<S> Stream for WebSocketReceiver<S>
728where
729 S: AsyncRead + AsyncWrite + Unpin,
730{
731 type Item = Result<Message, WsError>;
732
733 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
734 self.shared.lock().poll_next(cx)
735 }
736}
737
738impl<S> FusedStream for WebSocketReceiver<S>
739where
740 S: AsyncRead + AsyncWrite + Unpin,
741{
742 fn is_terminated(&self) -> bool {
743 self.shared.lock().ended
744 }
745}
746
747#[derive(Debug)]
748struct Shared<S>(Mutex<WebSocketStream<S>>);
749
750impl<S> Shared<S> {
751 fn lock(&self) -> MutexGuard<'_, WebSocketStream<S>> {
752 self.0.lock().expect("lock shared stream")
753 }
754
755 fn into_inner(self) -> WebSocketStream<S> {
756 self.0.into_inner().expect("get shared stream")
757 }
758}
759
760#[cfg(any(
761 feature = "async-tls",
762 feature = "async-std-runtime",
763 feature = "smol-runtime",
764 feature = "tokio-runtime",
765 feature = "gio-runtime"
766))]
767#[inline]
769pub(crate) fn domain(
770 request: &tungstenite::handshake::client::Request,
771) -> Result<String, tungstenite::Error> {
772 request
773 .uri()
774 .host()
775 .map(|host| {
776 let host = if host.starts_with('[') {
782 &host[1..host.len() - 1]
783 } else {
784 host
785 };
786
787 host.to_owned()
788 })
789 .ok_or(tungstenite::Error::Url(
790 tungstenite::error::UrlError::NoHostName,
791 ))
792}
793
794#[cfg(any(
795 feature = "async-std-runtime",
796 feature = "smol-runtime",
797 feature = "tokio-runtime",
798 feature = "gio-runtime"
799))]
800#[inline]
802pub(crate) fn port(
803 request: &tungstenite::handshake::client::Request,
804) -> Result<u16, tungstenite::Error> {
805 request
806 .uri()
807 .port_u16()
808 .or_else(|| match request.uri().scheme_str() {
809 Some("wss") => Some(443),
810 Some("ws") => Some(80),
811 _ => None,
812 })
813 .ok_or(tungstenite::Error::Url(
814 tungstenite::error::UrlError::UnsupportedUrlScheme,
815 ))
816}
817
818#[cfg(test)]
819mod tests {
820 #[cfg(any(
821 feature = "async-tls",
822 feature = "async-std-runtime",
823 feature = "smol-runtime",
824 feature = "tokio-runtime",
825 feature = "gio-runtime"
826 ))]
827 #[test]
828 fn domain_strips_ipv6_brackets() {
829 use tungstenite::client::IntoClientRequest;
830
831 let request = "ws://[::1]:80".into_client_request().unwrap();
832 assert_eq!(crate::domain(&request).unwrap(), "::1");
833 }
834
835 #[cfg(feature = "handshake")]
836 #[test]
837 fn requests_cannot_contain_invalid_uris() {
838 use tungstenite::client::IntoClientRequest;
839
840 assert!("ws://[".into_client_request().is_err());
841 assert!("ws://[blabla/bla".into_client_request().is_err());
842 assert!("ws://[::1/bla".into_client_request().is_err());
843 }
844}