1use std::borrow::ToOwned;
6use std::cell::Cell;
7use std::ptr;
8
9use constellation_traits::BlobImpl;
10use dom_struct::dom_struct;
11use ipc_channel::ipc::{self, IpcReceiver, IpcSender};
12use ipc_channel::router::ROUTER;
13use js::jsapi::{JSAutoRealm, JSObject};
14use js::jsval::UndefinedValue;
15use js::rust::{CustomAutoRooterGuard, HandleObject};
16use js::typedarray::{ArrayBuffer, ArrayBufferView, CreateWith};
17use net_traits::request::{
18 CacheMode, CredentialsMode, RedirectMode, Referrer, RequestBuilder, RequestMode,
19 ServiceWorkersMode,
20};
21use net_traits::{
22 CoreResourceMsg, FetchChannels, MessageData, WebSocketDomAction, WebSocketNetworkEvent,
23};
24use profile_traits::ipc as ProfiledIpc;
25use script_bindings::conversions::SafeToJSValConvertible;
26use servo_url::{ImmutableOrigin, ServoUrl};
27
28use crate::dom::bindings::cell::DomRefCell;
29use crate::dom::bindings::codegen::Bindings::BlobBinding::BlobMethods;
30use crate::dom::bindings::codegen::Bindings::WebSocketBinding::{BinaryType, WebSocketMethods};
31use crate::dom::bindings::codegen::UnionTypes::StringOrStringSequence;
32use crate::dom::bindings::error::{Error, ErrorResult, Fallible};
33use crate::dom::bindings::inheritance::Castable;
34use crate::dom::bindings::refcounted::Trusted;
35use crate::dom::bindings::reflector::{DomGlobal, DomObject, reflect_dom_object_with_proto};
36use crate::dom::bindings::root::DomRoot;
37use crate::dom::bindings::str::{DOMString, USVString, is_token};
38use crate::dom::blob::Blob;
39use crate::dom::closeevent::CloseEvent;
40use crate::dom::csp::{GlobalCspReporting, Violation};
41use crate::dom::event::{Event, EventBubbles, EventCancelable};
42use crate::dom::eventtarget::EventTarget;
43use crate::dom::globalscope::GlobalScope;
44use crate::dom::messageevent::MessageEvent;
45use crate::script_runtime::CanGc;
46use crate::task::TaskOnce;
47use crate::task_source::SendableTaskSource;
48
49#[derive(Clone, Copy, Debug, JSTraceable, MallocSizeOf, PartialEq)]
50enum WebSocketRequestState {
51 Connecting = 0,
52 Open = 1,
53 Closing = 2,
54 Closed = 3,
55}
56
57#[allow(dead_code)]
60mod close_code {
61 pub(crate) const NORMAL: u16 = 1000;
62 pub(crate) const GOING_AWAY: u16 = 1001;
63 pub(crate) const PROTOCOL_ERROR: u16 = 1002;
64 pub(crate) const UNSUPPORTED_DATATYPE: u16 = 1003;
65 pub(crate) const NO_STATUS: u16 = 1005;
66 pub(crate) const ABNORMAL: u16 = 1006;
67 pub(crate) const INVALID_PAYLOAD: u16 = 1007;
68 pub(crate) const POLICY_VIOLATION: u16 = 1008;
69 pub(crate) const TOO_LARGE: u16 = 1009;
70 pub(crate) const EXTENSION_MISSING: u16 = 1010;
71 pub(crate) const INTERNAL_ERROR: u16 = 1011;
72 pub(crate) const TLS_FAILED: u16 = 1015;
73}
74
75fn close_the_websocket_connection(
76 address: Trusted<WebSocket>,
77 task_source: &SendableTaskSource,
78 code: Option<u16>,
79 reason: String,
80) {
81 task_source.queue(CloseTask {
82 address,
83 failed: false,
84 code,
85 reason: Some(reason),
86 });
87}
88
89fn fail_the_websocket_connection(address: Trusted<WebSocket>, task_source: &SendableTaskSource) {
90 task_source.queue(CloseTask {
91 address,
92 failed: true,
93 code: Some(close_code::ABNORMAL),
94 reason: None,
95 });
96}
97
98#[dom_struct]
99pub(crate) struct WebSocket {
100 eventtarget: EventTarget,
101 #[no_trace]
102 url: ServoUrl,
103 ready_state: Cell<WebSocketRequestState>,
104 buffered_amount: Cell<u64>,
105 clearing_buffer: Cell<bool>, #[ignore_malloc_size_of = "Defined in std"]
107 #[no_trace]
108 sender: IpcSender<WebSocketDomAction>,
109 binary_type: Cell<BinaryType>,
110 protocol: DomRefCell<String>, }
112
113impl WebSocket {
114 fn new_inherited(url: ServoUrl, sender: IpcSender<WebSocketDomAction>) -> WebSocket {
115 WebSocket {
116 eventtarget: EventTarget::new_inherited(),
117 url,
118 ready_state: Cell::new(WebSocketRequestState::Connecting),
119 buffered_amount: Cell::new(0),
120 clearing_buffer: Cell::new(false),
121 sender,
122 binary_type: Cell::new(BinaryType::Blob),
123 protocol: DomRefCell::new("".to_owned()),
124 }
125 }
126
127 fn new(
128 global: &GlobalScope,
129 proto: Option<HandleObject>,
130 url: ServoUrl,
131 sender: IpcSender<WebSocketDomAction>,
132 can_gc: CanGc,
133 ) -> DomRoot<WebSocket> {
134 reflect_dom_object_with_proto(
135 Box::new(WebSocket::new_inherited(url, sender)),
136 global,
137 proto,
138 can_gc,
139 )
140 }
141
142 fn send_impl(&self, data_byte_len: u64) -> Fallible<bool> {
144 let return_after_buffer = match self.ready_state.get() {
145 WebSocketRequestState::Connecting => {
146 return Err(Error::InvalidState);
147 },
148 WebSocketRequestState::Open => false,
149 WebSocketRequestState::Closing | WebSocketRequestState::Closed => true,
150 };
151
152 let address = Trusted::new(self);
153
154 match data_byte_len.checked_add(self.buffered_amount.get()) {
155 None => panic!(),
156 Some(new_amount) => self.buffered_amount.set(new_amount),
157 };
158
159 if return_after_buffer {
160 return Ok(false);
161 }
162
163 if !self.clearing_buffer.get() && self.ready_state.get() == WebSocketRequestState::Open {
164 self.clearing_buffer.set(true);
165
166 self.global()
168 .task_manager()
169 .websocket_task_source()
170 .queue_unconditionally(BufferedAmountTask { address });
171 }
172
173 Ok(true)
174 }
175
176 pub(crate) fn origin(&self) -> ImmutableOrigin {
177 self.url.origin()
178 }
179}
180
181impl WebSocketMethods<crate::DomTypeHolder> for WebSocket {
182 fn Constructor(
184 global: &GlobalScope,
185 proto: Option<HandleObject>,
186 can_gc: CanGc,
187 url: DOMString,
188 protocols: Option<StringOrStringSequence>,
189 ) -> Fallible<DomRoot<WebSocket>> {
190 let mut url_record = ServoUrl::parse(&url).or(Err(Error::Syntax))?;
194
195 match url_record.scheme() {
199 "http" => {
200 url_record
201 .as_mut_url()
202 .set_scheme("ws")
203 .expect("Can't set scheme from http to ws");
204 },
205 "https" => {
206 url_record
207 .as_mut_url()
208 .set_scheme("wss")
209 .expect("Can't set scheme from https to wss");
210 },
211 "ws" | "wss" => {},
212 _ => return Err(Error::Syntax),
213 }
214
215 if url_record.fragment().is_some() {
217 return Err(Error::Syntax);
218 }
219
220 let protocols = protocols.map_or(vec![], |p| match p {
222 StringOrStringSequence::String(string) => vec![string.into()],
223 StringOrStringSequence::StringSequence(seq) => {
224 seq.into_iter().map(String::from).collect()
225 },
226 });
227
228 for (i, protocol) in protocols.iter().enumerate() {
232 if protocols[i + 1..]
236 .iter()
237 .any(|p| p.eq_ignore_ascii_case(protocol))
238 {
239 return Err(Error::Syntax);
240 }
241
242 if !is_token(protocol.as_bytes()) {
244 return Err(Error::Syntax);
245 }
246 }
247
248 let (dom_action_sender, resource_action_receiver): (
250 IpcSender<WebSocketDomAction>,
251 IpcReceiver<WebSocketDomAction>,
252 ) = ipc::channel().unwrap();
253 let (resource_event_sender, dom_event_receiver): (
254 IpcSender<WebSocketNetworkEvent>,
255 ProfiledIpc::IpcReceiver<WebSocketNetworkEvent>,
256 ) = ProfiledIpc::channel(global.time_profiler_chan().clone()).unwrap();
257
258 let ws = WebSocket::new(global, proto, url_record.clone(), dom_action_sender, can_gc);
260 let address = Trusted::new(&*ws);
261
262 let request = RequestBuilder::new(global.webview_id(), url_record, Referrer::NoReferrer)
263 .origin(global.origin().immutable().clone())
264 .insecure_requests_policy(global.insecure_requests_policy())
265 .has_trustworthy_ancestor_origin(global.has_trustworthy_ancestor_or_current_origin())
266 .mode(RequestMode::WebSocket { protocols })
267 .service_workers_mode(ServiceWorkersMode::None)
268 .credentials_mode(CredentialsMode::Include)
269 .cache_mode(CacheMode::NoCache)
270 .policy_container(global.policy_container())
271 .redirect_mode(RedirectMode::Error);
272
273 let channels = FetchChannels::WebSocket {
274 event_sender: resource_event_sender,
275 action_receiver: resource_action_receiver,
276 };
277 let _ = global
278 .core_resource_thread()
279 .send(CoreResourceMsg::Fetch(request, channels));
280
281 let task_source = global.task_manager().websocket_task_source().to_sendable();
282 ROUTER.add_typed_route(
283 dom_event_receiver.to_ipc_receiver(),
284 Box::new(move |message| match message.unwrap() {
285 WebSocketNetworkEvent::ReportCSPViolations(violations) => {
286 let task = ReportCSPViolationTask {
287 websocket: address.clone(),
288 violations,
289 };
290 task_source.queue(task);
291 },
292 WebSocketNetworkEvent::ConnectionEstablished { protocol_in_use } => {
293 let open_thread = ConnectionEstablishedTask {
294 address: address.clone(),
295 protocol_in_use,
296 };
297 task_source.queue(open_thread);
298 },
299 WebSocketNetworkEvent::MessageReceived(message) => {
300 let message_thread = MessageReceivedTask {
301 address: address.clone(),
302 message,
303 };
304 task_source.queue(message_thread);
305 },
306 WebSocketNetworkEvent::Fail => {
307 fail_the_websocket_connection(address.clone(), &task_source);
308 },
309 WebSocketNetworkEvent::Close(code, reason) => {
310 close_the_websocket_connection(address.clone(), &task_source, code, reason);
311 },
312 }),
313 );
314
315 Ok(ws)
316 }
317
318 event_handler!(open, GetOnopen, SetOnopen);
320
321 event_handler!(close, GetOnclose, SetOnclose);
323
324 event_handler!(error, GetOnerror, SetOnerror);
326
327 event_handler!(message, GetOnmessage, SetOnmessage);
329
330 fn Url(&self) -> DOMString {
332 DOMString::from(self.url.as_str())
333 }
334
335 fn ReadyState(&self) -> u16 {
337 self.ready_state.get() as u16
338 }
339
340 fn BufferedAmount(&self) -> u64 {
342 self.buffered_amount.get()
343 }
344
345 fn BinaryType(&self) -> BinaryType {
347 self.binary_type.get()
348 }
349
350 fn SetBinaryType(&self, btype: BinaryType) {
352 self.binary_type.set(btype)
353 }
354
355 fn Protocol(&self) -> DOMString {
357 DOMString::from(self.protocol.borrow().clone())
358 }
359
360 fn Send(&self, data: USVString) -> ErrorResult {
362 let data_byte_len = data.0.len() as u64;
363 let send_data = self.send_impl(data_byte_len)?;
364
365 if send_data {
366 let _ = self
367 .sender
368 .send(WebSocketDomAction::SendMessage(MessageData::Text(data.0)));
369 }
370
371 Ok(())
372 }
373
374 fn Send_(&self, blob: &Blob) -> ErrorResult {
376 let data_byte_len = blob.Size();
381 let send_data = self.send_impl(data_byte_len)?;
382
383 if send_data {
384 let bytes = blob.get_bytes().unwrap_or_default();
385 let _ = self
386 .sender
387 .send(WebSocketDomAction::SendMessage(MessageData::Binary(bytes)));
388 }
389
390 Ok(())
391 }
392
393 fn Send__(&self, array: CustomAutoRooterGuard<ArrayBuffer>) -> ErrorResult {
395 let bytes = array.to_vec();
396 let data_byte_len = bytes.len();
397 let send_data = self.send_impl(data_byte_len as u64)?;
398
399 if send_data {
400 let _ = self
401 .sender
402 .send(WebSocketDomAction::SendMessage(MessageData::Binary(bytes)));
403 }
404 Ok(())
405 }
406
407 fn Send___(&self, array: CustomAutoRooterGuard<ArrayBufferView>) -> ErrorResult {
409 let bytes = array.to_vec();
410 let data_byte_len = bytes.len();
411 let send_data = self.send_impl(data_byte_len as u64)?;
412
413 if send_data {
414 let _ = self
415 .sender
416 .send(WebSocketDomAction::SendMessage(MessageData::Binary(bytes)));
417 }
418 Ok(())
419 }
420
421 fn Close(&self, code: Option<u16>, reason: Option<USVString>) -> ErrorResult {
423 if let Some(code) = code {
424 if code != close_code::NORMAL && !(3000..=4999).contains(&code) {
426 return Err(Error::InvalidAccess);
427 }
428 }
429 if let Some(ref reason) = reason {
430 if reason.0.len() > 123 {
431 return Err(Error::Syntax);
433 }
434 }
435
436 match self.ready_state.get() {
437 WebSocketRequestState::Closing | WebSocketRequestState::Closed => {}, WebSocketRequestState::Connecting => {
439 self.ready_state.set(WebSocketRequestState::Closing);
443
444 fail_the_websocket_connection(
445 Trusted::new(self),
446 &self
447 .global()
448 .task_manager()
449 .websocket_task_source()
450 .to_sendable(),
451 );
452 },
453 WebSocketRequestState::Open => {
454 self.ready_state.set(WebSocketRequestState::Closing);
455
456 let reason = reason.map(|reason| reason.0);
459 let _ = self.sender.send(WebSocketDomAction::Close(code, reason));
460 },
461 }
462 Ok(()) }
464}
465
466struct ReportCSPViolationTask {
467 websocket: Trusted<WebSocket>,
468 violations: Vec<Violation>,
469}
470
471impl TaskOnce for ReportCSPViolationTask {
472 fn run_once(self) {
473 let global = self.websocket.root().global();
474 global.report_csp_violations(self.violations, None, None);
475 }
476}
477
478struct ConnectionEstablishedTask {
481 address: Trusted<WebSocket>,
482 protocol_in_use: Option<String>,
483}
484
485impl TaskOnce for ConnectionEstablishedTask {
486 fn run_once(self) {
488 let ws = self.address.root();
489
490 ws.ready_state.set(WebSocketRequestState::Open);
492
493 if let Some(protocol_name) = self.protocol_in_use {
498 *ws.protocol.borrow_mut() = protocol_name;
499 };
500
501 ws.upcast().fire_event(atom!("open"), CanGc::note());
503 }
504}
505
506struct BufferedAmountTask {
507 address: Trusted<WebSocket>,
508}
509
510impl TaskOnce for BufferedAmountTask {
511 fn run_once(self) {
517 let ws = self.address.root();
518
519 ws.buffered_amount.set(0);
520 ws.clearing_buffer.set(false);
521 }
522}
523
524struct CloseTask {
525 address: Trusted<WebSocket>,
526 failed: bool,
527 code: Option<u16>,
528 reason: Option<String>,
529}
530
531impl TaskOnce for CloseTask {
532 fn run_once(self) {
533 let ws = self.address.root();
534
535 if ws.ready_state.get() == WebSocketRequestState::Closed {
536 return;
538 }
539
540 ws.ready_state.set(WebSocketRequestState::Closed);
545
546 if self.failed {
548 ws.upcast().fire_event(atom!("error"), CanGc::note());
549 }
550
551 let clean_close = !self.failed;
553 let code = self.code.unwrap_or(close_code::NO_STATUS);
554 let reason = DOMString::from(self.reason.unwrap_or("".to_owned()));
555 let close_event = CloseEvent::new(
556 &ws.global(),
557 atom!("close"),
558 EventBubbles::DoesNotBubble,
559 EventCancelable::NotCancelable,
560 clean_close,
561 code,
562 reason,
563 CanGc::note(),
564 );
565 close_event
566 .upcast::<Event>()
567 .fire(ws.upcast(), CanGc::note());
568 }
569}
570
571struct MessageReceivedTask {
572 address: Trusted<WebSocket>,
573 message: MessageData,
574}
575
576impl TaskOnce for MessageReceivedTask {
577 #[allow(unsafe_code)]
578 fn run_once(self) {
579 let ws = self.address.root();
580 debug!(
581 "MessageReceivedTask::handler({:p}): readyState={:?}",
582 &*ws,
583 ws.ready_state.get()
584 );
585
586 if ws.ready_state.get() != WebSocketRequestState::Open {
588 return;
589 }
590
591 let global = ws.global();
593 let cx = GlobalScope::get_cx();
594 let _ac = JSAutoRealm::new(*cx, ws.reflector().get_jsobject().get());
595 rooted!(in(*cx) let mut message = UndefinedValue());
596 match self.message {
597 MessageData::Text(text) => text.safe_to_jsval(cx, message.handle_mut()),
598 MessageData::Binary(data) => match ws.binary_type.get() {
599 BinaryType::Blob => {
600 let blob = Blob::new(
601 &global,
602 BlobImpl::new_from_bytes(data, "".to_owned()),
603 CanGc::note(),
604 );
605 blob.safe_to_jsval(cx, message.handle_mut());
606 },
607 BinaryType::Arraybuffer => {
608 rooted!(in(*cx) let mut array_buffer = ptr::null_mut::<JSObject>());
609 unsafe {
611 assert!(
612 ArrayBuffer::create(
613 *cx,
614 CreateWith::Slice(&data),
615 array_buffer.handle_mut()
616 )
617 .is_ok()
618 )
619 };
620
621 (*array_buffer).safe_to_jsval(cx, message.handle_mut());
622 },
623 },
624 }
625 MessageEvent::dispatch_jsval(
626 ws.upcast(),
627 &global,
628 message.handle(),
629 Some(&ws.origin().ascii_serialization()),
630 None,
631 vec![],
632 CanGc::note(),
633 );
634 }
635}