1use std::borrow::ToOwned;
6use std::cell::Cell;
7use std::ptr::{self, NonNull};
8
9use dom_struct::dom_struct;
10use ipc_channel::router::ROUTER;
11use js::jsapi::JSObject;
12use js::jsval::UndefinedValue;
13use js::realm::AutoRealm;
14use js::rust::{CustomAutoRooterGuard, HandleObject};
15use js::typedarray::{ArrayBuffer, ArrayBufferView, CreateWith};
16use net_traits::blob_url_store::UrlWithBlobClaim;
17use net_traits::request::{
18 CacheMode, CredentialsMode, RedirectMode, Referrer, RequestBuilder, RequestMode,
19 ServiceWorkersMode,
20};
21use net_traits::{
22 CoreResourceMsg, FetchChannels, MessageData, WebSocketDomAction, WebSocketNetworkEvent,
23};
24use profile_traits::ipc as ProfiledIpc;
25use script_bindings::cell::DomRefCell;
26use script_bindings::conversions::SafeToJSValConvertible;
27use script_bindings::reflector::{DomObject, reflect_dom_object_with_proto};
28use servo_base::generic_channel::{LazyCallback, lazy_callback};
29use servo_constellation_traits::BlobImpl;
30use servo_url::{ImmutableOrigin, ServoUrl};
31
32use crate::dom::bindings::codegen::Bindings::BlobBinding::BlobMethods;
33use crate::dom::bindings::codegen::Bindings::WebSocketBinding::{BinaryType, WebSocketMethods};
34use crate::dom::bindings::codegen::Bindings::WindowBinding::WindowMethods;
35use crate::dom::bindings::codegen::UnionTypes::StringOrStringSequence;
36use crate::dom::bindings::error::{Error, ErrorResult, Fallible};
37use crate::dom::bindings::inheritance::Castable;
38use crate::dom::bindings::refcounted::Trusted;
39use crate::dom::bindings::reflector::DomGlobal;
40use crate::dom::bindings::root::DomRoot;
41use crate::dom::bindings::str::{DOMString, USVString, is_token};
42use crate::dom::blob::Blob;
43use crate::dom::closeevent::CloseEvent;
44use crate::dom::csp::{GlobalCspReporting, Violation};
45use crate::dom::event::{Event, EventBubbles, EventCancelable};
46use crate::dom::eventtarget::EventTarget;
47use crate::dom::globalscope::GlobalScope;
48use crate::dom::messageevent::MessageEvent;
49use crate::dom::window::Window;
50use crate::fetch::RequestWithGlobalScope;
51use crate::script_runtime::CanGc;
52use crate::task::TaskOnce;
53use crate::task_source::SendableTaskSource;
54
55#[derive(Clone, Copy, Debug, JSTraceable, MallocSizeOf, PartialEq)]
56enum WebSocketRequestState {
57 Connecting = 0,
58 Open = 1,
59 Closing = 2,
60 Closed = 3,
61}
62
63#[expect(dead_code)]
66mod close_code {
67 pub(crate) const NORMAL: u16 = 1000;
68 pub(crate) const GOING_AWAY: u16 = 1001;
69 pub(crate) const PROTOCOL_ERROR: u16 = 1002;
70 pub(crate) const UNSUPPORTED_DATATYPE: u16 = 1003;
71 pub(crate) const NO_STATUS: u16 = 1005;
72 pub(crate) const ABNORMAL: u16 = 1006;
73 pub(crate) const INVALID_PAYLOAD: u16 = 1007;
74 pub(crate) const POLICY_VIOLATION: u16 = 1008;
75 pub(crate) const TOO_LARGE: u16 = 1009;
76 pub(crate) const EXTENSION_MISSING: u16 = 1010;
77 pub(crate) const INTERNAL_ERROR: u16 = 1011;
78 pub(crate) const TLS_FAILED: u16 = 1015;
79}
80
81fn close_the_websocket_connection(
82 address: Trusted<WebSocket>,
83 task_source: &SendableTaskSource,
84 code: Option<u16>,
85 reason: String,
86) {
87 task_source.queue(CloseTask {
88 address,
89 failed: false,
90 code,
91 reason: Some(reason),
92 });
93}
94
95fn fail_the_websocket_connection(address: Trusted<WebSocket>, task_source: &SendableTaskSource) {
96 task_source.queue(CloseTask {
97 address,
98 failed: true,
99 code: Some(close_code::ABNORMAL),
100 reason: None,
101 });
102}
103
104#[dom_struct]
105pub(crate) struct WebSocket {
106 eventtarget: EventTarget,
107 #[no_trace]
108 url: ServoUrl,
109 ready_state: Cell<WebSocketRequestState>,
110 buffered_amount: Cell<u64>,
111 clearing_buffer: Cell<bool>, #[no_trace]
113 callback: LazyCallback<WebSocketDomAction>,
114 binary_type: Cell<BinaryType>,
115 protocol: DomRefCell<String>, }
117
118impl WebSocket {
119 fn new_inherited(url: ServoUrl, callback: LazyCallback<WebSocketDomAction>) -> WebSocket {
120 WebSocket {
121 eventtarget: EventTarget::new_inherited(),
122 url,
123 ready_state: Cell::new(WebSocketRequestState::Connecting),
124 buffered_amount: Cell::new(0),
125 clearing_buffer: Cell::new(false),
126 callback,
127 binary_type: Cell::new(BinaryType::Blob),
128 protocol: DomRefCell::new("".to_owned()),
129 }
130 }
131
132 fn new(
133 global: &GlobalScope,
134 proto: Option<HandleObject>,
135 url: ServoUrl,
136 sender: LazyCallback<WebSocketDomAction>,
137 can_gc: CanGc,
138 ) -> DomRoot<WebSocket> {
139 let websocket = reflect_dom_object_with_proto(
140 Box::new(WebSocket::new_inherited(url, sender)),
141 global,
142 proto,
143 can_gc,
144 );
145 if let Some(window) = global.downcast::<Window>() {
146 window.Document().track_websocket(&websocket);
147 }
148 websocket
149 }
150
151 fn send_impl(&self, data_byte_len: u64) -> Fallible<bool> {
153 let return_after_buffer = match self.ready_state.get() {
154 WebSocketRequestState::Connecting => {
155 return Err(Error::InvalidState(None));
156 },
157 WebSocketRequestState::Open => false,
158 WebSocketRequestState::Closing | WebSocketRequestState::Closed => true,
159 };
160
161 let address = Trusted::new(self);
162
163 match data_byte_len.checked_add(self.buffered_amount.get()) {
164 None => panic!(),
165 Some(new_amount) => self.buffered_amount.set(new_amount),
166 };
167
168 if return_after_buffer {
169 return Ok(false);
170 }
171
172 if !self.clearing_buffer.get() && self.ready_state.get() == WebSocketRequestState::Open {
173 self.clearing_buffer.set(true);
174
175 self.global()
177 .task_manager()
178 .websocket_task_source()
179 .queue_unconditionally(BufferedAmountTask { address });
180 }
181
182 Ok(true)
183 }
184
185 pub(crate) fn origin(&self) -> ImmutableOrigin {
186 self.url.origin()
187 }
188
189 pub(crate) fn make_disappear(&self) -> bool {
192 let result = self.ready_state.get() != WebSocketRequestState::Closed;
193 let _ = self.Close(Some(1001), None);
194 result
195 }
196}
197
198impl WebSocketMethods<crate::DomTypeHolder> for WebSocket {
199 fn Constructor(
201 global: &GlobalScope,
202 proto: Option<HandleObject>,
203 can_gc: CanGc,
204 url: DOMString,
205 protocols: Option<StringOrStringSequence>,
206 ) -> Fallible<DomRoot<WebSocket>> {
207 let base_url = global.api_base_url();
209 let mut url_record =
212 ServoUrl::parse_with_base(Some(&base_url), &url.str()).or(Err(Error::Syntax(None)))?;
213
214 match url_record.scheme() {
218 "http" => {
219 url_record
220 .as_mut_url()
221 .set_scheme("ws")
222 .expect("Can't set scheme from http to ws");
223 },
224 "https" => {
225 url_record
226 .as_mut_url()
227 .set_scheme("wss")
228 .expect("Can't set scheme from https to wss");
229 },
230 "ws" | "wss" => {},
231 _ => return Err(Error::Syntax(None)),
232 }
233
234 if url_record.fragment().is_some() {
236 return Err(Error::Syntax(None));
237 }
238
239 let protocols = protocols.map_or(vec![], |p| match p {
241 StringOrStringSequence::String(string) => vec![string.into()],
242 StringOrStringSequence::StringSequence(seq) => {
243 seq.into_iter().map(String::from).collect()
244 },
245 });
246
247 for (i, protocol) in protocols.iter().enumerate() {
251 if protocols[i + 1..]
255 .iter()
256 .any(|p| p.eq_ignore_ascii_case(protocol))
257 {
258 return Err(Error::Syntax(None));
259 }
260
261 if !is_token(protocol.as_bytes()) {
263 return Err(Error::Syntax(None));
264 }
265 }
266
267 let (dom_action_sender, resource_action_receiver) = lazy_callback();
269 let (resource_event_sender, dom_event_receiver) =
270 ProfiledIpc::channel(global.time_profiler_chan().clone()).unwrap();
271
272 let ws = WebSocket::new(global, proto, url_record.clone(), dom_action_sender, can_gc);
274 let address = Trusted::new(&*ws);
275
276 let request = RequestBuilder::new(
282 global.webview_id(),
283 UrlWithBlobClaim::from_url_without_having_claimed_blob(url_record.clone()),
284 Referrer::NoReferrer,
285 )
286 .with_global_scope(global)
287 .mode(RequestMode::WebSocket {
288 protocols,
289 original_url: url_record,
290 })
291 .service_workers_mode(ServiceWorkersMode::None)
292 .credentials_mode(CredentialsMode::Include)
293 .cache_mode(CacheMode::NoCache)
294 .redirect_mode(RedirectMode::Error);
295
296 let channels = FetchChannels::WebSocket {
297 event_sender: resource_event_sender,
298 action_receiver: resource_action_receiver,
299 };
300 let _ = global
301 .core_resource_thread()
302 .send(CoreResourceMsg::Fetch(request, channels));
303
304 let task_source = global.task_manager().websocket_task_source().to_sendable();
305 ROUTER.add_typed_route(
306 dom_event_receiver.to_ipc_receiver(),
307 Box::new(move |message| match message.unwrap() {
308 WebSocketNetworkEvent::ReportCSPViolations(violations) => {
309 let task = ReportCSPViolationTask {
310 websocket: address.clone(),
311 violations,
312 };
313 task_source.queue(task);
314 },
315 WebSocketNetworkEvent::ConnectionEstablished { protocol_in_use } => {
316 let open_thread = ConnectionEstablishedTask {
317 address: address.clone(),
318 protocol_in_use,
319 };
320 task_source.queue(open_thread);
321 },
322 WebSocketNetworkEvent::MessageReceived(message) => {
323 let message_thread = MessageReceivedTask {
324 address: address.clone(),
325 message,
326 };
327 task_source.queue(message_thread);
328 },
329 WebSocketNetworkEvent::Fail => {
330 fail_the_websocket_connection(address.clone(), &task_source);
331 },
332 WebSocketNetworkEvent::Close(code, reason) => {
333 close_the_websocket_connection(address.clone(), &task_source, code, reason);
334 },
335 }),
336 );
337
338 Ok(ws)
339 }
340
341 event_handler!(open, GetOnopen, SetOnopen);
343
344 event_handler!(close, GetOnclose, SetOnclose);
346
347 event_handler!(error, GetOnerror, SetOnerror);
349
350 event_handler!(message, GetOnmessage, SetOnmessage);
352
353 fn Url(&self) -> DOMString {
355 DOMString::from(self.url.as_str())
356 }
357
358 fn ReadyState(&self) -> u16 {
360 self.ready_state.get() as u16
361 }
362
363 fn BufferedAmount(&self) -> u64 {
365 self.buffered_amount.get()
366 }
367
368 fn BinaryType(&self) -> BinaryType {
370 self.binary_type.get()
371 }
372
373 fn SetBinaryType(&self, btype: BinaryType) {
375 self.binary_type.set(btype)
376 }
377
378 fn Protocol(&self) -> DOMString {
380 DOMString::from(self.protocol.borrow().clone())
381 }
382
383 fn Send(&self, data: USVString) -> ErrorResult {
385 let data_byte_len = data.0.len() as u64;
386 let send_data = self.send_impl(data_byte_len)?;
387
388 if send_data {
389 let _ = self
390 .callback
391 .send(WebSocketDomAction::SendMessage(MessageData::Text(data.0)));
392 }
393
394 Ok(())
395 }
396
397 fn Send_(&self, blob: &Blob) -> ErrorResult {
399 let data_byte_len = blob.Size();
404 let send_data = self.send_impl(data_byte_len)?;
405
406 if send_data {
407 let bytes = blob.get_bytes().unwrap_or_default();
408 let _ = self
409 .callback
410 .send(WebSocketDomAction::SendMessage(MessageData::Binary(bytes)));
411 }
412
413 Ok(())
414 }
415
416 fn Send__(&self, array: CustomAutoRooterGuard<ArrayBuffer>) -> ErrorResult {
418 let bytes = array.to_vec();
419 let data_byte_len = bytes.len();
420 let send_data = self.send_impl(data_byte_len as u64)?;
421
422 if send_data {
423 let _ = self
424 .callback
425 .send(WebSocketDomAction::SendMessage(MessageData::Binary(bytes)));
426 }
427 Ok(())
428 }
429
430 fn Send___(&self, array: CustomAutoRooterGuard<ArrayBufferView>) -> ErrorResult {
432 let bytes = array.to_vec();
433 let data_byte_len = bytes.len();
434 let send_data = self.send_impl(data_byte_len as u64)?;
435
436 if send_data {
437 let _ = self
438 .callback
439 .send(WebSocketDomAction::SendMessage(MessageData::Binary(bytes)));
440 }
441 Ok(())
442 }
443
444 fn Close(&self, code: Option<u16>, reason: Option<USVString>) -> ErrorResult {
446 if let Some(code) = code {
447 if code != close_code::NORMAL && !(3000..=4999).contains(&code) {
449 return Err(Error::InvalidAccess(None));
450 }
451 }
452 if let Some(ref reason) = reason &&
453 reason.0.len() > 123
454 {
455 return Err(Error::Syntax(Some("Reason too long".to_string())));
457 }
458
459 match self.ready_state.get() {
460 WebSocketRequestState::Closing | WebSocketRequestState::Closed => {}, WebSocketRequestState::Connecting => {
462 self.ready_state.set(WebSocketRequestState::Closing);
466
467 fail_the_websocket_connection(
468 Trusted::new(self),
469 &self
470 .global()
471 .task_manager()
472 .websocket_task_source()
473 .to_sendable(),
474 );
475 },
476 WebSocketRequestState::Open => {
477 self.ready_state.set(WebSocketRequestState::Closing);
478
479 let reason = reason.map(|reason| reason.0);
482 let _ = self.callback.send(WebSocketDomAction::Close(code, reason));
483 },
484 }
485 Ok(()) }
487}
488
489struct ReportCSPViolationTask {
490 websocket: Trusted<WebSocket>,
491 violations: Vec<Violation>,
492}
493
494impl TaskOnce for ReportCSPViolationTask {
495 fn run_once(self, _cx: &mut js::context::JSContext) {
496 let global = self.websocket.root().global();
497 global.report_csp_violations(self.violations, None, None);
498 }
499}
500
501struct ConnectionEstablishedTask {
504 address: Trusted<WebSocket>,
505 protocol_in_use: Option<String>,
506}
507
508impl TaskOnce for ConnectionEstablishedTask {
509 fn run_once(self, cx: &mut js::context::JSContext) {
511 let ws = self.address.root();
512
513 ws.ready_state.set(WebSocketRequestState::Open);
515
516 if let Some(protocol_name) = self.protocol_in_use {
521 *ws.protocol.borrow_mut() = protocol_name;
522 };
523
524 ws.upcast().fire_event(cx, atom!("open"));
526 }
527}
528
529struct BufferedAmountTask {
530 address: Trusted<WebSocket>,
531}
532
533impl TaskOnce for BufferedAmountTask {
534 fn run_once(self, _cx: &mut js::context::JSContext) {
540 let ws = self.address.root();
541
542 ws.buffered_amount.set(0);
543 ws.clearing_buffer.set(false);
544 }
545}
546
547struct CloseTask {
548 address: Trusted<WebSocket>,
549 failed: bool,
550 code: Option<u16>,
551 reason: Option<String>,
552}
553
554impl TaskOnce for CloseTask {
555 fn run_once(self, cx: &mut js::context::JSContext) {
556 let ws = self.address.root();
557
558 if ws.ready_state.get() == WebSocketRequestState::Closed {
559 return;
561 }
562
563 ws.ready_state.set(WebSocketRequestState::Closed);
568
569 if self.failed {
571 ws.upcast().fire_event(cx, atom!("error"));
572 }
573
574 let clean_close = !self.failed;
576 let code = self.code.unwrap_or(close_code::NO_STATUS);
577 let reason = DOMString::from(self.reason.unwrap_or("".to_owned()));
578 let close_event = CloseEvent::new(
579 &ws.global(),
580 atom!("close"),
581 EventBubbles::DoesNotBubble,
582 EventCancelable::NotCancelable,
583 clean_close,
584 code,
585 reason,
586 CanGc::from_cx(cx),
587 );
588 close_event
589 .upcast::<Event>()
590 .fire(ws.upcast(), CanGc::from_cx(cx));
591 }
592}
593
594struct MessageReceivedTask {
595 address: Trusted<WebSocket>,
596 message: MessageData,
597}
598
599impl TaskOnce for MessageReceivedTask {
600 #[expect(unsafe_code)]
601 fn run_once(self, cx: &mut js::context::JSContext) {
602 let ws = self.address.root();
603 debug!(
604 "MessageReceivedTask::handler({:p}): readyState={:?}",
605 &*ws,
606 ws.ready_state.get()
607 );
608
609 if ws.ready_state.get() != WebSocketRequestState::Open {
611 return;
612 }
613
614 let global = ws.global();
616 let mut realm = AutoRealm::new(
617 cx,
618 NonNull::new(ws.reflector().get_jsobject().get()).unwrap(),
619 );
620 let cx = &mut *realm;
621 rooted!(&in(cx) let mut message = UndefinedValue());
622 match self.message {
623 MessageData::Text(text) => {
624 text.safe_to_jsval(cx.into(), message.handle_mut(), CanGc::from_cx(cx))
625 },
626 MessageData::Binary(data) => match ws.binary_type.get() {
627 BinaryType::Blob => {
628 let blob = Blob::new(
629 &global,
630 BlobImpl::new_from_bytes(data, "".to_owned()),
631 CanGc::from_cx(cx),
632 );
633 blob.safe_to_jsval(cx.into(), message.handle_mut(), CanGc::from_cx(cx));
634 },
635 BinaryType::Arraybuffer => {
636 rooted!(&in(cx) let mut array_buffer = ptr::null_mut::<JSObject>());
637 unsafe {
639 assert!(
640 ArrayBuffer::create(
641 cx.raw_cx(),
642 CreateWith::Slice(&data),
643 array_buffer.handle_mut()
644 )
645 .is_ok()
646 )
647 };
648
649 (*array_buffer).safe_to_jsval(
650 cx.into(),
651 message.handle_mut(),
652 CanGc::from_cx(cx),
653 );
654 },
655 },
656 }
657 MessageEvent::dispatch_jsval(
658 ws.upcast(),
659 &global,
660 message.handle(),
661 Some(&ws.origin().ascii_serialization()),
662 None,
663 vec![],
664 CanGc::from_cx(cx),
665 );
666 }
667}