1use std::borrow::ToOwned;
6use std::cell::Cell;
7use std::ptr::{self, NonNull};
8
9use dom_struct::dom_struct;
10use ipc_channel::router::ROUTER;
11use js::jsapi::JSObject;
12use js::jsval::UndefinedValue;
13use js::realm::AutoRealm;
14use js::rust::{CustomAutoRooterGuard, HandleObject};
15use js::typedarray::{ArrayBuffer, ArrayBufferView, CreateWith};
16use net_traits::blob_url_store::UrlWithBlobClaim;
17use net_traits::request::{
18 CacheMode, CredentialsMode, RedirectMode, Referrer, RequestBuilder, RequestMode,
19 ServiceWorkersMode,
20};
21use net_traits::{
22 CoreResourceMsg, FetchChannels, MessageData, WebSocketDomAction, WebSocketNetworkEvent,
23};
24use profile_traits::ipc as ProfiledIpc;
25use script_bindings::conversions::SafeToJSValConvertible;
26use servo_base::generic_channel::{LazyCallback, lazy_callback};
27use servo_constellation_traits::BlobImpl;
28use servo_url::{ImmutableOrigin, ServoUrl};
29
30use crate::dom::bindings::cell::DomRefCell;
31use crate::dom::bindings::codegen::Bindings::BlobBinding::BlobMethods;
32use crate::dom::bindings::codegen::Bindings::WebSocketBinding::{BinaryType, WebSocketMethods};
33use crate::dom::bindings::codegen::Bindings::WindowBinding::WindowMethods;
34use crate::dom::bindings::codegen::UnionTypes::StringOrStringSequence;
35use crate::dom::bindings::error::{Error, ErrorResult, Fallible};
36use crate::dom::bindings::inheritance::Castable;
37use crate::dom::bindings::refcounted::Trusted;
38use crate::dom::bindings::reflector::{DomGlobal, DomObject, reflect_dom_object_with_proto};
39use crate::dom::bindings::root::DomRoot;
40use crate::dom::bindings::str::{DOMString, USVString, is_token};
41use crate::dom::blob::Blob;
42use crate::dom::closeevent::CloseEvent;
43use crate::dom::csp::{GlobalCspReporting, Violation};
44use crate::dom::event::{Event, EventBubbles, EventCancelable};
45use crate::dom::eventtarget::EventTarget;
46use crate::dom::globalscope::GlobalScope;
47use crate::dom::messageevent::MessageEvent;
48use crate::dom::window::Window;
49use crate::fetch::RequestWithGlobalScope;
50use crate::script_runtime::CanGc;
51use crate::task::TaskOnce;
52use crate::task_source::SendableTaskSource;
53
54#[derive(Clone, Copy, Debug, JSTraceable, MallocSizeOf, PartialEq)]
55enum WebSocketRequestState {
56 Connecting = 0,
57 Open = 1,
58 Closing = 2,
59 Closed = 3,
60}
61
62#[expect(dead_code)]
65mod close_code {
66 pub(crate) const NORMAL: u16 = 1000;
67 pub(crate) const GOING_AWAY: u16 = 1001;
68 pub(crate) const PROTOCOL_ERROR: u16 = 1002;
69 pub(crate) const UNSUPPORTED_DATATYPE: u16 = 1003;
70 pub(crate) const NO_STATUS: u16 = 1005;
71 pub(crate) const ABNORMAL: u16 = 1006;
72 pub(crate) const INVALID_PAYLOAD: u16 = 1007;
73 pub(crate) const POLICY_VIOLATION: u16 = 1008;
74 pub(crate) const TOO_LARGE: u16 = 1009;
75 pub(crate) const EXTENSION_MISSING: u16 = 1010;
76 pub(crate) const INTERNAL_ERROR: u16 = 1011;
77 pub(crate) const TLS_FAILED: u16 = 1015;
78}
79
80fn close_the_websocket_connection(
81 address: Trusted<WebSocket>,
82 task_source: &SendableTaskSource,
83 code: Option<u16>,
84 reason: String,
85) {
86 task_source.queue(CloseTask {
87 address,
88 failed: false,
89 code,
90 reason: Some(reason),
91 });
92}
93
94fn fail_the_websocket_connection(address: Trusted<WebSocket>, task_source: &SendableTaskSource) {
95 task_source.queue(CloseTask {
96 address,
97 failed: true,
98 code: Some(close_code::ABNORMAL),
99 reason: None,
100 });
101}
102
103#[dom_struct]
104pub(crate) struct WebSocket {
105 eventtarget: EventTarget,
106 #[no_trace]
107 url: ServoUrl,
108 ready_state: Cell<WebSocketRequestState>,
109 buffered_amount: Cell<u64>,
110 clearing_buffer: Cell<bool>, #[no_trace]
112 callback: LazyCallback<WebSocketDomAction>,
113 binary_type: Cell<BinaryType>,
114 protocol: DomRefCell<String>, }
116
117impl WebSocket {
118 fn new_inherited(url: ServoUrl, callback: LazyCallback<WebSocketDomAction>) -> WebSocket {
119 WebSocket {
120 eventtarget: EventTarget::new_inherited(),
121 url,
122 ready_state: Cell::new(WebSocketRequestState::Connecting),
123 buffered_amount: Cell::new(0),
124 clearing_buffer: Cell::new(false),
125 callback,
126 binary_type: Cell::new(BinaryType::Blob),
127 protocol: DomRefCell::new("".to_owned()),
128 }
129 }
130
131 fn new(
132 global: &GlobalScope,
133 proto: Option<HandleObject>,
134 url: ServoUrl,
135 sender: LazyCallback<WebSocketDomAction>,
136 can_gc: CanGc,
137 ) -> DomRoot<WebSocket> {
138 let websocket = reflect_dom_object_with_proto(
139 Box::new(WebSocket::new_inherited(url, sender)),
140 global,
141 proto,
142 can_gc,
143 );
144 if let Some(window) = global.downcast::<Window>() {
145 window.Document().track_websocket(&websocket);
146 }
147 websocket
148 }
149
150 fn send_impl(&self, data_byte_len: u64) -> Fallible<bool> {
152 let return_after_buffer = match self.ready_state.get() {
153 WebSocketRequestState::Connecting => {
154 return Err(Error::InvalidState(None));
155 },
156 WebSocketRequestState::Open => false,
157 WebSocketRequestState::Closing | WebSocketRequestState::Closed => true,
158 };
159
160 let address = Trusted::new(self);
161
162 match data_byte_len.checked_add(self.buffered_amount.get()) {
163 None => panic!(),
164 Some(new_amount) => self.buffered_amount.set(new_amount),
165 };
166
167 if return_after_buffer {
168 return Ok(false);
169 }
170
171 if !self.clearing_buffer.get() && self.ready_state.get() == WebSocketRequestState::Open {
172 self.clearing_buffer.set(true);
173
174 self.global()
176 .task_manager()
177 .websocket_task_source()
178 .queue_unconditionally(BufferedAmountTask { address });
179 }
180
181 Ok(true)
182 }
183
184 pub(crate) fn origin(&self) -> ImmutableOrigin {
185 self.url.origin()
186 }
187
188 pub(crate) fn make_disappear(&self) -> bool {
191 let result = self.ready_state.get() != WebSocketRequestState::Closed;
192 let _ = self.Close(Some(1001), None);
193 result
194 }
195}
196
197impl WebSocketMethods<crate::DomTypeHolder> for WebSocket {
198 fn Constructor(
200 global: &GlobalScope,
201 proto: Option<HandleObject>,
202 can_gc: CanGc,
203 url: DOMString,
204 protocols: Option<StringOrStringSequence>,
205 ) -> Fallible<DomRoot<WebSocket>> {
206 let base_url = global.api_base_url();
208 let mut url_record =
211 ServoUrl::parse_with_base(Some(&base_url), &url.str()).or(Err(Error::Syntax(None)))?;
212
213 match url_record.scheme() {
217 "http" => {
218 url_record
219 .as_mut_url()
220 .set_scheme("ws")
221 .expect("Can't set scheme from http to ws");
222 },
223 "https" => {
224 url_record
225 .as_mut_url()
226 .set_scheme("wss")
227 .expect("Can't set scheme from https to wss");
228 },
229 "ws" | "wss" => {},
230 _ => return Err(Error::Syntax(None)),
231 }
232
233 if url_record.fragment().is_some() {
235 return Err(Error::Syntax(None));
236 }
237
238 let protocols = protocols.map_or(vec![], |p| match p {
240 StringOrStringSequence::String(string) => vec![string.into()],
241 StringOrStringSequence::StringSequence(seq) => {
242 seq.into_iter().map(String::from).collect()
243 },
244 });
245
246 for (i, protocol) in protocols.iter().enumerate() {
250 if protocols[i + 1..]
254 .iter()
255 .any(|p| p.eq_ignore_ascii_case(protocol))
256 {
257 return Err(Error::Syntax(None));
258 }
259
260 if !is_token(protocol.as_bytes()) {
262 return Err(Error::Syntax(None));
263 }
264 }
265
266 let (dom_action_sender, resource_action_receiver) = lazy_callback();
268 let (resource_event_sender, dom_event_receiver) =
269 ProfiledIpc::channel(global.time_profiler_chan().clone()).unwrap();
270
271 let ws = WebSocket::new(global, proto, url_record.clone(), dom_action_sender, can_gc);
273 let address = Trusted::new(&*ws);
274
275 let request = RequestBuilder::new(
281 global.webview_id(),
282 UrlWithBlobClaim::from_url_without_having_claimed_blob(url_record.clone()),
283 Referrer::NoReferrer,
284 )
285 .with_global_scope(global)
286 .mode(RequestMode::WebSocket {
287 protocols,
288 original_url: url_record,
289 })
290 .service_workers_mode(ServiceWorkersMode::None)
291 .credentials_mode(CredentialsMode::Include)
292 .cache_mode(CacheMode::NoCache)
293 .redirect_mode(RedirectMode::Error);
294
295 let channels = FetchChannels::WebSocket {
296 event_sender: resource_event_sender,
297 action_receiver: resource_action_receiver,
298 };
299 let _ = global
300 .core_resource_thread()
301 .send(CoreResourceMsg::Fetch(request, channels));
302
303 let task_source = global.task_manager().websocket_task_source().to_sendable();
304 ROUTER.add_typed_route(
305 dom_event_receiver.to_ipc_receiver(),
306 Box::new(move |message| match message.unwrap() {
307 WebSocketNetworkEvent::ReportCSPViolations(violations) => {
308 let task = ReportCSPViolationTask {
309 websocket: address.clone(),
310 violations,
311 };
312 task_source.queue(task);
313 },
314 WebSocketNetworkEvent::ConnectionEstablished { protocol_in_use } => {
315 let open_thread = ConnectionEstablishedTask {
316 address: address.clone(),
317 protocol_in_use,
318 };
319 task_source.queue(open_thread);
320 },
321 WebSocketNetworkEvent::MessageReceived(message) => {
322 let message_thread = MessageReceivedTask {
323 address: address.clone(),
324 message,
325 };
326 task_source.queue(message_thread);
327 },
328 WebSocketNetworkEvent::Fail => {
329 fail_the_websocket_connection(address.clone(), &task_source);
330 },
331 WebSocketNetworkEvent::Close(code, reason) => {
332 close_the_websocket_connection(address.clone(), &task_source, code, reason);
333 },
334 }),
335 );
336
337 Ok(ws)
338 }
339
340 event_handler!(open, GetOnopen, SetOnopen);
342
343 event_handler!(close, GetOnclose, SetOnclose);
345
346 event_handler!(error, GetOnerror, SetOnerror);
348
349 event_handler!(message, GetOnmessage, SetOnmessage);
351
352 fn Url(&self) -> DOMString {
354 DOMString::from(self.url.as_str())
355 }
356
357 fn ReadyState(&self) -> u16 {
359 self.ready_state.get() as u16
360 }
361
362 fn BufferedAmount(&self) -> u64 {
364 self.buffered_amount.get()
365 }
366
367 fn BinaryType(&self) -> BinaryType {
369 self.binary_type.get()
370 }
371
372 fn SetBinaryType(&self, btype: BinaryType) {
374 self.binary_type.set(btype)
375 }
376
377 fn Protocol(&self) -> DOMString {
379 DOMString::from(self.protocol.borrow().clone())
380 }
381
382 fn Send(&self, data: USVString) -> ErrorResult {
384 let data_byte_len = data.0.len() as u64;
385 let send_data = self.send_impl(data_byte_len)?;
386
387 if send_data {
388 let _ = self
389 .callback
390 .send(WebSocketDomAction::SendMessage(MessageData::Text(data.0)));
391 }
392
393 Ok(())
394 }
395
396 fn Send_(&self, blob: &Blob) -> ErrorResult {
398 let data_byte_len = blob.Size();
403 let send_data = self.send_impl(data_byte_len)?;
404
405 if send_data {
406 let bytes = blob.get_bytes().unwrap_or_default();
407 let _ = self
408 .callback
409 .send(WebSocketDomAction::SendMessage(MessageData::Binary(bytes)));
410 }
411
412 Ok(())
413 }
414
415 fn Send__(&self, array: CustomAutoRooterGuard<ArrayBuffer>) -> ErrorResult {
417 let bytes = array.to_vec();
418 let data_byte_len = bytes.len();
419 let send_data = self.send_impl(data_byte_len as u64)?;
420
421 if send_data {
422 let _ = self
423 .callback
424 .send(WebSocketDomAction::SendMessage(MessageData::Binary(bytes)));
425 }
426 Ok(())
427 }
428
429 fn Send___(&self, array: CustomAutoRooterGuard<ArrayBufferView>) -> ErrorResult {
431 let bytes = array.to_vec();
432 let data_byte_len = bytes.len();
433 let send_data = self.send_impl(data_byte_len as u64)?;
434
435 if send_data {
436 let _ = self
437 .callback
438 .send(WebSocketDomAction::SendMessage(MessageData::Binary(bytes)));
439 }
440 Ok(())
441 }
442
443 fn Close(&self, code: Option<u16>, reason: Option<USVString>) -> ErrorResult {
445 if let Some(code) = code {
446 if code != close_code::NORMAL && !(3000..=4999).contains(&code) {
448 return Err(Error::InvalidAccess(None));
449 }
450 }
451 if let Some(ref reason) = reason {
452 if reason.0.len() > 123 {
453 return Err(Error::Syntax(Some("Reason too long".to_string())));
455 }
456 }
457
458 match self.ready_state.get() {
459 WebSocketRequestState::Closing | WebSocketRequestState::Closed => {}, WebSocketRequestState::Connecting => {
461 self.ready_state.set(WebSocketRequestState::Closing);
465
466 fail_the_websocket_connection(
467 Trusted::new(self),
468 &self
469 .global()
470 .task_manager()
471 .websocket_task_source()
472 .to_sendable(),
473 );
474 },
475 WebSocketRequestState::Open => {
476 self.ready_state.set(WebSocketRequestState::Closing);
477
478 let reason = reason.map(|reason| reason.0);
481 let _ = self.callback.send(WebSocketDomAction::Close(code, reason));
482 },
483 }
484 Ok(()) }
486}
487
488struct ReportCSPViolationTask {
489 websocket: Trusted<WebSocket>,
490 violations: Vec<Violation>,
491}
492
493impl TaskOnce for ReportCSPViolationTask {
494 fn run_once(self, _cx: &mut js::context::JSContext) {
495 let global = self.websocket.root().global();
496 global.report_csp_violations(self.violations, None, None);
497 }
498}
499
500struct ConnectionEstablishedTask {
503 address: Trusted<WebSocket>,
504 protocol_in_use: Option<String>,
505}
506
507impl TaskOnce for ConnectionEstablishedTask {
508 fn run_once(self, cx: &mut js::context::JSContext) {
510 let ws = self.address.root();
511
512 ws.ready_state.set(WebSocketRequestState::Open);
514
515 if let Some(protocol_name) = self.protocol_in_use {
520 *ws.protocol.borrow_mut() = protocol_name;
521 };
522
523 ws.upcast().fire_event(atom!("open"), CanGc::from_cx(cx));
525 }
526}
527
528struct BufferedAmountTask {
529 address: Trusted<WebSocket>,
530}
531
532impl TaskOnce for BufferedAmountTask {
533 fn run_once(self, _cx: &mut js::context::JSContext) {
539 let ws = self.address.root();
540
541 ws.buffered_amount.set(0);
542 ws.clearing_buffer.set(false);
543 }
544}
545
546struct CloseTask {
547 address: Trusted<WebSocket>,
548 failed: bool,
549 code: Option<u16>,
550 reason: Option<String>,
551}
552
553impl TaskOnce for CloseTask {
554 fn run_once(self, cx: &mut js::context::JSContext) {
555 let ws = self.address.root();
556
557 if ws.ready_state.get() == WebSocketRequestState::Closed {
558 return;
560 }
561
562 ws.ready_state.set(WebSocketRequestState::Closed);
567
568 if self.failed {
570 ws.upcast().fire_event(atom!("error"), CanGc::from_cx(cx));
571 }
572
573 let clean_close = !self.failed;
575 let code = self.code.unwrap_or(close_code::NO_STATUS);
576 let reason = DOMString::from(self.reason.unwrap_or("".to_owned()));
577 let close_event = CloseEvent::new(
578 &ws.global(),
579 atom!("close"),
580 EventBubbles::DoesNotBubble,
581 EventCancelable::NotCancelable,
582 clean_close,
583 code,
584 reason,
585 CanGc::from_cx(cx),
586 );
587 close_event
588 .upcast::<Event>()
589 .fire(ws.upcast(), CanGc::from_cx(cx));
590 }
591}
592
593struct MessageReceivedTask {
594 address: Trusted<WebSocket>,
595 message: MessageData,
596}
597
598impl TaskOnce for MessageReceivedTask {
599 #[expect(unsafe_code)]
600 fn run_once(self, cx: &mut js::context::JSContext) {
601 let ws = self.address.root();
602 debug!(
603 "MessageReceivedTask::handler({:p}): readyState={:?}",
604 &*ws,
605 ws.ready_state.get()
606 );
607
608 if ws.ready_state.get() != WebSocketRequestState::Open {
610 return;
611 }
612
613 let global = ws.global();
615 let mut realm = AutoRealm::new(
616 cx,
617 NonNull::new(ws.reflector().get_jsobject().get()).unwrap(),
618 );
619 let cx = &mut *realm;
620 rooted!(&in(cx) let mut message = UndefinedValue());
621 match self.message {
622 MessageData::Text(text) => {
623 text.safe_to_jsval(cx.into(), message.handle_mut(), CanGc::from_cx(cx))
624 },
625 MessageData::Binary(data) => match ws.binary_type.get() {
626 BinaryType::Blob => {
627 let blob = Blob::new(
628 &global,
629 BlobImpl::new_from_bytes(data, "".to_owned()),
630 CanGc::from_cx(cx),
631 );
632 blob.safe_to_jsval(cx.into(), message.handle_mut(), CanGc::from_cx(cx));
633 },
634 BinaryType::Arraybuffer => {
635 rooted!(&in(cx) let mut array_buffer = ptr::null_mut::<JSObject>());
636 unsafe {
638 assert!(
639 ArrayBuffer::create(
640 cx.raw_cx(),
641 CreateWith::Slice(&data),
642 array_buffer.handle_mut()
643 )
644 .is_ok()
645 )
646 };
647
648 (*array_buffer).safe_to_jsval(
649 cx.into(),
650 message.handle_mut(),
651 CanGc::from_cx(cx),
652 );
653 },
654 },
655 }
656 MessageEvent::dispatch_jsval(
657 ws.upcast(),
658 &global,
659 message.handle(),
660 Some(&ws.origin().ascii_serialization()),
661 None,
662 vec![],
663 CanGc::from_cx(cx),
664 );
665 }
666}