1#![deny(
42 missing_docs,
43 unused_must_use,
44 unused_mut,
45 unused_imports,
46 unused_import_braces
47)]
48
49pub use tungstenite;
50
51mod compat;
52mod handshake;
53
54#[cfg(any(
55 feature = "async-tls",
56 feature = "async-native-tls",
57 feature = "tokio-native-tls",
58 feature = "tokio-rustls-manual-roots",
59 feature = "tokio-rustls-native-certs",
60 feature = "tokio-rustls-platform-verifier",
61 feature = "tokio-rustls-webpki-roots",
62 feature = "tokio-openssl",
63))]
64pub mod stream;
65
66use std::{
67 io::{Read, Write},
68 pin::Pin,
69 sync::{Arc, Mutex, MutexGuard},
70 task::{ready, Context, Poll},
71};
72
73use compat::{cvt, AllowStd, ContextWaker};
74use futures_core::stream::{FusedStream, Stream};
75use futures_io::{AsyncRead, AsyncWrite};
76use log::*;
77
78#[cfg(feature = "handshake")]
79use tungstenite::{
80 client::IntoClientRequest,
81 handshake::{
82 client::{ClientHandshake, Response},
83 server::{Callback, NoCallback},
84 HandshakeError,
85 },
86};
87use tungstenite::{
88 error::Error as WsError,
89 protocol::{Message, Role, WebSocket, WebSocketConfig},
90};
91
92#[cfg(feature = "async-std-runtime")]
93pub mod async_std;
94#[cfg(feature = "async-tls")]
95pub mod async_tls;
96#[cfg(feature = "gio-runtime")]
97pub mod gio;
98#[cfg(feature = "tokio-runtime")]
99pub mod tokio;
100
101pub mod bytes;
102pub use bytes::ByteReader;
103pub use bytes::ByteWriter;
104
105use tungstenite::protocol::CloseFrame;
106
107#[cfg(feature = "handshake")]
120pub async fn client_async<'a, R, S>(
121 request: R,
122 stream: S,
123) -> Result<(WebSocketStream<S>, Response), WsError>
124where
125 R: IntoClientRequest + Unpin,
126 S: AsyncRead + AsyncWrite + Unpin,
127{
128 client_async_with_config(request, stream, None).await
129}
130
131#[cfg(feature = "handshake")]
134pub async fn client_async_with_config<'a, R, S>(
135 request: R,
136 stream: S,
137 config: Option<WebSocketConfig>,
138) -> Result<(WebSocketStream<S>, Response), WsError>
139where
140 R: IntoClientRequest + Unpin,
141 S: AsyncRead + AsyncWrite + Unpin,
142{
143 let f = handshake::client_handshake(stream, move |allow_std| {
144 let request = request.into_client_request()?;
145 let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
146 cli_handshake.handshake()
147 });
148 f.await.map_err(|e| match e {
149 HandshakeError::Failure(e) => e,
150 e => WsError::Io(std::io::Error::new(
151 std::io::ErrorKind::Other,
152 e.to_string(),
153 )),
154 })
155}
156
157#[cfg(feature = "handshake")]
169pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
170where
171 S: AsyncRead + AsyncWrite + Unpin,
172{
173 accept_hdr_async(stream, NoCallback).await
174}
175
176#[cfg(feature = "handshake")]
179pub async fn accept_async_with_config<S>(
180 stream: S,
181 config: Option<WebSocketConfig>,
182) -> Result<WebSocketStream<S>, WsError>
183where
184 S: AsyncRead + AsyncWrite + Unpin,
185{
186 accept_hdr_async_with_config(stream, NoCallback, config).await
187}
188
189#[cfg(feature = "handshake")]
195pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
196where
197 S: AsyncRead + AsyncWrite + Unpin,
198 C: Callback + Unpin,
199{
200 accept_hdr_async_with_config(stream, callback, None).await
201}
202
203#[cfg(feature = "handshake")]
206pub async fn accept_hdr_async_with_config<S, C>(
207 stream: S,
208 callback: C,
209 config: Option<WebSocketConfig>,
210) -> Result<WebSocketStream<S>, WsError>
211where
212 S: AsyncRead + AsyncWrite + Unpin,
213 C: Callback + Unpin,
214{
215 let f = handshake::server_handshake(stream, move |allow_std| {
216 tungstenite::accept_hdr_with_config(allow_std, callback, config)
217 });
218 f.await.map_err(|e| match e {
219 HandshakeError::Failure(e) => e,
220 e => WsError::Io(std::io::Error::new(
221 std::io::ErrorKind::Other,
222 e.to_string(),
223 )),
224 })
225}
226
227#[derive(Debug)]
237pub struct WebSocketStream<S> {
238 inner: WebSocket<AllowStd<S>>,
239 #[cfg(feature = "futures-03-sink")]
240 closing: bool,
241 ended: bool,
242 ready: bool,
247}
248
249impl<S> WebSocketStream<S> {
250 pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
253 where
254 S: AsyncRead + AsyncWrite + Unpin,
255 {
256 handshake::without_handshake(stream, move |allow_std| {
257 WebSocket::from_raw_socket(allow_std, role, config)
258 })
259 .await
260 }
261
262 pub async fn from_partially_read(
265 stream: S,
266 part: Vec<u8>,
267 role: Role,
268 config: Option<WebSocketConfig>,
269 ) -> Self
270 where
271 S: AsyncRead + AsyncWrite + Unpin,
272 {
273 handshake::without_handshake(stream, move |allow_std| {
274 WebSocket::from_partially_read(allow_std, part, role, config)
275 })
276 .await
277 }
278
279 pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
280 Self {
281 inner: ws,
282 #[cfg(feature = "futures-03-sink")]
283 closing: false,
284 ended: false,
285 ready: true,
286 }
287 }
288
289 fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
290 where
291 F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
292 AllowStd<S>: Read + Write,
293 {
294 #[cfg(feature = "verbose-logging")]
295 trace!("{}:{} WebSocketStream.with_context", file!(), line!());
296 if let Some((kind, ctx)) = ctx {
297 self.inner.get_mut().set_waker(kind, ctx.waker());
298 }
299 f(&mut self.inner)
300 }
301
302 pub fn into_inner(self) -> S {
304 self.inner.into_inner().into_inner()
305 }
306
307 pub fn get_ref(&self) -> &S
309 where
310 S: AsyncRead + AsyncWrite + Unpin,
311 {
312 self.inner.get_ref().get_ref()
313 }
314
315 pub fn get_mut(&mut self) -> &mut S
317 where
318 S: AsyncRead + AsyncWrite + Unpin,
319 {
320 self.inner.get_mut().get_mut()
321 }
322
323 pub fn get_config(&self) -> &WebSocketConfig {
325 self.inner.get_config()
326 }
327
328 pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
330 where
331 S: AsyncRead + AsyncWrite + Unpin,
332 {
333 self.send(Message::Close(msg)).await
334 }
335
336 pub fn split(self) -> (WebSocketSender<S>, WebSocketReceiver<S>) {
339 let shared = Arc::new(Shared(Mutex::new(self)));
340 let sender = WebSocketSender {
341 shared: shared.clone(),
342 };
343
344 let receiver = WebSocketReceiver { shared };
345 (sender, receiver)
346 }
347
348 pub fn reunite(
353 sender: WebSocketSender<S>,
354 receiver: WebSocketReceiver<S>,
355 ) -> Result<Self, (WebSocketSender<S>, WebSocketReceiver<S>)> {
356 if sender.is_pair_of(&receiver) {
357 drop(receiver);
358 let stream = Arc::try_unwrap(sender.shared)
359 .ok()
360 .expect("reunite the stream")
361 .into_inner();
362
363 Ok(stream)
364 } else {
365 Err((sender, receiver))
366 }
367 }
368}
369
370impl<S> WebSocketStream<S>
371where
372 S: AsyncRead + AsyncWrite + Unpin,
373{
374 fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Message, WsError>>> {
375 #[cfg(feature = "verbose-logging")]
376 trace!("{}:{} WebSocketStream.poll_next", file!(), line!());
377
378 if self.ended {
382 return Poll::Ready(None);
383 }
384
385 match ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
386 #[cfg(feature = "verbose-logging")]
387 trace!(
388 "{}:{} WebSocketStream.with_context poll_next -> read()",
389 file!(),
390 line!()
391 );
392 cvt(s.read())
393 })) {
394 Ok(v) => Poll::Ready(Some(Ok(v))),
395 Err(e) => {
396 self.ended = true;
397 if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) {
398 Poll::Ready(None)
399 } else {
400 Poll::Ready(Some(Err(e)))
401 }
402 }
403 }
404 }
405
406 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
407 if self.ready {
408 return Poll::Ready(Ok(()));
409 }
410
411 self.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
413 .map(|r| {
414 self.ready = true;
415 r
416 })
417 }
418
419 fn start_send(&mut self, item: Message) -> Result<(), WsError> {
420 match self.with_context(None, |s| s.write(item)) {
421 Ok(()) => {
422 self.ready = true;
423 Ok(())
424 }
425 Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
426 self.ready = false;
429 Ok(())
430 }
431 Err(e) => {
432 self.ready = true;
433 debug!("websocket start_send error: {}", e);
434 Err(e)
435 }
436 }
437 }
438
439 fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
440 self.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
441 .map(|r| {
442 self.ready = true;
443 match r {
444 Err(WsError::ConnectionClosed) => Ok(()),
446 other => other,
447 }
448 })
449 }
450
451 #[cfg(feature = "futures-03-sink")]
452 fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
453 self.ready = true;
454 let res = if self.closing {
455 self.with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
457 } else {
458 self.with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
459 };
460
461 match res {
462 Ok(()) => Poll::Ready(Ok(())),
463 Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())),
464 Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
465 trace!("WouldBlock");
466 self.closing = true;
467 Poll::Pending
468 }
469 Err(err) => {
470 debug!("websocket close error: {}", err);
471 Poll::Ready(Err(err))
472 }
473 }
474 }
475}
476
477impl<S> Stream for WebSocketStream<S>
478where
479 S: AsyncRead + AsyncWrite + Unpin,
480{
481 type Item = Result<Message, WsError>;
482
483 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
484 self.get_mut().poll_next(cx)
485 }
486}
487
488impl<S> FusedStream for WebSocketStream<S>
489where
490 S: AsyncRead + AsyncWrite + Unpin,
491{
492 fn is_terminated(&self) -> bool {
493 self.ended
494 }
495}
496
497#[cfg(feature = "futures-03-sink")]
498impl<S> futures_util::Sink<Message> for WebSocketStream<S>
499where
500 S: AsyncRead + AsyncWrite + Unpin,
501{
502 type Error = WsError;
503
504 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
505 self.get_mut().poll_ready(cx)
506 }
507
508 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
509 self.get_mut().start_send(item)
510 }
511
512 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
513 self.get_mut().poll_flush(cx)
514 }
515
516 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
517 self.get_mut().poll_close(cx)
518 }
519}
520
521#[cfg(not(feature = "futures-03-sink"))]
522impl<S> bytes::private::SealedSender for WebSocketStream<S>
523where
524 S: AsyncRead + AsyncWrite + Unpin,
525{
526 fn poll_write(
527 self: Pin<&mut Self>,
528 cx: &mut Context<'_>,
529 buf: &[u8],
530 ) -> Poll<Result<usize, WsError>> {
531 let me = self.get_mut();
532 ready!(me.poll_ready(cx))?;
533 let len = buf.len();
534 me.start_send(Message::binary(buf.to_owned()))?;
535 Poll::Ready(Ok(len))
536 }
537
538 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
539 self.get_mut().poll_flush(cx)
540 }
541
542 fn poll_close(
543 self: Pin<&mut Self>,
544 cx: &mut Context<'_>,
545 msg: &mut Option<Message>,
546 ) -> Poll<Result<(), WsError>> {
547 let me = self.get_mut();
548 send_helper(me, msg, cx)
549 }
550}
551
552impl<S> WebSocketStream<S> {
553 pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
555 where
556 S: AsyncRead + AsyncWrite + Unpin,
557 {
558 Send {
559 ws: self,
560 msg: Some(msg),
561 }
562 .await
563 }
564}
565
566struct Send<W> {
567 ws: W,
568 msg: Option<Message>,
569}
570
571fn send_helper<S>(
573 ws: &mut WebSocketStream<S>,
574 msg: &mut Option<Message>,
575 cx: &mut Context<'_>,
576) -> Poll<Result<(), WsError>>
577where
578 S: AsyncRead + AsyncWrite + Unpin,
579{
580 if msg.is_some() {
581 ready!(ws.poll_ready(cx))?;
582 let msg = msg.take().expect("unreachable");
583 ws.start_send(msg)?;
584 }
585
586 ws.poll_flush(cx)
587}
588
589impl<S> std::future::Future for Send<&mut WebSocketStream<S>>
590where
591 S: AsyncRead + AsyncWrite + Unpin,
592{
593 type Output = Result<(), WsError>;
594
595 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
596 let me = self.get_mut();
597 send_helper(me.ws, &mut me.msg, cx)
598 }
599}
600
601impl<S> std::future::Future for Send<&Shared<S>>
602where
603 S: AsyncRead + AsyncWrite + Unpin,
604{
605 type Output = Result<(), WsError>;
606
607 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
608 let me = self.get_mut();
609 let mut ws = me.ws.lock();
610 send_helper(&mut ws, &mut me.msg, cx)
611 }
612}
613
614#[derive(Debug)]
616pub struct WebSocketSender<S> {
617 shared: Arc<Shared<S>>,
618}
619
620impl<S> WebSocketSender<S> {
621 pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
623 where
624 S: AsyncRead + AsyncWrite + Unpin,
625 {
626 Send {
627 ws: &*self.shared,
628 msg: Some(msg),
629 }
630 .await
631 }
632
633 pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
635 where
636 S: AsyncRead + AsyncWrite + Unpin,
637 {
638 self.send(Message::Close(msg)).await
639 }
640
641 pub fn is_pair_of(&self, other: &WebSocketReceiver<S>) -> bool {
644 Arc::ptr_eq(&self.shared, &other.shared)
645 }
646}
647
648#[cfg(feature = "futures-03-sink")]
649impl<T> futures_util::Sink<Message> for WebSocketSender<T>
650where
651 T: AsyncRead + AsyncWrite + Unpin,
652{
653 type Error = WsError;
654
655 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
656 self.shared.lock().poll_ready(cx)
657 }
658
659 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
660 self.shared.lock().start_send(item)
661 }
662
663 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
664 self.shared.lock().poll_flush(cx)
665 }
666
667 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
668 self.shared.lock().poll_close(cx)
669 }
670}
671
672#[cfg(not(feature = "futures-03-sink"))]
673impl<S> bytes::private::SealedSender for WebSocketSender<S>
674where
675 S: AsyncRead + AsyncWrite + Unpin,
676{
677 fn poll_write(
678 self: Pin<&mut Self>,
679 cx: &mut Context<'_>,
680 buf: &[u8],
681 ) -> Poll<Result<usize, WsError>> {
682 let me = self.get_mut();
683 let mut ws = me.shared.lock();
684 ready!(ws.poll_ready(cx))?;
685 let len = buf.len();
686 ws.start_send(Message::binary(buf.to_owned()))?;
687 Poll::Ready(Ok(len))
688 }
689
690 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
691 self.shared.lock().poll_flush(cx)
692 }
693
694 fn poll_close(
695 self: Pin<&mut Self>,
696 cx: &mut Context<'_>,
697 msg: &mut Option<Message>,
698 ) -> Poll<Result<(), WsError>> {
699 let me = self.get_mut();
700 let mut ws = me.shared.lock();
701 send_helper(&mut ws, msg, cx)
702 }
703}
704
705#[derive(Debug)]
707pub struct WebSocketReceiver<S> {
708 shared: Arc<Shared<S>>,
709}
710
711impl<S> WebSocketReceiver<S> {
712 pub fn is_pair_of(&self, other: &WebSocketSender<S>) -> bool {
715 Arc::ptr_eq(&self.shared, &other.shared)
716 }
717}
718
719impl<S> Stream for WebSocketReceiver<S>
720where
721 S: AsyncRead + AsyncWrite + Unpin,
722{
723 type Item = Result<Message, WsError>;
724
725 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
726 self.shared.lock().poll_next(cx)
727 }
728}
729
730impl<S> FusedStream for WebSocketReceiver<S>
731where
732 S: AsyncRead + AsyncWrite + Unpin,
733{
734 fn is_terminated(&self) -> bool {
735 self.shared.lock().ended
736 }
737}
738
739#[derive(Debug)]
740struct Shared<S>(Mutex<WebSocketStream<S>>);
741
742impl<S> Shared<S> {
743 fn lock(&self) -> MutexGuard<'_, WebSocketStream<S>> {
744 self.0.lock().expect("lock shared stream")
745 }
746
747 fn into_inner(self) -> WebSocketStream<S> {
748 self.0.into_inner().expect("get shared stream")
749 }
750}
751
752#[cfg(any(
753 feature = "async-tls",
754 feature = "async-std-runtime",
755 feature = "tokio-runtime",
756 feature = "gio-runtime"
757))]
758#[inline]
760pub(crate) fn domain(
761 request: &tungstenite::handshake::client::Request,
762) -> Result<String, tungstenite::Error> {
763 request
764 .uri()
765 .host()
766 .map(|host| {
767 let host = if host.starts_with('[') {
773 &host[1..host.len() - 1]
774 } else {
775 host
776 };
777
778 host.to_owned()
779 })
780 .ok_or(tungstenite::Error::Url(
781 tungstenite::error::UrlError::NoHostName,
782 ))
783}
784
785#[cfg(any(
786 feature = "async-std-runtime",
787 feature = "tokio-runtime",
788 feature = "gio-runtime"
789))]
790#[inline]
792pub(crate) fn port(
793 request: &tungstenite::handshake::client::Request,
794) -> Result<u16, tungstenite::Error> {
795 request
796 .uri()
797 .port_u16()
798 .or_else(|| match request.uri().scheme_str() {
799 Some("wss") => Some(443),
800 Some("ws") => Some(80),
801 _ => None,
802 })
803 .ok_or(tungstenite::Error::Url(
804 tungstenite::error::UrlError::UnsupportedUrlScheme,
805 ))
806}
807
808#[cfg(test)]
809mod tests {
810 #[cfg(any(
811 feature = "async-tls",
812 feature = "async-std-runtime",
813 feature = "tokio-runtime",
814 feature = "gio-runtime"
815 ))]
816 #[test]
817 fn domain_strips_ipv6_brackets() {
818 use tungstenite::client::IntoClientRequest;
819
820 let request = "ws://[::1]:80".into_client_request().unwrap();
821 assert_eq!(crate::domain(&request).unwrap(), "::1");
822 }
823
824 #[cfg(feature = "handshake")]
825 #[test]
826 fn requests_cannot_contain_invalid_uris() {
827 use tungstenite::client::IntoClientRequest;
828
829 assert!("ws://[".into_client_request().is_err());
830 assert!("ws://[blabla/bla".into_client_request().is_err());
831 assert!("ws://[::1/bla".into_client_request().is_err());
832 }
833}