1use std::mem;
15use std::sync::Arc;
16use std::sync::atomic::{AtomicBool, Ordering};
17
18use async_tungstenite::WebSocketStream;
19use async_tungstenite::tokio::{ConnectStream, client_async_tls_with_connector_and_config};
20use base64::Engine;
21use content_security_policy as csp;
22use futures::future::TryFutureExt;
23use futures::stream::StreamExt;
24use http::header::{self, HeaderName, HeaderValue};
25use ipc_channel::ipc::{IpcReceiver, IpcSender};
26use ipc_channel::router::ROUTER;
27use log::{debug, trace, warn};
28use net_traits::policy_container::{PolicyContainer, RequestPolicyContainer};
29use net_traits::request::{Origin, RequestBuilder, RequestMode};
30use net_traits::{CookieSource, MessageData, WebSocketDomAction, WebSocketNetworkEvent};
31use servo_url::ServoUrl;
32use tokio::net::TcpStream;
33use tokio::select;
34use tokio::sync::mpsc::{UnboundedReceiver, unbounded_channel};
35use tokio_rustls::TlsConnector;
36use tungstenite::Message;
37use tungstenite::error::{Error, ProtocolError, Result as WebSocketResult, UrlError};
38use tungstenite::handshake::client::{Request, Response};
39use tungstenite::protocol::CloseFrame;
40use url::Url;
41
42use crate::async_runtime::spawn_task;
43use crate::connector::{CACertificates, TlsConfig, create_tls_config};
44use crate::cookie::ServoCookie;
45use crate::fetch::methods::{
46 convert_request_to_csp_request, should_request_be_blocked_by_csp,
47 should_request_be_blocked_due_to_a_bad_port,
48};
49use crate::hosts::replace_host;
50use crate::http_loader::HttpState;
51
52#[allow(clippy::result_large_err)]
53fn create_request(
59 resource_url: &ServoUrl,
60 origin: &str,
61 protocols: &[String],
62 http_state: &HttpState,
63) -> WebSocketResult<Request> {
64 let mut builder = Request::get(resource_url.as_str());
65 let headers = builder.headers_mut().unwrap();
66 headers.insert("Origin", HeaderValue::from_str(origin)?);
67
68 let origin = resource_url.origin();
69 let host = format!(
70 "{}",
71 origin
72 .host()
73 .ok_or_else(|| Error::Url(UrlError::NoHostName))?
74 );
75 headers.insert("Host", HeaderValue::from_str(&host)?);
76 headers.insert("Connection", HeaderValue::from_static("upgrade"));
77 headers.insert("Upgrade", HeaderValue::from_static("websocket"));
78 headers.insert("Sec-Websocket-Version", HeaderValue::from_static("13"));
79
80 let key = HeaderValue::from_str(&tungstenite::handshake::client::generate_key()).unwrap();
81 headers.insert("Sec-WebSocket-Key", key);
82
83 if !protocols.is_empty() {
84 let protocols = protocols.join(",");
85 headers.insert("Sec-WebSocket-Protocol", HeaderValue::from_str(&protocols)?);
86 }
87
88 let mut cookie_jar = http_state.cookie_jar.write().unwrap();
89 cookie_jar.remove_expired_cookies_for_url(resource_url);
90 if let Some(cookie_list) = cookie_jar.cookies_for_url(resource_url, CookieSource::HTTP) {
91 headers.insert("Cookie", HeaderValue::from_str(&cookie_list)?);
92 }
93
94 if resource_url.password().is_some() || resource_url.username() != "" {
95 let basic = base64::engine::general_purpose::STANDARD.encode(format!(
96 "{}:{}",
97 resource_url.username(),
98 resource_url.password().unwrap_or("")
99 ));
100 headers.insert(
101 "Authorization",
102 HeaderValue::from_str(&format!("Basic {}", basic))?,
103 );
104 }
105
106 let request = builder.body(())?;
107 Ok(request)
108}
109
110#[allow(clippy::result_large_err)]
111fn process_ws_response(
116 http_state: &HttpState,
117 response: &Response,
118 resource_url: &ServoUrl,
119 protocols: &[String],
120) -> Result<Option<String>, Error> {
121 trace!("processing websocket http response for {}", resource_url);
122 let mut protocol_in_use = None;
123 if let Some(protocol_name) = response.headers().get("Sec-WebSocket-Protocol") {
124 let protocol_name = protocol_name.to_str().unwrap_or("");
125 if !protocols.is_empty() && !protocols.iter().any(|p| protocol_name == (*p)) {
126 return Err(Error::Protocol(ProtocolError::InvalidHeader(
127 HeaderName::from_static("sec-websocket-protocol"),
128 )));
129 }
130 protocol_in_use = Some(protocol_name.to_string());
131 }
132
133 let mut jar = http_state.cookie_jar.write().unwrap();
134 for cookie in response.headers().get_all(header::SET_COOKIE) {
136 if let Ok(s) = std::str::from_utf8(cookie.as_bytes()) {
137 if let Some(cookie) =
138 ServoCookie::from_cookie_string(s.into(), resource_url, CookieSource::HTTP)
139 {
140 jar.push(cookie, resource_url, CookieSource::HTTP);
141 }
142 }
143 }
144
145 http_state
146 .hsts_list
147 .write()
148 .unwrap()
149 .update_hsts_list_from_response(resource_url, response.headers());
150
151 Ok(protocol_in_use)
152}
153
154#[derive(Debug)]
155enum DomMsg {
156 Send(Message),
157 Close(Option<(u16, String)>),
158}
159
160fn setup_dom_listener(
163 dom_action_receiver: IpcReceiver<WebSocketDomAction>,
164 initiated_close: Arc<AtomicBool>,
165) -> UnboundedReceiver<DomMsg> {
166 let (sender, receiver) = unbounded_channel();
167
168 ROUTER.add_typed_route(
169 dom_action_receiver,
170 Box::new(move |message| {
171 let dom_action = message.expect("Ws dom_action message to deserialize");
172 trace!("handling WS DOM action: {:?}", dom_action);
173 match dom_action {
174 WebSocketDomAction::SendMessage(MessageData::Text(data)) => {
175 if let Err(e) = sender.send(DomMsg::Send(Message::Text(data.into()))) {
176 warn!("Error sending websocket message: {:?}", e);
177 }
178 },
179 WebSocketDomAction::SendMessage(MessageData::Binary(data)) => {
180 if let Err(e) = sender.send(DomMsg::Send(Message::Binary(data.into()))) {
181 warn!("Error sending websocket message: {:?}", e);
182 }
183 },
184 WebSocketDomAction::Close(code, reason) => {
185 if initiated_close.fetch_or(true, Ordering::SeqCst) {
186 return;
187 }
188 let frame = code.map(move |c| (c, reason.unwrap_or_default()));
189 if let Err(e) = sender.send(DomMsg::Close(frame)) {
190 warn!("Error closing websocket: {:?}", e);
191 }
192 },
193 }
194 }),
195 );
196
197 receiver
198}
199
200async fn run_ws_loop(
205 mut dom_receiver: UnboundedReceiver<DomMsg>,
206 resource_event_sender: IpcSender<WebSocketNetworkEvent>,
207 mut stream: WebSocketStream<ConnectStream>,
208) {
209 loop {
210 select! {
211 dom_msg = dom_receiver.recv() => {
212 trace!("processing dom msg: {:?}", dom_msg);
213 let dom_msg = match dom_msg {
214 Some(msg) => msg,
215 None => break,
216 };
217 match dom_msg {
218 DomMsg::Send(m) => {
219 if let Err(e) = stream.send(m).await {
220 warn!("error sending websocket message: {:?}", e);
221 }
222 },
223 DomMsg::Close(frame) => {
224 if let Err(e) = stream.close(frame.map(|(code, reason)| {
225 CloseFrame {
226 code: code.into(),
227 reason: reason.into(),
228 }
229 })).await {
230 warn!("error closing websocket: {:?}", e);
231 }
232 },
233 }
234 }
235 ws_msg = stream.next() => {
236 trace!("processing WS stream: {:?}", ws_msg);
237 let msg = match ws_msg {
238 Some(Ok(msg)) => msg,
239 Some(Err(e)) => {
240 warn!("Error in WebSocket communication: {:?}", e);
241 let _ = resource_event_sender.send(WebSocketNetworkEvent::Fail);
242 break;
243 },
244 None => {
245 warn!("Error in WebSocket communication");
246 let _ = resource_event_sender.send(WebSocketNetworkEvent::Fail);
247 break;
248 }
249 };
250 match msg {
251 Message::Text(s) => {
252 let message = MessageData::Text(s.as_str().to_owned());
253 if let Err(e) = resource_event_sender
254 .send(WebSocketNetworkEvent::MessageReceived(message))
255 {
256 warn!("Error sending websocket notification: {:?}", e);
257 break;
258 }
259 }
260
261 Message::Binary(v) => {
262 let message = MessageData::Binary(v.to_vec());
263 if let Err(e) = resource_event_sender
264 .send(WebSocketNetworkEvent::MessageReceived(message))
265 {
266 warn!("Error sending websocket notification: {:?}", e);
267 break;
268 }
269 }
270
271 Message::Ping(_) | Message::Pong(_) => {}
272
273 Message::Close(frame) => {
274 let (reason, code) = match frame {
275 Some(frame) => (frame.reason, Some(frame.code.into())),
276 None => ("".into(), None),
277 };
278 debug!("Websocket connection closing due to ({:?}) {}", code, reason);
279 let _ = resource_event_sender.send(WebSocketNetworkEvent::Close(
280 code,
281 reason.to_string(),
282 ));
283 break;
284 }
285
286 Message::Frame(_) => {
287 warn!("Unexpected websocket frame message");
288 }
289 }
290 }
291 }
292 }
293}
294
295async fn start_websocket(
299 http_state: Arc<HttpState>,
300 url: ServoUrl,
301 resource_event_sender: IpcSender<WebSocketNetworkEvent>,
302 protocols: Vec<String>,
303 client: Request,
304 tls_config: TlsConfig,
305 dom_action_receiver: IpcReceiver<WebSocketDomAction>,
306) -> Result<(), Error> {
307 trace!("starting WS connection to {}", url);
308
309 let initiated_close = Arc::new(AtomicBool::new(false));
310 let dom_receiver = setup_dom_listener(dom_action_receiver, initiated_close.clone());
311
312 let host_str = client
313 .uri()
314 .host()
315 .ok_or_else(|| Error::Url(UrlError::NoHostName))?;
316 let host = replace_host(host_str);
317 let mut net_url = Url::parse(&client.uri().to_string())
318 .map_err(|e| Error::Url(UrlError::UnableToConnect(e.to_string())))?;
319 net_url
320 .set_host(Some(&host))
321 .map_err(|e| Error::Url(UrlError::UnableToConnect(e.to_string())))?;
322
323 let domain = net_url
324 .host()
325 .ok_or_else(|| Error::Url(UrlError::NoHostName))?;
326 let port = net_url
327 .port_or_known_default()
328 .ok_or_else(|| Error::Url(UrlError::UnableToConnect("Unknown port".into())))?;
329
330 let try_socket = TcpStream::connect((&*domain.to_string(), port)).await;
331 let socket = try_socket.map_err(Error::Io)?;
332 let connector = TlsConnector::from(Arc::new(tls_config));
333
334 let (stream, response) =
335 client_async_tls_with_connector_and_config(client, socket, Some(connector), None).await?;
336
337 let protocol_in_use = process_ws_response(&http_state, &response, &url, &protocols)?;
338
339 if !initiated_close.load(Ordering::SeqCst) {
340 if resource_event_sender
341 .send(WebSocketNetworkEvent::ConnectionEstablished { protocol_in_use })
342 .is_err()
343 {
344 return Ok(());
345 }
346
347 trace!("about to start ws loop for {}", url);
348 run_ws_loop(dom_receiver, resource_event_sender, stream).await;
349 } else {
350 trace!("client closed connection for {}, not running loop", url);
351 }
352 Ok(())
353}
354
355fn connect(
357 mut req_builder: RequestBuilder,
358 resource_event_sender: IpcSender<WebSocketNetworkEvent>,
359 dom_action_receiver: IpcReceiver<WebSocketDomAction>,
360 http_state: Arc<HttpState>,
361 ca_certificates: CACertificates,
362 ignore_certificate_errors: bool,
363) -> Result<(), String> {
364 let protocols = match req_builder.mode {
365 RequestMode::WebSocket { ref mut protocols } => mem::take(protocols),
366 _ => {
367 return Err(
368 "Received a RequestBuilder with a non-websocket mode in websocket_loader"
369 .to_string(),
370 );
371 },
372 };
373
374 http_state
376 .hsts_list
377 .read()
378 .unwrap()
379 .apply_hsts_rules(&mut req_builder.url);
380 let request = req_builder.build();
381
382 let req_url = request.url();
383 let req_origin = match request.origin {
384 Origin::Client => unreachable!(),
385 Origin::Origin(ref origin) => origin,
386 };
387
388 if should_request_be_blocked_due_to_a_bad_port(&req_url) {
389 return Err("Port blocked".to_string());
390 }
391
392 let policy_container = match &request.policy_container {
393 RequestPolicyContainer::Client => PolicyContainer::default(),
394 RequestPolicyContainer::PolicyContainer(container) => container.to_owned(),
395 };
396
397 if let Some(csp_request) = convert_request_to_csp_request(&request) {
398 let (check_result, violations) =
399 should_request_be_blocked_by_csp(&csp_request, &policy_container);
400
401 if !violations.is_empty() {
402 let _ =
403 resource_event_sender.send(WebSocketNetworkEvent::ReportCSPViolations(violations));
404 }
405
406 if check_result == csp::CheckResult::Blocked {
407 return Err("Blocked by Content-Security-Policy".to_string());
408 }
409 }
410
411 let client = match create_request(
412 &req_url,
413 &req_origin.ascii_serialization(),
414 &protocols,
415 &http_state,
416 ) {
417 Ok(c) => c,
418 Err(e) => return Err(e.to_string()),
419 };
420
421 let mut tls_config = create_tls_config(
422 ca_certificates,
423 ignore_certificate_errors,
424 http_state.override_manager.clone(),
425 );
426 tls_config.alpn_protocols = vec!["http/1.1".to_string().into()];
427
428 let resource_event_sender2 = resource_event_sender.clone();
429 spawn_task(
430 start_websocket(
431 http_state,
432 req_url.clone(),
433 resource_event_sender,
434 protocols,
435 client,
436 tls_config,
437 dom_action_receiver,
438 )
439 .map_err(move |e| {
440 warn!("Failed to establish a WebSocket connection: {:?}", e);
441 let _ = resource_event_sender2.send(WebSocketNetworkEvent::Fail);
442 }),
443 );
444 Ok(())
445}
446
447pub fn init(
449 req_builder: RequestBuilder,
450 resource_event_sender: IpcSender<WebSocketNetworkEvent>,
451 dom_action_receiver: IpcReceiver<WebSocketDomAction>,
452 http_state: Arc<HttpState>,
453 ca_certificates: CACertificates,
454 ignore_certificate_errors: bool,
455) {
456 let resource_event_sender2 = resource_event_sender.clone();
457 if let Err(e) = connect(
458 req_builder,
459 resource_event_sender,
460 dom_action_receiver,
461 http_state,
462 ca_certificates,
463 ignore_certificate_errors,
464 ) {
465 warn!("Error starting websocket: {}", e);
466 let _ = resource_event_sender2.send(WebSocketNetworkEvent::Fail);
467 }
468}