net/
websocket_loader.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
4
5//! The websocket handler has three main responsibilities:
6//! 1) initiate the initial HTTP connection and process the response
7//! 2) ensure any DOM requests for sending/closing are propagated to the network
8//! 3) transmit any incoming messages/closing to the DOM
9//!
10//! In order to accomplish this, the handler uses a long-running loop that selects
11//! over events from the network and events from the DOM, using async/await to avoid
12//! the need for a dedicated thread per websocket.
13
14use std::mem;
15use std::sync::Arc;
16use std::sync::atomic::{AtomicBool, Ordering};
17
18use async_tungstenite::WebSocketStream;
19use async_tungstenite::tokio::{ConnectStream, client_async_tls_with_connector_and_config};
20use base64::Engine;
21use content_security_policy as csp;
22use futures::future::TryFutureExt;
23use futures::stream::StreamExt;
24use http::header::{self, HeaderName, HeaderValue};
25use ipc_channel::ipc::{IpcReceiver, IpcSender};
26use ipc_channel::router::ROUTER;
27use log::{debug, trace, warn};
28use net_traits::policy_container::{PolicyContainer, RequestPolicyContainer};
29use net_traits::request::{Origin, RequestBuilder, RequestMode};
30use net_traits::{CookieSource, MessageData, WebSocketDomAction, WebSocketNetworkEvent};
31use servo_url::ServoUrl;
32use tokio::net::TcpStream;
33use tokio::select;
34use tokio::sync::mpsc::{UnboundedReceiver, unbounded_channel};
35use tokio_rustls::TlsConnector;
36use tungstenite::Message;
37use tungstenite::error::{Error, ProtocolError, Result as WebSocketResult, UrlError};
38use tungstenite::handshake::client::{Request, Response};
39use tungstenite::protocol::CloseFrame;
40use url::Url;
41
42use crate::async_runtime::spawn_task;
43use crate::connector::{CACertificates, TlsConfig, create_tls_config};
44use crate::cookie::ServoCookie;
45use crate::fetch::methods::{
46    convert_request_to_csp_request, should_request_be_blocked_by_csp,
47    should_request_be_blocked_due_to_a_bad_port,
48};
49use crate::hosts::replace_host;
50use crate::http_loader::HttpState;
51
52#[allow(clippy::result_large_err)]
53/// Create a tungstenite Request object for the initial HTTP request.
54/// This request contains `Origin`, `Sec-WebSocket-Protocol`, `Authorization`,
55/// and `Cookie` headers as appropriate.
56/// Returns an error if any header values are invalid or tungstenite cannot create
57/// the desired request.
58fn create_request(
59    resource_url: &ServoUrl,
60    origin: &str,
61    protocols: &[String],
62    http_state: &HttpState,
63) -> WebSocketResult<Request> {
64    let mut builder = Request::get(resource_url.as_str());
65    let headers = builder.headers_mut().unwrap();
66    headers.insert("Origin", HeaderValue::from_str(origin)?);
67
68    let origin = resource_url.origin();
69    let host = format!(
70        "{}",
71        origin
72            .host()
73            .ok_or_else(|| Error::Url(UrlError::NoHostName))?
74    );
75    headers.insert("Host", HeaderValue::from_str(&host)?);
76    headers.insert("Connection", HeaderValue::from_static("upgrade"));
77    headers.insert("Upgrade", HeaderValue::from_static("websocket"));
78    headers.insert("Sec-Websocket-Version", HeaderValue::from_static("13"));
79
80    let key = HeaderValue::from_str(&tungstenite::handshake::client::generate_key()).unwrap();
81    headers.insert("Sec-WebSocket-Key", key);
82
83    if !protocols.is_empty() {
84        let protocols = protocols.join(",");
85        headers.insert("Sec-WebSocket-Protocol", HeaderValue::from_str(&protocols)?);
86    }
87
88    let mut cookie_jar = http_state.cookie_jar.write().unwrap();
89    cookie_jar.remove_expired_cookies_for_url(resource_url);
90    if let Some(cookie_list) = cookie_jar.cookies_for_url(resource_url, CookieSource::HTTP) {
91        headers.insert("Cookie", HeaderValue::from_str(&cookie_list)?);
92    }
93
94    if resource_url.password().is_some() || resource_url.username() != "" {
95        let basic = base64::engine::general_purpose::STANDARD.encode(format!(
96            "{}:{}",
97            resource_url.username(),
98            resource_url.password().unwrap_or("")
99        ));
100        headers.insert(
101            "Authorization",
102            HeaderValue::from_str(&format!("Basic {}", basic))?,
103        );
104    }
105
106    let request = builder.body(())?;
107    Ok(request)
108}
109
110#[allow(clippy::result_large_err)]
111/// Process an HTTP response resulting from a WS handshake.
112/// This ensures that any `Cookie` or HSTS headers are recognized.
113/// Returns an error if the protocol selected by the handshake doesn't
114/// match the list of provided protocols in the original request.
115fn process_ws_response(
116    http_state: &HttpState,
117    response: &Response,
118    resource_url: &ServoUrl,
119    protocols: &[String],
120) -> Result<Option<String>, Error> {
121    trace!("processing websocket http response for {}", resource_url);
122    let mut protocol_in_use = None;
123    if let Some(protocol_name) = response.headers().get("Sec-WebSocket-Protocol") {
124        let protocol_name = protocol_name.to_str().unwrap_or("");
125        if !protocols.is_empty() && !protocols.iter().any(|p| protocol_name == (*p)) {
126            return Err(Error::Protocol(ProtocolError::InvalidHeader(
127                HeaderName::from_static("sec-websocket-protocol"),
128            )));
129        }
130        protocol_in_use = Some(protocol_name.to_string());
131    }
132
133    let mut jar = http_state.cookie_jar.write().unwrap();
134    // TODO(eijebong): Replace thise once typed headers settled on a cookie impl
135    for cookie in response.headers().get_all(header::SET_COOKIE) {
136        if let Ok(s) = std::str::from_utf8(cookie.as_bytes()) {
137            if let Some(cookie) =
138                ServoCookie::from_cookie_string(s.into(), resource_url, CookieSource::HTTP)
139            {
140                jar.push(cookie, resource_url, CookieSource::HTTP);
141            }
142        }
143    }
144
145    http_state
146        .hsts_list
147        .write()
148        .unwrap()
149        .update_hsts_list_from_response(resource_url, response.headers());
150
151    Ok(protocol_in_use)
152}
153
154#[derive(Debug)]
155enum DomMsg {
156    Send(Message),
157    Close(Option<(u16, String)>),
158}
159
160/// Initialize a listener for DOM actions. These are routed from the IPC channel
161/// to a tokio channel that the main WS client task uses to receive them.
162fn setup_dom_listener(
163    dom_action_receiver: IpcReceiver<WebSocketDomAction>,
164    initiated_close: Arc<AtomicBool>,
165) -> UnboundedReceiver<DomMsg> {
166    let (sender, receiver) = unbounded_channel();
167
168    ROUTER.add_typed_route(
169        dom_action_receiver,
170        Box::new(move |message| {
171            let dom_action = message.expect("Ws dom_action message to deserialize");
172            trace!("handling WS DOM action: {:?}", dom_action);
173            match dom_action {
174                WebSocketDomAction::SendMessage(MessageData::Text(data)) => {
175                    if let Err(e) = sender.send(DomMsg::Send(Message::Text(data.into()))) {
176                        warn!("Error sending websocket message: {:?}", e);
177                    }
178                },
179                WebSocketDomAction::SendMessage(MessageData::Binary(data)) => {
180                    if let Err(e) = sender.send(DomMsg::Send(Message::Binary(data.into()))) {
181                        warn!("Error sending websocket message: {:?}", e);
182                    }
183                },
184                WebSocketDomAction::Close(code, reason) => {
185                    if initiated_close.fetch_or(true, Ordering::SeqCst) {
186                        return;
187                    }
188                    let frame = code.map(move |c| (c, reason.unwrap_or_default()));
189                    if let Err(e) = sender.send(DomMsg::Close(frame)) {
190                        warn!("Error closing websocket: {:?}", e);
191                    }
192                },
193            }
194        }),
195    );
196
197    receiver
198}
199
200/// Listen for WS events from the DOM and the network until one side
201/// closes the connection or an error occurs. Since this is an async
202/// function that uses the select operation, it will run as a task
203/// on the WS tokio runtime.
204async fn run_ws_loop(
205    mut dom_receiver: UnboundedReceiver<DomMsg>,
206    resource_event_sender: IpcSender<WebSocketNetworkEvent>,
207    mut stream: WebSocketStream<ConnectStream>,
208) {
209    loop {
210        select! {
211            dom_msg = dom_receiver.recv() => {
212                trace!("processing dom msg: {:?}", dom_msg);
213                let dom_msg = match dom_msg {
214                    Some(msg) => msg,
215                    None => break,
216                };
217                match dom_msg {
218                    DomMsg::Send(m) => {
219                        if let Err(e) = stream.send(m).await {
220                            warn!("error sending websocket message: {:?}", e);
221                        }
222                    },
223                    DomMsg::Close(frame) => {
224                        if let Err(e) = stream.close(frame.map(|(code, reason)| {
225                            CloseFrame {
226                                code: code.into(),
227                                reason: reason.into(),
228                            }
229                        })).await {
230                            warn!("error closing websocket: {:?}", e);
231                        }
232                    },
233                }
234            }
235            ws_msg = stream.next() => {
236                trace!("processing WS stream: {:?}", ws_msg);
237                let msg = match ws_msg {
238                    Some(Ok(msg)) => msg,
239                    Some(Err(e)) => {
240                        warn!("Error in WebSocket communication: {:?}", e);
241                        let _ = resource_event_sender.send(WebSocketNetworkEvent::Fail);
242                        break;
243                    },
244                    None => {
245                        warn!("Error in WebSocket communication");
246                        let _ = resource_event_sender.send(WebSocketNetworkEvent::Fail);
247                        break;
248                    }
249                };
250                match msg {
251                    Message::Text(s) => {
252                        let message = MessageData::Text(s.as_str().to_owned());
253                        if let Err(e) = resource_event_sender
254                            .send(WebSocketNetworkEvent::MessageReceived(message))
255                        {
256                            warn!("Error sending websocket notification: {:?}", e);
257                            break;
258                        }
259                    }
260
261                    Message::Binary(v) => {
262                        let message = MessageData::Binary(v.to_vec());
263                        if let Err(e) = resource_event_sender
264                            .send(WebSocketNetworkEvent::MessageReceived(message))
265                        {
266                            warn!("Error sending websocket notification: {:?}", e);
267                            break;
268                        }
269                    }
270
271                    Message::Ping(_) | Message::Pong(_) => {}
272
273                    Message::Close(frame) => {
274                        let (reason, code) = match frame {
275                            Some(frame) => (frame.reason, Some(frame.code.into())),
276                            None => ("".into(), None),
277                        };
278                        debug!("Websocket connection closing due to ({:?}) {}", code, reason);
279                        let _ = resource_event_sender.send(WebSocketNetworkEvent::Close(
280                            code,
281                            reason.to_string(),
282                        ));
283                        break;
284                    }
285
286                    Message::Frame(_) => {
287                        warn!("Unexpected websocket frame message");
288                    }
289                }
290            }
291        }
292    }
293}
294
295/// Initiate a new async WS connection. Returns an error if the connection fails
296/// for any reason, or if the response isn't valid. Otherwise, the endless WS
297/// listening loop will be started.
298async fn start_websocket(
299    http_state: Arc<HttpState>,
300    url: ServoUrl,
301    resource_event_sender: IpcSender<WebSocketNetworkEvent>,
302    protocols: Vec<String>,
303    client: Request,
304    tls_config: TlsConfig,
305    dom_action_receiver: IpcReceiver<WebSocketDomAction>,
306) -> Result<(), Error> {
307    trace!("starting WS connection to {}", url);
308
309    let initiated_close = Arc::new(AtomicBool::new(false));
310    let dom_receiver = setup_dom_listener(dom_action_receiver, initiated_close.clone());
311
312    let host_str = client
313        .uri()
314        .host()
315        .ok_or_else(|| Error::Url(UrlError::NoHostName))?;
316    let host = replace_host(host_str);
317    let mut net_url = Url::parse(&client.uri().to_string())
318        .map_err(|e| Error::Url(UrlError::UnableToConnect(e.to_string())))?;
319    net_url
320        .set_host(Some(&host))
321        .map_err(|e| Error::Url(UrlError::UnableToConnect(e.to_string())))?;
322
323    let domain = net_url
324        .host()
325        .ok_or_else(|| Error::Url(UrlError::NoHostName))?;
326    let port = net_url
327        .port_or_known_default()
328        .ok_or_else(|| Error::Url(UrlError::UnableToConnect("Unknown port".into())))?;
329
330    let try_socket = TcpStream::connect((&*domain.to_string(), port)).await;
331    let socket = try_socket.map_err(Error::Io)?;
332    let connector = TlsConnector::from(Arc::new(tls_config));
333
334    let (stream, response) =
335        client_async_tls_with_connector_and_config(client, socket, Some(connector), None).await?;
336
337    let protocol_in_use = process_ws_response(&http_state, &response, &url, &protocols)?;
338
339    if !initiated_close.load(Ordering::SeqCst) {
340        if resource_event_sender
341            .send(WebSocketNetworkEvent::ConnectionEstablished { protocol_in_use })
342            .is_err()
343        {
344            return Ok(());
345        }
346
347        trace!("about to start ws loop for {}", url);
348        run_ws_loop(dom_receiver, resource_event_sender, stream).await;
349    } else {
350        trace!("client closed connection for {}, not running loop", url);
351    }
352    Ok(())
353}
354
355/// Create a new websocket connection for the given request.
356fn connect(
357    mut req_builder: RequestBuilder,
358    resource_event_sender: IpcSender<WebSocketNetworkEvent>,
359    dom_action_receiver: IpcReceiver<WebSocketDomAction>,
360    http_state: Arc<HttpState>,
361    ca_certificates: CACertificates,
362    ignore_certificate_errors: bool,
363) -> Result<(), String> {
364    let protocols = match req_builder.mode {
365        RequestMode::WebSocket { ref mut protocols } => mem::take(protocols),
366        _ => {
367            return Err(
368                "Received a RequestBuilder with a non-websocket mode in websocket_loader"
369                    .to_string(),
370            );
371        },
372    };
373
374    // https://fetch.spec.whatwg.org/#websocket-opening-handshake
375    http_state
376        .hsts_list
377        .read()
378        .unwrap()
379        .apply_hsts_rules(&mut req_builder.url);
380    let request = req_builder.build();
381
382    let req_url = request.url();
383    let req_origin = match request.origin {
384        Origin::Client => unreachable!(),
385        Origin::Origin(ref origin) => origin,
386    };
387
388    if should_request_be_blocked_due_to_a_bad_port(&req_url) {
389        return Err("Port blocked".to_string());
390    }
391
392    let policy_container = match &request.policy_container {
393        RequestPolicyContainer::Client => PolicyContainer::default(),
394        RequestPolicyContainer::PolicyContainer(container) => container.to_owned(),
395    };
396
397    if let Some(csp_request) = convert_request_to_csp_request(&request) {
398        let (check_result, violations) =
399            should_request_be_blocked_by_csp(&csp_request, &policy_container);
400
401        if !violations.is_empty() {
402            let _ =
403                resource_event_sender.send(WebSocketNetworkEvent::ReportCSPViolations(violations));
404        }
405
406        if check_result == csp::CheckResult::Blocked {
407            return Err("Blocked by Content-Security-Policy".to_string());
408        }
409    }
410
411    let client = match create_request(
412        &req_url,
413        &req_origin.ascii_serialization(),
414        &protocols,
415        &http_state,
416    ) {
417        Ok(c) => c,
418        Err(e) => return Err(e.to_string()),
419    };
420
421    let mut tls_config = create_tls_config(
422        ca_certificates,
423        ignore_certificate_errors,
424        http_state.override_manager.clone(),
425    );
426    tls_config.alpn_protocols = vec!["http/1.1".to_string().into()];
427
428    let resource_event_sender2 = resource_event_sender.clone();
429    spawn_task(
430        start_websocket(
431            http_state,
432            req_url.clone(),
433            resource_event_sender,
434            protocols,
435            client,
436            tls_config,
437            dom_action_receiver,
438        )
439        .map_err(move |e| {
440            warn!("Failed to establish a WebSocket connection: {:?}", e);
441            let _ = resource_event_sender2.send(WebSocketNetworkEvent::Fail);
442        }),
443    );
444    Ok(())
445}
446
447/// Create a new websocket connection for the given request.
448pub fn init(
449    req_builder: RequestBuilder,
450    resource_event_sender: IpcSender<WebSocketNetworkEvent>,
451    dom_action_receiver: IpcReceiver<WebSocketDomAction>,
452    http_state: Arc<HttpState>,
453    ca_certificates: CACertificates,
454    ignore_certificate_errors: bool,
455) {
456    let resource_event_sender2 = resource_event_sender.clone();
457    if let Err(e) = connect(
458        req_builder,
459        resource_event_sender,
460        dom_action_receiver,
461        http_state,
462        ca_certificates,
463        ignore_certificate_errors,
464    ) {
465        warn!("Error starting websocket: {}", e);
466        let _ = resource_event_sender2.send(WebSocketNetworkEvent::Fail);
467    }
468}