1#![deny(
42 missing_docs,
43 unused_must_use,
44 unused_mut,
45 unused_imports,
46 unused_import_braces
47)]
48
49pub use tungstenite;
50
51mod compat;
52mod handshake;
53
54#[cfg(any(
55 feature = "async-tls",
56 feature = "async-native-tls",
57 feature = "tokio-native-tls",
58 feature = "tokio-rustls-manual-roots",
59 feature = "tokio-rustls-native-certs",
60 feature = "tokio-rustls-platform-verifier",
61 feature = "tokio-rustls-webpki-roots",
62 feature = "tokio-openssl",
63))]
64pub mod stream;
65
66use std::{
67 io::{Read, Write},
68 pin::Pin,
69 sync::{Arc, Mutex, MutexGuard},
70 task::{ready, Context, Poll},
71};
72
73use compat::{cvt, AllowStd, ContextWaker};
74use futures_core::stream::{FusedStream, Stream};
75use futures_io::{AsyncRead, AsyncWrite};
76use log::*;
77
78#[cfg(feature = "handshake")]
79use tungstenite::{
80 client::IntoClientRequest,
81 handshake::{
82 client::{ClientHandshake, Response},
83 server::{Callback, NoCallback},
84 HandshakeError,
85 },
86};
87use tungstenite::{
88 error::Error as WsError,
89 protocol::{Message, Role, WebSocket, WebSocketConfig},
90};
91
92#[cfg(feature = "async-std-runtime")]
93pub mod async_std;
94#[cfg(feature = "async-tls")]
95pub mod async_tls;
96#[cfg(feature = "gio-runtime")]
97pub mod gio;
98#[cfg(feature = "smol-runtime")]
99pub mod smol;
100#[cfg(feature = "tokio-runtime")]
101pub mod tokio;
102
103pub mod bytes;
104pub use bytes::ByteReader;
105pub use bytes::ByteWriter;
106
107use tungstenite::protocol::CloseFrame;
108
109#[cfg(feature = "handshake")]
122pub async fn client_async<'a, R, S>(
123 request: R,
124 stream: S,
125) -> Result<(WebSocketStream<S>, Response), WsError>
126where
127 R: IntoClientRequest + Unpin,
128 S: AsyncRead + AsyncWrite + Unpin,
129{
130 client_async_with_config(request, stream, None).await
131}
132
133#[cfg(feature = "handshake")]
136pub async fn client_async_with_config<'a, R, S>(
137 request: R,
138 stream: S,
139 config: Option<WebSocketConfig>,
140) -> Result<(WebSocketStream<S>, Response), WsError>
141where
142 R: IntoClientRequest + Unpin,
143 S: AsyncRead + AsyncWrite + Unpin,
144{
145 let f = handshake::client_handshake(stream, move |allow_std| {
146 let request = request.into_client_request()?;
147 let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
148 cli_handshake.handshake()
149 });
150 f.await.map_err(|e| match e {
151 HandshakeError::Failure(e) => e,
152 e => WsError::Io(std::io::Error::new(
153 std::io::ErrorKind::Other,
154 e.to_string(),
155 )),
156 })
157}
158
159#[cfg(feature = "handshake")]
171pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
172where
173 S: AsyncRead + AsyncWrite + Unpin,
174{
175 accept_hdr_async(stream, NoCallback).await
176}
177
178#[cfg(feature = "handshake")]
181pub async fn accept_async_with_config<S>(
182 stream: S,
183 config: Option<WebSocketConfig>,
184) -> Result<WebSocketStream<S>, WsError>
185where
186 S: AsyncRead + AsyncWrite + Unpin,
187{
188 accept_hdr_async_with_config(stream, NoCallback, config).await
189}
190
191#[cfg(feature = "handshake")]
197pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
198where
199 S: AsyncRead + AsyncWrite + Unpin,
200 C: Callback + Unpin,
201{
202 accept_hdr_async_with_config(stream, callback, None).await
203}
204
205#[cfg(feature = "handshake")]
208pub async fn accept_hdr_async_with_config<S, C>(
209 stream: S,
210 callback: C,
211 config: Option<WebSocketConfig>,
212) -> Result<WebSocketStream<S>, WsError>
213where
214 S: AsyncRead + AsyncWrite + Unpin,
215 C: Callback + Unpin,
216{
217 let f = handshake::server_handshake(stream, move |allow_std| {
218 tungstenite::accept_hdr_with_config(allow_std, callback, config)
219 });
220 f.await.map_err(|e| match e {
221 HandshakeError::Failure(e) => e,
222 e => WsError::Io(std::io::Error::new(
223 std::io::ErrorKind::Other,
224 e.to_string(),
225 )),
226 })
227}
228
229#[derive(Debug)]
239pub struct WebSocketStream<S> {
240 inner: WebSocket<AllowStd<S>>,
241 #[cfg(feature = "futures-03-sink")]
242 closing: bool,
243 ended: bool,
244 ready: bool,
249}
250
251impl<S> WebSocketStream<S> {
252 pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
255 where
256 S: AsyncRead + AsyncWrite + Unpin,
257 {
258 handshake::without_handshake(stream, move |allow_std| {
259 WebSocket::from_raw_socket(allow_std, role, config)
260 })
261 .await
262 }
263
264 pub async fn from_partially_read(
267 stream: S,
268 part: Vec<u8>,
269 role: Role,
270 config: Option<WebSocketConfig>,
271 ) -> Self
272 where
273 S: AsyncRead + AsyncWrite + Unpin,
274 {
275 handshake::without_handshake(stream, move |allow_std| {
276 WebSocket::from_partially_read(allow_std, part, role, config)
277 })
278 .await
279 }
280
281 pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
282 Self {
283 inner: ws,
284 #[cfg(feature = "futures-03-sink")]
285 closing: false,
286 ended: false,
287 ready: true,
288 }
289 }
290
291 fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
292 where
293 F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
294 AllowStd<S>: Read + Write,
295 {
296 #[cfg(feature = "verbose-logging")]
297 trace!("{}:{} WebSocketStream.with_context", file!(), line!());
298 if let Some((kind, ctx)) = ctx {
299 self.inner.get_mut().set_waker(kind, ctx.waker());
300 }
301 f(&mut self.inner)
302 }
303
304 pub fn into_inner(self) -> S {
306 self.inner.into_inner().into_inner()
307 }
308
309 pub fn get_ref(&self) -> &S
311 where
312 S: AsyncRead + AsyncWrite + Unpin,
313 {
314 self.inner.get_ref().get_ref()
315 }
316
317 pub fn get_mut(&mut self) -> &mut S
319 where
320 S: AsyncRead + AsyncWrite + Unpin,
321 {
322 self.inner.get_mut().get_mut()
323 }
324
325 pub fn get_config(&self) -> &WebSocketConfig {
327 self.inner.get_config()
328 }
329
330 pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
332 where
333 S: AsyncRead + AsyncWrite + Unpin,
334 {
335 self.send(Message::Close(msg)).await
336 }
337
338 pub fn split(self) -> (WebSocketSender<S>, WebSocketReceiver<S>) {
341 let shared = Arc::new(Shared(Mutex::new(self)));
342 let sender = WebSocketSender {
343 shared: shared.clone(),
344 };
345
346 let receiver = WebSocketReceiver { shared };
347 (sender, receiver)
348 }
349
350 pub fn reunite(
355 sender: WebSocketSender<S>,
356 receiver: WebSocketReceiver<S>,
357 ) -> Result<Self, (WebSocketSender<S>, WebSocketReceiver<S>)> {
358 if sender.is_pair_of(&receiver) {
359 drop(receiver);
360 let stream = Arc::try_unwrap(sender.shared)
361 .ok()
362 .expect("reunite the stream")
363 .into_inner();
364
365 Ok(stream)
366 } else {
367 Err((sender, receiver))
368 }
369 }
370}
371
372impl<S> WebSocketStream<S>
373where
374 S: AsyncRead + AsyncWrite + Unpin,
375{
376 fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Message, WsError>>> {
377 #[cfg(feature = "verbose-logging")]
378 trace!("{}:{} WebSocketStream.poll_next", file!(), line!());
379
380 if self.ended {
384 return Poll::Ready(None);
385 }
386
387 match ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
388 #[cfg(feature = "verbose-logging")]
389 trace!(
390 "{}:{} WebSocketStream.with_context poll_next -> read()",
391 file!(),
392 line!()
393 );
394 cvt(s.read())
395 })) {
396 Ok(v) => Poll::Ready(Some(Ok(v))),
397 Err(e) => {
398 self.ended = true;
399 if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) {
400 Poll::Ready(None)
401 } else {
402 Poll::Ready(Some(Err(e)))
403 }
404 }
405 }
406 }
407
408 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
409 if self.ready {
410 return Poll::Ready(Ok(()));
411 }
412
413 self.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
415 .map(|r| {
416 self.ready = true;
417 r
418 })
419 }
420
421 fn start_send(&mut self, item: Message) -> Result<(), WsError> {
422 match self.with_context(None, |s| s.write(item)) {
423 Ok(()) => {
424 self.ready = true;
425 Ok(())
426 }
427 Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
428 self.ready = false;
431 Ok(())
432 }
433 Err(e) => {
434 self.ready = true;
435 debug!("websocket start_send error: {}", e);
436 Err(e)
437 }
438 }
439 }
440
441 fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
442 self.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
443 .map(|r| {
444 self.ready = true;
445 match r {
446 Err(WsError::ConnectionClosed) => Ok(()),
448 other => other,
449 }
450 })
451 }
452
453 #[cfg(feature = "futures-03-sink")]
454 fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
455 self.ready = true;
456 let res = if self.closing {
457 self.with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
459 } else {
460 self.with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
461 };
462
463 match res {
464 Ok(()) => Poll::Ready(Ok(())),
465 Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())),
466 Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
467 trace!("WouldBlock");
468 self.closing = true;
469 Poll::Pending
470 }
471 Err(err) => {
472 debug!("websocket close error: {}", err);
473 Poll::Ready(Err(err))
474 }
475 }
476 }
477}
478
479impl<S> Stream for WebSocketStream<S>
480where
481 S: AsyncRead + AsyncWrite + Unpin,
482{
483 type Item = Result<Message, WsError>;
484
485 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
486 self.get_mut().poll_next(cx)
487 }
488}
489
490impl<S> FusedStream for WebSocketStream<S>
491where
492 S: AsyncRead + AsyncWrite + Unpin,
493{
494 fn is_terminated(&self) -> bool {
495 self.ended
496 }
497}
498
499#[cfg(feature = "futures-03-sink")]
500impl<S> futures_util::Sink<Message> for WebSocketStream<S>
501where
502 S: AsyncRead + AsyncWrite + Unpin,
503{
504 type Error = WsError;
505
506 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
507 self.get_mut().poll_ready(cx)
508 }
509
510 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
511 self.get_mut().start_send(item)
512 }
513
514 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
515 self.get_mut().poll_flush(cx)
516 }
517
518 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
519 self.get_mut().poll_close(cx)
520 }
521}
522
523#[cfg(not(feature = "futures-03-sink"))]
524impl<S> bytes::private::SealedSender for WebSocketStream<S>
525where
526 S: AsyncRead + AsyncWrite + Unpin,
527{
528 fn poll_write(
529 self: Pin<&mut Self>,
530 cx: &mut Context<'_>,
531 buf: &[u8],
532 ) -> Poll<Result<usize, WsError>> {
533 let me = self.get_mut();
534 ready!(me.poll_ready(cx))?;
535 let len = buf.len();
536 me.start_send(Message::binary(buf.to_owned()))?;
537 Poll::Ready(Ok(len))
538 }
539
540 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
541 self.get_mut().poll_flush(cx)
542 }
543
544 fn poll_close(
545 self: Pin<&mut Self>,
546 cx: &mut Context<'_>,
547 msg: &mut Option<Message>,
548 ) -> Poll<Result<(), WsError>> {
549 let me = self.get_mut();
550 send_helper(me, msg, cx)
551 }
552}
553
554impl<S> WebSocketStream<S> {
555 pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
557 where
558 S: AsyncRead + AsyncWrite + Unpin,
559 {
560 Send {
561 ws: self,
562 msg: Some(msg),
563 }
564 .await
565 }
566}
567
568struct Send<W> {
569 ws: W,
570 msg: Option<Message>,
571}
572
573fn send_helper<S>(
575 ws: &mut WebSocketStream<S>,
576 msg: &mut Option<Message>,
577 cx: &mut Context<'_>,
578) -> Poll<Result<(), WsError>>
579where
580 S: AsyncRead + AsyncWrite + Unpin,
581{
582 if msg.is_some() {
583 ready!(ws.poll_ready(cx))?;
584 let msg = msg.take().expect("unreachable");
585 ws.start_send(msg)?;
586 }
587
588 ws.poll_flush(cx)
589}
590
591impl<S> std::future::Future for Send<&mut WebSocketStream<S>>
592where
593 S: AsyncRead + AsyncWrite + Unpin,
594{
595 type Output = Result<(), WsError>;
596
597 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
598 let me = self.get_mut();
599 send_helper(me.ws, &mut me.msg, cx)
600 }
601}
602
603impl<S> std::future::Future for Send<&Shared<S>>
604where
605 S: AsyncRead + AsyncWrite + Unpin,
606{
607 type Output = Result<(), WsError>;
608
609 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
610 let me = self.get_mut();
611 let mut ws = me.ws.lock();
612 send_helper(&mut ws, &mut me.msg, cx)
613 }
614}
615
616#[derive(Debug)]
618pub struct WebSocketSender<S> {
619 shared: Arc<Shared<S>>,
620}
621
622impl<S> WebSocketSender<S> {
623 pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
625 where
626 S: AsyncRead + AsyncWrite + Unpin,
627 {
628 Send {
629 ws: &*self.shared,
630 msg: Some(msg),
631 }
632 .await
633 }
634
635 pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
637 where
638 S: AsyncRead + AsyncWrite + Unpin,
639 {
640 self.send(Message::Close(msg)).await
641 }
642
643 pub fn is_pair_of(&self, other: &WebSocketReceiver<S>) -> bool {
646 Arc::ptr_eq(&self.shared, &other.shared)
647 }
648}
649
650#[cfg(feature = "futures-03-sink")]
651impl<T> futures_util::Sink<Message> for WebSocketSender<T>
652where
653 T: AsyncRead + AsyncWrite + Unpin,
654{
655 type Error = WsError;
656
657 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
658 self.shared.lock().poll_ready(cx)
659 }
660
661 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
662 self.shared.lock().start_send(item)
663 }
664
665 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
666 self.shared.lock().poll_flush(cx)
667 }
668
669 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
670 self.shared.lock().poll_close(cx)
671 }
672}
673
674#[cfg(not(feature = "futures-03-sink"))]
675impl<S> bytes::private::SealedSender for WebSocketSender<S>
676where
677 S: AsyncRead + AsyncWrite + Unpin,
678{
679 fn poll_write(
680 self: Pin<&mut Self>,
681 cx: &mut Context<'_>,
682 buf: &[u8],
683 ) -> Poll<Result<usize, WsError>> {
684 let me = self.get_mut();
685 let mut ws = me.shared.lock();
686 ready!(ws.poll_ready(cx))?;
687 let len = buf.len();
688 ws.start_send(Message::binary(buf.to_owned()))?;
689 Poll::Ready(Ok(len))
690 }
691
692 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
693 self.shared.lock().poll_flush(cx)
694 }
695
696 fn poll_close(
697 self: Pin<&mut Self>,
698 cx: &mut Context<'_>,
699 msg: &mut Option<Message>,
700 ) -> Poll<Result<(), WsError>> {
701 let me = self.get_mut();
702 let mut ws = me.shared.lock();
703 send_helper(&mut ws, msg, cx)
704 }
705}
706
707#[derive(Debug)]
709pub struct WebSocketReceiver<S> {
710 shared: Arc<Shared<S>>,
711}
712
713impl<S> WebSocketReceiver<S> {
714 pub fn is_pair_of(&self, other: &WebSocketSender<S>) -> bool {
717 Arc::ptr_eq(&self.shared, &other.shared)
718 }
719}
720
721impl<S> Stream for WebSocketReceiver<S>
722where
723 S: AsyncRead + AsyncWrite + Unpin,
724{
725 type Item = Result<Message, WsError>;
726
727 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
728 self.shared.lock().poll_next(cx)
729 }
730}
731
732impl<S> FusedStream for WebSocketReceiver<S>
733where
734 S: AsyncRead + AsyncWrite + Unpin,
735{
736 fn is_terminated(&self) -> bool {
737 self.shared.lock().ended
738 }
739}
740
741#[derive(Debug)]
742struct Shared<S>(Mutex<WebSocketStream<S>>);
743
744impl<S> Shared<S> {
745 fn lock(&self) -> MutexGuard<'_, WebSocketStream<S>> {
746 self.0.lock().expect("lock shared stream")
747 }
748
749 fn into_inner(self) -> WebSocketStream<S> {
750 self.0.into_inner().expect("get shared stream")
751 }
752}
753
754#[cfg(any(
755 feature = "async-tls",
756 feature = "async-std-runtime",
757 feature = "smol-runtime",
758 feature = "tokio-runtime",
759 feature = "gio-runtime"
760))]
761#[inline]
763pub(crate) fn domain(
764 request: &tungstenite::handshake::client::Request,
765) -> Result<String, tungstenite::Error> {
766 request
767 .uri()
768 .host()
769 .map(|host| {
770 let host = if host.starts_with('[') {
776 &host[1..host.len() - 1]
777 } else {
778 host
779 };
780
781 host.to_owned()
782 })
783 .ok_or(tungstenite::Error::Url(
784 tungstenite::error::UrlError::NoHostName,
785 ))
786}
787
788#[cfg(any(
789 feature = "async-std-runtime",
790 feature = "smol-runtime",
791 feature = "tokio-runtime",
792 feature = "gio-runtime"
793))]
794#[inline]
796pub(crate) fn port(
797 request: &tungstenite::handshake::client::Request,
798) -> Result<u16, tungstenite::Error> {
799 request
800 .uri()
801 .port_u16()
802 .or_else(|| match request.uri().scheme_str() {
803 Some("wss") => Some(443),
804 Some("ws") => Some(80),
805 _ => None,
806 })
807 .ok_or(tungstenite::Error::Url(
808 tungstenite::error::UrlError::UnsupportedUrlScheme,
809 ))
810}
811
812#[cfg(test)]
813mod tests {
814 #[cfg(any(
815 feature = "async-tls",
816 feature = "async-std-runtime",
817 feature = "smol-runtime",
818 feature = "tokio-runtime",
819 feature = "gio-runtime"
820 ))]
821 #[test]
822 fn domain_strips_ipv6_brackets() {
823 use tungstenite::client::IntoClientRequest;
824
825 let request = "ws://[::1]:80".into_client_request().unwrap();
826 assert_eq!(crate::domain(&request).unwrap(), "::1");
827 }
828
829 #[cfg(feature = "handshake")]
830 #[test]
831 fn requests_cannot_contain_invalid_uris() {
832 use tungstenite::client::IntoClientRequest;
833
834 assert!("ws://[".into_client_request().is_err());
835 assert!("ws://[blabla/bla".into_client_request().is_err());
836 assert!("ws://[::1/bla".into_client_request().is_err());
837 }
838}