1use std::borrow::ToOwned;
6use std::cell::Cell;
7use std::ptr::{self, NonNull};
8
9use constellation_traits::BlobImpl;
10use dom_struct::dom_struct;
11use ipc_channel::ipc::{self, IpcReceiver, IpcSender};
12use ipc_channel::router::ROUTER;
13use js::jsapi::JSObject;
14use js::jsval::UndefinedValue;
15use js::realm::AutoRealm;
16use js::rust::{CustomAutoRooterGuard, HandleObject};
17use js::typedarray::{ArrayBuffer, ArrayBufferView, CreateWith};
18use net_traits::request::{
19 CacheMode, CredentialsMode, RedirectMode, Referrer, RequestBuilder, RequestMode,
20 ServiceWorkersMode,
21};
22use net_traits::{
23 CoreResourceMsg, FetchChannels, MessageData, WebSocketDomAction, WebSocketNetworkEvent,
24};
25use profile_traits::ipc as ProfiledIpc;
26use script_bindings::conversions::SafeToJSValConvertible;
27use servo_url::{ImmutableOrigin, ServoUrl};
28
29use crate::dom::bindings::cell::DomRefCell;
30use crate::dom::bindings::codegen::Bindings::BlobBinding::BlobMethods;
31use crate::dom::bindings::codegen::Bindings::WebSocketBinding::{BinaryType, WebSocketMethods};
32use crate::dom::bindings::codegen::Bindings::WindowBinding::WindowMethods;
33use crate::dom::bindings::codegen::UnionTypes::StringOrStringSequence;
34use crate::dom::bindings::error::{Error, ErrorResult, Fallible};
35use crate::dom::bindings::inheritance::Castable;
36use crate::dom::bindings::refcounted::Trusted;
37use crate::dom::bindings::reflector::{DomGlobal, DomObject, reflect_dom_object_with_proto};
38use crate::dom::bindings::root::DomRoot;
39use crate::dom::bindings::str::{DOMString, USVString, is_token};
40use crate::dom::blob::Blob;
41use crate::dom::closeevent::CloseEvent;
42use crate::dom::csp::{GlobalCspReporting, Violation};
43use crate::dom::event::{Event, EventBubbles, EventCancelable};
44use crate::dom::eventtarget::EventTarget;
45use crate::dom::globalscope::GlobalScope;
46use crate::dom::messageevent::MessageEvent;
47use crate::dom::window::Window;
48use crate::fetch::RequestWithGlobalScope;
49use crate::script_runtime::CanGc;
50use crate::task::TaskOnce;
51use crate::task_source::SendableTaskSource;
52
53#[derive(Clone, Copy, Debug, JSTraceable, MallocSizeOf, PartialEq)]
54enum WebSocketRequestState {
55 Connecting = 0,
56 Open = 1,
57 Closing = 2,
58 Closed = 3,
59}
60
61#[expect(dead_code)]
64mod close_code {
65 pub(crate) const NORMAL: u16 = 1000;
66 pub(crate) const GOING_AWAY: u16 = 1001;
67 pub(crate) const PROTOCOL_ERROR: u16 = 1002;
68 pub(crate) const UNSUPPORTED_DATATYPE: u16 = 1003;
69 pub(crate) const NO_STATUS: u16 = 1005;
70 pub(crate) const ABNORMAL: u16 = 1006;
71 pub(crate) const INVALID_PAYLOAD: u16 = 1007;
72 pub(crate) const POLICY_VIOLATION: u16 = 1008;
73 pub(crate) const TOO_LARGE: u16 = 1009;
74 pub(crate) const EXTENSION_MISSING: u16 = 1010;
75 pub(crate) const INTERNAL_ERROR: u16 = 1011;
76 pub(crate) const TLS_FAILED: u16 = 1015;
77}
78
79fn close_the_websocket_connection(
80 address: Trusted<WebSocket>,
81 task_source: &SendableTaskSource,
82 code: Option<u16>,
83 reason: String,
84) {
85 task_source.queue(CloseTask {
86 address,
87 failed: false,
88 code,
89 reason: Some(reason),
90 });
91}
92
93fn fail_the_websocket_connection(address: Trusted<WebSocket>, task_source: &SendableTaskSource) {
94 task_source.queue(CloseTask {
95 address,
96 failed: true,
97 code: Some(close_code::ABNORMAL),
98 reason: None,
99 });
100}
101
102#[dom_struct]
103pub(crate) struct WebSocket {
104 eventtarget: EventTarget,
105 #[no_trace]
106 url: ServoUrl,
107 ready_state: Cell<WebSocketRequestState>,
108 buffered_amount: Cell<u64>,
109 clearing_buffer: Cell<bool>, #[ignore_malloc_size_of = "Defined in std"]
111 #[no_trace]
112 sender: IpcSender<WebSocketDomAction>,
113 binary_type: Cell<BinaryType>,
114 protocol: DomRefCell<String>, }
116
117impl WebSocket {
118 fn new_inherited(url: ServoUrl, sender: IpcSender<WebSocketDomAction>) -> WebSocket {
119 WebSocket {
120 eventtarget: EventTarget::new_inherited(),
121 url,
122 ready_state: Cell::new(WebSocketRequestState::Connecting),
123 buffered_amount: Cell::new(0),
124 clearing_buffer: Cell::new(false),
125 sender,
126 binary_type: Cell::new(BinaryType::Blob),
127 protocol: DomRefCell::new("".to_owned()),
128 }
129 }
130
131 fn new(
132 global: &GlobalScope,
133 proto: Option<HandleObject>,
134 url: ServoUrl,
135 sender: IpcSender<WebSocketDomAction>,
136 can_gc: CanGc,
137 ) -> DomRoot<WebSocket> {
138 let websocket = reflect_dom_object_with_proto(
139 Box::new(WebSocket::new_inherited(url, sender)),
140 global,
141 proto,
142 can_gc,
143 );
144 if let Some(window) = global.downcast::<Window>() {
145 window.Document().track_websocket(&websocket);
146 }
147 websocket
148 }
149
150 fn send_impl(&self, data_byte_len: u64) -> Fallible<bool> {
152 let return_after_buffer = match self.ready_state.get() {
153 WebSocketRequestState::Connecting => {
154 return Err(Error::InvalidState(None));
155 },
156 WebSocketRequestState::Open => false,
157 WebSocketRequestState::Closing | WebSocketRequestState::Closed => true,
158 };
159
160 let address = Trusted::new(self);
161
162 match data_byte_len.checked_add(self.buffered_amount.get()) {
163 None => panic!(),
164 Some(new_amount) => self.buffered_amount.set(new_amount),
165 };
166
167 if return_after_buffer {
168 return Ok(false);
169 }
170
171 if !self.clearing_buffer.get() && self.ready_state.get() == WebSocketRequestState::Open {
172 self.clearing_buffer.set(true);
173
174 self.global()
176 .task_manager()
177 .websocket_task_source()
178 .queue_unconditionally(BufferedAmountTask { address });
179 }
180
181 Ok(true)
182 }
183
184 pub(crate) fn origin(&self) -> ImmutableOrigin {
185 self.url.origin()
186 }
187
188 pub(crate) fn make_disappear(&self) -> bool {
191 let result = self.ready_state.get() != WebSocketRequestState::Closed;
192 let _ = self.Close(Some(1001), None);
193 result
194 }
195}
196
197impl WebSocketMethods<crate::DomTypeHolder> for WebSocket {
198 fn Constructor(
200 global: &GlobalScope,
201 proto: Option<HandleObject>,
202 can_gc: CanGc,
203 url: DOMString,
204 protocols: Option<StringOrStringSequence>,
205 ) -> Fallible<DomRoot<WebSocket>> {
206 let mut url_record = ServoUrl::parse(&url.str()).or(Err(Error::Syntax(None)))?;
210
211 match url_record.scheme() {
215 "http" => {
216 url_record
217 .as_mut_url()
218 .set_scheme("ws")
219 .expect("Can't set scheme from http to ws");
220 },
221 "https" => {
222 url_record
223 .as_mut_url()
224 .set_scheme("wss")
225 .expect("Can't set scheme from https to wss");
226 },
227 "ws" | "wss" => {},
228 _ => return Err(Error::Syntax(None)),
229 }
230
231 if url_record.fragment().is_some() {
233 return Err(Error::Syntax(None));
234 }
235
236 let protocols = protocols.map_or(vec![], |p| match p {
238 StringOrStringSequence::String(string) => vec![string.into()],
239 StringOrStringSequence::StringSequence(seq) => {
240 seq.into_iter().map(String::from).collect()
241 },
242 });
243
244 for (i, protocol) in protocols.iter().enumerate() {
248 if protocols[i + 1..]
252 .iter()
253 .any(|p| p.eq_ignore_ascii_case(protocol))
254 {
255 return Err(Error::Syntax(None));
256 }
257
258 if !is_token(protocol.as_bytes()) {
260 return Err(Error::Syntax(None));
261 }
262 }
263
264 let (dom_action_sender, resource_action_receiver): (
266 IpcSender<WebSocketDomAction>,
267 IpcReceiver<WebSocketDomAction>,
268 ) = ipc::channel().unwrap();
269 let (resource_event_sender, dom_event_receiver): (
270 IpcSender<WebSocketNetworkEvent>,
271 ProfiledIpc::IpcReceiver<WebSocketNetworkEvent>,
272 ) = ProfiledIpc::channel(global.time_profiler_chan().clone()).unwrap();
273
274 let ws = WebSocket::new(global, proto, url_record.clone(), dom_action_sender, can_gc);
276 let address = Trusted::new(&*ws);
277
278 let request = RequestBuilder::new(
284 global.webview_id(),
285 url_record.clone(),
286 Referrer::NoReferrer,
287 )
288 .with_global_scope(global)
289 .mode(RequestMode::WebSocket {
290 protocols,
291 original_url: url_record,
292 })
293 .service_workers_mode(ServiceWorkersMode::None)
294 .credentials_mode(CredentialsMode::Include)
295 .cache_mode(CacheMode::NoCache)
296 .redirect_mode(RedirectMode::Error);
297
298 let channels = FetchChannels::WebSocket {
299 event_sender: resource_event_sender,
300 action_receiver: resource_action_receiver,
301 };
302 let _ = global
303 .core_resource_thread()
304 .send(CoreResourceMsg::Fetch(request, channels));
305
306 let task_source = global.task_manager().websocket_task_source().to_sendable();
307 ROUTER.add_typed_route(
308 dom_event_receiver.to_ipc_receiver(),
309 Box::new(move |message| match message.unwrap() {
310 WebSocketNetworkEvent::ReportCSPViolations(violations) => {
311 let task = ReportCSPViolationTask {
312 websocket: address.clone(),
313 violations,
314 };
315 task_source.queue(task);
316 },
317 WebSocketNetworkEvent::ConnectionEstablished { protocol_in_use } => {
318 let open_thread = ConnectionEstablishedTask {
319 address: address.clone(),
320 protocol_in_use,
321 };
322 task_source.queue(open_thread);
323 },
324 WebSocketNetworkEvent::MessageReceived(message) => {
325 let message_thread = MessageReceivedTask {
326 address: address.clone(),
327 message,
328 };
329 task_source.queue(message_thread);
330 },
331 WebSocketNetworkEvent::Fail => {
332 fail_the_websocket_connection(address.clone(), &task_source);
333 },
334 WebSocketNetworkEvent::Close(code, reason) => {
335 close_the_websocket_connection(address.clone(), &task_source, code, reason);
336 },
337 }),
338 );
339
340 Ok(ws)
341 }
342
343 event_handler!(open, GetOnopen, SetOnopen);
345
346 event_handler!(close, GetOnclose, SetOnclose);
348
349 event_handler!(error, GetOnerror, SetOnerror);
351
352 event_handler!(message, GetOnmessage, SetOnmessage);
354
355 fn Url(&self) -> DOMString {
357 DOMString::from(self.url.as_str())
358 }
359
360 fn ReadyState(&self) -> u16 {
362 self.ready_state.get() as u16
363 }
364
365 fn BufferedAmount(&self) -> u64 {
367 self.buffered_amount.get()
368 }
369
370 fn BinaryType(&self) -> BinaryType {
372 self.binary_type.get()
373 }
374
375 fn SetBinaryType(&self, btype: BinaryType) {
377 self.binary_type.set(btype)
378 }
379
380 fn Protocol(&self) -> DOMString {
382 DOMString::from(self.protocol.borrow().clone())
383 }
384
385 fn Send(&self, data: USVString) -> ErrorResult {
387 let data_byte_len = data.0.len() as u64;
388 let send_data = self.send_impl(data_byte_len)?;
389
390 if send_data {
391 let _ = self
392 .sender
393 .send(WebSocketDomAction::SendMessage(MessageData::Text(data.0)));
394 }
395
396 Ok(())
397 }
398
399 fn Send_(&self, blob: &Blob) -> ErrorResult {
401 let data_byte_len = blob.Size();
406 let send_data = self.send_impl(data_byte_len)?;
407
408 if send_data {
409 let bytes = blob.get_bytes().unwrap_or_default();
410 let _ = self
411 .sender
412 .send(WebSocketDomAction::SendMessage(MessageData::Binary(bytes)));
413 }
414
415 Ok(())
416 }
417
418 fn Send__(&self, array: CustomAutoRooterGuard<ArrayBuffer>) -> ErrorResult {
420 let bytes = array.to_vec();
421 let data_byte_len = bytes.len();
422 let send_data = self.send_impl(data_byte_len as u64)?;
423
424 if send_data {
425 let _ = self
426 .sender
427 .send(WebSocketDomAction::SendMessage(MessageData::Binary(bytes)));
428 }
429 Ok(())
430 }
431
432 fn Send___(&self, array: CustomAutoRooterGuard<ArrayBufferView>) -> ErrorResult {
434 let bytes = array.to_vec();
435 let data_byte_len = bytes.len();
436 let send_data = self.send_impl(data_byte_len as u64)?;
437
438 if send_data {
439 let _ = self
440 .sender
441 .send(WebSocketDomAction::SendMessage(MessageData::Binary(bytes)));
442 }
443 Ok(())
444 }
445
446 fn Close(&self, code: Option<u16>, reason: Option<USVString>) -> ErrorResult {
448 if let Some(code) = code {
449 if code != close_code::NORMAL && !(3000..=4999).contains(&code) {
451 return Err(Error::InvalidAccess(None));
452 }
453 }
454 if let Some(ref reason) = reason {
455 if reason.0.len() > 123 {
456 return Err(Error::Syntax(Some("Reason too long".to_string())));
458 }
459 }
460
461 match self.ready_state.get() {
462 WebSocketRequestState::Closing | WebSocketRequestState::Closed => {}, WebSocketRequestState::Connecting => {
464 self.ready_state.set(WebSocketRequestState::Closing);
468
469 fail_the_websocket_connection(
470 Trusted::new(self),
471 &self
472 .global()
473 .task_manager()
474 .websocket_task_source()
475 .to_sendable(),
476 );
477 },
478 WebSocketRequestState::Open => {
479 self.ready_state.set(WebSocketRequestState::Closing);
480
481 let reason = reason.map(|reason| reason.0);
484 let _ = self.sender.send(WebSocketDomAction::Close(code, reason));
485 },
486 }
487 Ok(()) }
489}
490
491struct ReportCSPViolationTask {
492 websocket: Trusted<WebSocket>,
493 violations: Vec<Violation>,
494}
495
496impl TaskOnce for ReportCSPViolationTask {
497 fn run_once(self, _cx: &mut js::context::JSContext) {
498 let global = self.websocket.root().global();
499 global.report_csp_violations(self.violations, None, None);
500 }
501}
502
503struct ConnectionEstablishedTask {
506 address: Trusted<WebSocket>,
507 protocol_in_use: Option<String>,
508}
509
510impl TaskOnce for ConnectionEstablishedTask {
511 fn run_once(self, cx: &mut js::context::JSContext) {
513 let ws = self.address.root();
514
515 ws.ready_state.set(WebSocketRequestState::Open);
517
518 if let Some(protocol_name) = self.protocol_in_use {
523 *ws.protocol.borrow_mut() = protocol_name;
524 };
525
526 ws.upcast().fire_event(atom!("open"), CanGc::from_cx(cx));
528 }
529}
530
531struct BufferedAmountTask {
532 address: Trusted<WebSocket>,
533}
534
535impl TaskOnce for BufferedAmountTask {
536 fn run_once(self, _cx: &mut js::context::JSContext) {
542 let ws = self.address.root();
543
544 ws.buffered_amount.set(0);
545 ws.clearing_buffer.set(false);
546 }
547}
548
549struct CloseTask {
550 address: Trusted<WebSocket>,
551 failed: bool,
552 code: Option<u16>,
553 reason: Option<String>,
554}
555
556impl TaskOnce for CloseTask {
557 fn run_once(self, cx: &mut js::context::JSContext) {
558 let ws = self.address.root();
559
560 if ws.ready_state.get() == WebSocketRequestState::Closed {
561 return;
563 }
564
565 ws.ready_state.set(WebSocketRequestState::Closed);
570
571 if self.failed {
573 ws.upcast().fire_event(atom!("error"), CanGc::from_cx(cx));
574 }
575
576 let clean_close = !self.failed;
578 let code = self.code.unwrap_or(close_code::NO_STATUS);
579 let reason = DOMString::from(self.reason.unwrap_or("".to_owned()));
580 let close_event = CloseEvent::new(
581 &ws.global(),
582 atom!("close"),
583 EventBubbles::DoesNotBubble,
584 EventCancelable::NotCancelable,
585 clean_close,
586 code,
587 reason,
588 CanGc::from_cx(cx),
589 );
590 close_event
591 .upcast::<Event>()
592 .fire(ws.upcast(), CanGc::from_cx(cx));
593 }
594}
595
596struct MessageReceivedTask {
597 address: Trusted<WebSocket>,
598 message: MessageData,
599}
600
601impl TaskOnce for MessageReceivedTask {
602 #[expect(unsafe_code)]
603 fn run_once(self, cx: &mut js::context::JSContext) {
604 let ws = self.address.root();
605 debug!(
606 "MessageReceivedTask::handler({:p}): readyState={:?}",
607 &*ws,
608 ws.ready_state.get()
609 );
610
611 if ws.ready_state.get() != WebSocketRequestState::Open {
613 return;
614 }
615
616 let global = ws.global();
618 let mut realm = AutoRealm::new(
619 cx,
620 NonNull::new(ws.reflector().get_jsobject().get()).unwrap(),
621 );
622 let cx = &mut *realm;
623 rooted!(&in(cx) let mut message = UndefinedValue());
624 match self.message {
625 MessageData::Text(text) => {
626 text.safe_to_jsval(cx.into(), message.handle_mut(), CanGc::from_cx(cx))
627 },
628 MessageData::Binary(data) => match ws.binary_type.get() {
629 BinaryType::Blob => {
630 let blob = Blob::new(
631 &global,
632 BlobImpl::new_from_bytes(data, "".to_owned()),
633 CanGc::from_cx(cx),
634 );
635 blob.safe_to_jsval(cx.into(), message.handle_mut(), CanGc::from_cx(cx));
636 },
637 BinaryType::Arraybuffer => {
638 rooted!(&in(cx) let mut array_buffer = ptr::null_mut::<JSObject>());
639 unsafe {
641 assert!(
642 ArrayBuffer::create(
643 cx.raw_cx(),
644 CreateWith::Slice(&data),
645 array_buffer.handle_mut()
646 )
647 .is_ok()
648 )
649 };
650
651 (*array_buffer).safe_to_jsval(
652 cx.into(),
653 message.handle_mut(),
654 CanGc::from_cx(cx),
655 );
656 },
657 },
658 }
659 MessageEvent::dispatch_jsval(
660 ws.upcast(),
661 &global,
662 message.handle(),
663 Some(&ws.origin().ascii_serialization()),
664 None,
665 vec![],
666 CanGc::from_cx(cx),
667 );
668 }
669}