use std::borrow::ToOwned;
use std::cell::Cell;
use std::ptr;
use dom_struct::dom_struct;
use ipc_channel::ipc::{self, IpcReceiver, IpcSender};
use ipc_channel::router::ROUTER;
use js::jsapi::{JSAutoRealm, JSObject};
use js::jsval::UndefinedValue;
use js::rust::{CustomAutoRooterGuard, HandleObject};
use js::typedarray::{ArrayBuffer, ArrayBufferView, CreateWith};
use net_traits::request::{Referrer, RequestBuilder, RequestMode};
use net_traits::{
CoreResourceMsg, FetchChannels, MessageData, WebSocketDomAction, WebSocketNetworkEvent,
};
use profile_traits::ipc as ProfiledIpc;
use script_traits::serializable::BlobImpl;
use servo_url::{ImmutableOrigin, ServoUrl};
use crate::dom::bindings::cell::DomRefCell;
use crate::dom::bindings::codegen::Bindings::BlobBinding::BlobMethods;
use crate::dom::bindings::codegen::Bindings::WebSocketBinding::{BinaryType, WebSocketMethods};
use crate::dom::bindings::codegen::UnionTypes::StringOrStringSequence;
use crate::dom::bindings::conversions::ToJSValConvertible;
use crate::dom::bindings::error::{Error, ErrorResult, Fallible};
use crate::dom::bindings::inheritance::Castable;
use crate::dom::bindings::refcounted::Trusted;
use crate::dom::bindings::reflector::{reflect_dom_object_with_proto, DomObject};
use crate::dom::bindings::root::DomRoot;
use crate::dom::bindings::str::{is_token, DOMString, USVString};
use crate::dom::blob::Blob;
use crate::dom::closeevent::CloseEvent;
use crate::dom::event::{Event, EventBubbles, EventCancelable};
use crate::dom::eventtarget::EventTarget;
use crate::dom::globalscope::GlobalScope;
use crate::dom::messageevent::MessageEvent;
use crate::script_runtime::ScriptThreadEventCategory::WebSocketEvent;
use crate::script_runtime::{CanGc, CommonScriptMsg};
use crate::task::{TaskCanceller, TaskOnce};
use crate::task_source::websocket::WebsocketTaskSource;
use crate::task_source::TaskSource;
#[derive(Clone, Copy, Debug, JSTraceable, MallocSizeOf, PartialEq)]
enum WebSocketRequestState {
Connecting = 0,
Open = 1,
Closing = 2,
Closed = 3,
}
#[allow(dead_code)]
mod close_code {
pub const NORMAL: u16 = 1000;
pub const GOING_AWAY: u16 = 1001;
pub const PROTOCOL_ERROR: u16 = 1002;
pub const UNSUPPORTED_DATATYPE: u16 = 1003;
pub const NO_STATUS: u16 = 1005;
pub const ABNORMAL: u16 = 1006;
pub const INVALID_PAYLOAD: u16 = 1007;
pub const POLICY_VIOLATION: u16 = 1008;
pub const TOO_LARGE: u16 = 1009;
pub const EXTENSION_MISSING: u16 = 1010;
pub const INTERNAL_ERROR: u16 = 1011;
pub const TLS_FAILED: u16 = 1015;
}
fn close_the_websocket_connection(
address: Trusted<WebSocket>,
task_source: &WebsocketTaskSource,
canceller: &TaskCanceller,
code: Option<u16>,
reason: String,
) {
let close_task = CloseTask {
address,
failed: false,
code,
reason: Some(reason),
};
let _ = task_source.queue_with_canceller(close_task, canceller);
}
fn fail_the_websocket_connection(
address: Trusted<WebSocket>,
task_source: &WebsocketTaskSource,
canceller: &TaskCanceller,
) {
let close_task = CloseTask {
address,
failed: true,
code: Some(close_code::ABNORMAL),
reason: None,
};
let _ = task_source.queue_with_canceller(close_task, canceller);
}
#[dom_struct]
pub struct WebSocket {
eventtarget: EventTarget,
#[no_trace]
url: ServoUrl,
ready_state: Cell<WebSocketRequestState>,
buffered_amount: Cell<u64>,
clearing_buffer: Cell<bool>, #[ignore_malloc_size_of = "Defined in std"]
#[no_trace]
sender: IpcSender<WebSocketDomAction>,
binary_type: Cell<BinaryType>,
protocol: DomRefCell<String>, }
impl WebSocket {
fn new_inherited(url: ServoUrl, sender: IpcSender<WebSocketDomAction>) -> WebSocket {
WebSocket {
eventtarget: EventTarget::new_inherited(),
url,
ready_state: Cell::new(WebSocketRequestState::Connecting),
buffered_amount: Cell::new(0),
clearing_buffer: Cell::new(false),
sender,
binary_type: Cell::new(BinaryType::Blob),
protocol: DomRefCell::new("".to_owned()),
}
}
fn new(
global: &GlobalScope,
proto: Option<HandleObject>,
url: ServoUrl,
sender: IpcSender<WebSocketDomAction>,
can_gc: CanGc,
) -> DomRoot<WebSocket> {
reflect_dom_object_with_proto(
Box::new(WebSocket::new_inherited(url, sender)),
global,
proto,
can_gc,
)
}
fn send_impl(&self, data_byte_len: u64) -> Fallible<bool> {
let return_after_buffer = match self.ready_state.get() {
WebSocketRequestState::Connecting => {
return Err(Error::InvalidState);
},
WebSocketRequestState::Open => false,
WebSocketRequestState::Closing | WebSocketRequestState::Closed => true,
};
let address = Trusted::new(self);
match data_byte_len.checked_add(self.buffered_amount.get()) {
None => panic!(),
Some(new_amount) => self.buffered_amount.set(new_amount),
};
if return_after_buffer {
return Ok(false);
}
if !self.clearing_buffer.get() && self.ready_state.get() == WebSocketRequestState::Open {
self.clearing_buffer.set(true);
let task = Box::new(BufferedAmountTask { address });
let pipeline_id = self.global().pipeline_id();
self.global()
.script_chan()
.send(CommonScriptMsg::Task(
WebSocketEvent,
task,
Some(pipeline_id),
WebsocketTaskSource::NAME,
))
.unwrap();
}
Ok(true)
}
pub fn origin(&self) -> ImmutableOrigin {
self.url.origin()
}
}
impl WebSocketMethods<crate::DomTypeHolder> for WebSocket {
fn Constructor(
global: &GlobalScope,
proto: Option<HandleObject>,
can_gc: CanGc,
url: DOMString,
protocols: Option<StringOrStringSequence>,
) -> Fallible<DomRoot<WebSocket>> {
let url_record = ServoUrl::parse(&url).or(Err(Error::Syntax))?;
match url_record.scheme() {
"ws" | "wss" => {},
_ => return Err(Error::Syntax),
}
if url_record.fragment().is_some() {
return Err(Error::Syntax);
}
let protocols = protocols.map_or(vec![], |p| match p {
StringOrStringSequence::String(string) => vec![string.into()],
StringOrStringSequence::StringSequence(seq) => {
seq.into_iter().map(String::from).collect()
},
});
for (i, protocol) in protocols.iter().enumerate() {
if protocols[i + 1..]
.iter()
.any(|p| p.eq_ignore_ascii_case(protocol))
{
return Err(Error::Syntax);
}
if !is_token(protocol.as_bytes()) {
return Err(Error::Syntax);
}
}
let (dom_action_sender, resource_action_receiver): (
IpcSender<WebSocketDomAction>,
IpcReceiver<WebSocketDomAction>,
) = ipc::channel().unwrap();
let (resource_event_sender, dom_event_receiver): (
IpcSender<WebSocketNetworkEvent>,
ProfiledIpc::IpcReceiver<WebSocketNetworkEvent>,
) = ProfiledIpc::channel(global.time_profiler_chan().clone()).unwrap();
let ws = WebSocket::new(global, proto, url_record.clone(), dom_action_sender, can_gc);
let address = Trusted::new(&*ws);
let request = RequestBuilder::new(url_record, Referrer::NoReferrer)
.origin(global.origin().immutable().clone())
.mode(RequestMode::WebSocket { protocols });
let channels = FetchChannels::WebSocket {
event_sender: resource_event_sender,
action_receiver: resource_action_receiver,
};
let _ = global
.core_resource_thread()
.send(CoreResourceMsg::Fetch(request, channels));
let task_source = global.websocket_task_source();
let canceller = global.task_canceller(WebsocketTaskSource::NAME);
ROUTER.add_typed_route(
dom_event_receiver.to_ipc_receiver(),
Box::new(move |message| match message.unwrap() {
WebSocketNetworkEvent::ConnectionEstablished { protocol_in_use } => {
let open_thread = ConnectionEstablishedTask {
address: address.clone(),
protocol_in_use,
};
let _ = task_source.queue_with_canceller(open_thread, &canceller);
},
WebSocketNetworkEvent::MessageReceived(message) => {
let message_thread = MessageReceivedTask {
address: address.clone(),
message,
};
let _ = task_source.queue_with_canceller(message_thread, &canceller);
},
WebSocketNetworkEvent::Fail => {
fail_the_websocket_connection(address.clone(), &task_source, &canceller);
},
WebSocketNetworkEvent::Close(code, reason) => {
close_the_websocket_connection(
address.clone(),
&task_source,
&canceller,
code,
reason,
);
},
}),
);
Ok(ws)
}
event_handler!(open, GetOnopen, SetOnopen);
event_handler!(close, GetOnclose, SetOnclose);
event_handler!(error, GetOnerror, SetOnerror);
event_handler!(message, GetOnmessage, SetOnmessage);
fn Url(&self) -> DOMString {
DOMString::from(self.url.as_str())
}
fn ReadyState(&self) -> u16 {
self.ready_state.get() as u16
}
fn BufferedAmount(&self) -> u64 {
self.buffered_amount.get()
}
fn BinaryType(&self) -> BinaryType {
self.binary_type.get()
}
fn SetBinaryType(&self, btype: BinaryType) {
self.binary_type.set(btype)
}
fn Protocol(&self) -> DOMString {
DOMString::from(self.protocol.borrow().clone())
}
fn Send(&self, data: USVString) -> ErrorResult {
let data_byte_len = data.0.as_bytes().len() as u64;
let send_data = self.send_impl(data_byte_len)?;
if send_data {
let _ = self
.sender
.send(WebSocketDomAction::SendMessage(MessageData::Text(data.0)));
}
Ok(())
}
fn Send_(&self, blob: &Blob) -> ErrorResult {
let data_byte_len = blob.Size();
let send_data = self.send_impl(data_byte_len)?;
if send_data {
let bytes = blob.get_bytes().unwrap_or_default();
let _ = self
.sender
.send(WebSocketDomAction::SendMessage(MessageData::Binary(bytes)));
}
Ok(())
}
fn Send__(&self, array: CustomAutoRooterGuard<ArrayBuffer>) -> ErrorResult {
let bytes = array.to_vec();
let data_byte_len = bytes.len();
let send_data = self.send_impl(data_byte_len as u64)?;
if send_data {
let _ = self
.sender
.send(WebSocketDomAction::SendMessage(MessageData::Binary(bytes)));
}
Ok(())
}
fn Send___(&self, array: CustomAutoRooterGuard<ArrayBufferView>) -> ErrorResult {
let bytes = array.to_vec();
let data_byte_len = bytes.len();
let send_data = self.send_impl(data_byte_len as u64)?;
if send_data {
let _ = self
.sender
.send(WebSocketDomAction::SendMessage(MessageData::Binary(bytes)));
}
Ok(())
}
fn Close(&self, code: Option<u16>, reason: Option<USVString>) -> ErrorResult {
if let Some(code) = code {
if code != close_code::NORMAL && !(3000..=4999).contains(&code) {
return Err(Error::InvalidAccess);
}
}
if let Some(ref reason) = reason {
if reason.0.as_bytes().len() > 123 {
return Err(Error::Syntax);
}
}
match self.ready_state.get() {
WebSocketRequestState::Closing | WebSocketRequestState::Closed => {}, WebSocketRequestState::Connecting => {
self.ready_state.set(WebSocketRequestState::Closing);
let address = Trusted::new(self);
let task_source = self.global().websocket_task_source();
fail_the_websocket_connection(
address,
&task_source,
&self.global().task_canceller(WebsocketTaskSource::NAME),
);
},
WebSocketRequestState::Open => {
self.ready_state.set(WebSocketRequestState::Closing);
let reason = reason.map(|reason| reason.0);
let _ = self.sender.send(WebSocketDomAction::Close(code, reason));
},
}
Ok(()) }
}
struct ConnectionEstablishedTask {
address: Trusted<WebSocket>,
protocol_in_use: Option<String>,
}
impl TaskOnce for ConnectionEstablishedTask {
fn run_once(self) {
let ws = self.address.root();
ws.ready_state.set(WebSocketRequestState::Open);
if let Some(protocol_name) = self.protocol_in_use {
*ws.protocol.borrow_mut() = protocol_name;
};
ws.upcast().fire_event(atom!("open"), CanGc::note());
}
}
struct BufferedAmountTask {
address: Trusted<WebSocket>,
}
impl TaskOnce for BufferedAmountTask {
fn run_once(self) {
let ws = self.address.root();
ws.buffered_amount.set(0);
ws.clearing_buffer.set(false);
}
}
struct CloseTask {
address: Trusted<WebSocket>,
failed: bool,
code: Option<u16>,
reason: Option<String>,
}
impl TaskOnce for CloseTask {
fn run_once(self) {
let ws = self.address.root();
if ws.ready_state.get() == WebSocketRequestState::Closed {
return;
}
ws.ready_state.set(WebSocketRequestState::Closed);
if self.failed {
ws.upcast().fire_event(atom!("error"), CanGc::note());
}
let clean_close = !self.failed;
let code = self.code.unwrap_or(close_code::NO_STATUS);
let reason = DOMString::from(self.reason.unwrap_or("".to_owned()));
let close_event = CloseEvent::new(
&ws.global(),
atom!("close"),
EventBubbles::DoesNotBubble,
EventCancelable::NotCancelable,
clean_close,
code,
reason,
CanGc::note(),
);
close_event
.upcast::<Event>()
.fire(ws.upcast(), CanGc::note());
}
}
struct MessageReceivedTask {
address: Trusted<WebSocket>,
message: MessageData,
}
impl TaskOnce for MessageReceivedTask {
#[allow(unsafe_code)]
fn run_once(self) {
let ws = self.address.root();
debug!(
"MessageReceivedTask::handler({:p}): readyState={:?}",
&*ws,
ws.ready_state.get()
);
if ws.ready_state.get() != WebSocketRequestState::Open {
return;
}
let global = ws.global();
unsafe {
let cx = GlobalScope::get_cx();
let _ac = JSAutoRealm::new(*cx, ws.reflector().get_jsobject().get());
rooted!(in(*cx) let mut message = UndefinedValue());
match self.message {
MessageData::Text(text) => text.to_jsval(*cx, message.handle_mut()),
MessageData::Binary(data) => match ws.binary_type.get() {
BinaryType::Blob => {
let blob = Blob::new(
&global,
BlobImpl::new_from_bytes(data, "".to_owned()),
CanGc::note(),
);
blob.to_jsval(*cx, message.handle_mut());
},
BinaryType::Arraybuffer => {
rooted!(in(*cx) let mut array_buffer = ptr::null_mut::<JSObject>());
assert!(ArrayBuffer::create(
*cx,
CreateWith::Slice(&data),
array_buffer.handle_mut()
)
.is_ok());
(*array_buffer).to_jsval(*cx, message.handle_mut());
},
},
}
MessageEvent::dispatch_jsval(
ws.upcast(),
&global,
message.handle(),
Some(&ws.origin().ascii_serialization()),
None,
vec![],
CanGc::note(),
);
}
}
}