net/
websocket_loader.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
4
5//! The websocket handler has three main responsibilities:
6//! 1) initiate the initial HTTP connection and process the response
7//! 2) ensure any DOM requests for sending/closing are propagated to the network
8//! 3) transmit any incoming messages/closing to the DOM
9//!
10//! In order to accomplish this, the handler uses a long-running loop that selects
11//! over events from the network and events from the DOM, using async/await to avoid
12//! the need for a dedicated thread per websocket.
13
14use std::sync::Arc;
15use std::sync::atomic::{AtomicBool, Ordering};
16
17use async_tungstenite::WebSocketStream;
18use async_tungstenite::tokio::{ConnectStream, client_async_tls_with_connector_and_config};
19use base64::Engine;
20use futures::stream::StreamExt;
21use http::HeaderMap;
22use http::header::{self, HeaderName, HeaderValue};
23use ipc_channel::ipc::{IpcReceiver, IpcSender};
24use ipc_channel::router::ROUTER;
25use log::{debug, trace, warn};
26use net_traits::request::{RequestBuilder, RequestMode};
27use net_traits::{CookieSource, MessageData, WebSocketDomAction, WebSocketNetworkEvent};
28use servo_url::ServoUrl;
29use tokio::net::TcpStream;
30use tokio::select;
31use tokio::sync::mpsc::{UnboundedReceiver, unbounded_channel};
32use tokio_rustls::TlsConnector;
33use tungstenite::error::{Error, ProtocolError, UrlError};
34use tungstenite::handshake::client::Response;
35use tungstenite::protocol::CloseFrame;
36use tungstenite::{ClientRequestBuilder, Message};
37
38use crate::async_runtime::spawn_task;
39use crate::connector::TlsConfig;
40use crate::cookie::ServoCookie;
41use crate::hosts::replace_host;
42use crate::http_loader::HttpState;
43
44#[allow(clippy::result_large_err)]
45/// Create a Request object for the initial HTTP request.
46/// This request contains `Origin`, `Sec-WebSocket-Protocol`, `Authorization`,
47/// and `Cookie` headers as appropriate.
48/// Returns an error if any header values are invalid or tungstenite cannot create
49/// the desired request.
50pub fn create_handshake_request(
51    request: RequestBuilder,
52    http_state: Arc<HttpState>,
53) -> Result<net_traits::request::Request, Error> {
54    let origin = request.url.origin();
55
56    let mut headers = HeaderMap::new();
57    headers.insert(
58        "Origin",
59        HeaderValue::from_str(&request.url.origin().ascii_serialization())?,
60    );
61
62    let host = format!(
63        "{}",
64        origin
65            .host()
66            .ok_or_else(|| Error::Url(UrlError::NoHostName))?
67    );
68    headers.insert("Host", HeaderValue::from_str(&host)?);
69    // https://websockets.spec.whatwg.org/#concept-websocket-establish
70    // 3. Append (`Upgrade`, `websocket`) to request’s header list.
71    headers.insert("Upgrade", HeaderValue::from_static("websocket"));
72
73    // 4. Append (`Connection`, `Upgrade`) to request’s header list.
74    headers.insert("Connection", HeaderValue::from_static("upgrade"));
75
76    // 5. Let keyValue be a nonce consisting of a randomly selected 16-byte value that has been
77    // forgiving-base64-encoded and isomorphic encoded.
78    let key = HeaderValue::from_str(&tungstenite::handshake::client::generate_key()).unwrap();
79
80    // 6. Append (`Sec-WebSocket-Key`, keyValue) to request’s header list.
81    headers.insert("Sec-WebSocket-Key", key);
82
83    // 7. Append (`Sec-WebSocket-Version`, `13`) to request’s header list.
84    headers.insert("Sec-Websocket-Version", HeaderValue::from_static("13"));
85
86    // 8. For each protocol in protocols, combine (`Sec-WebSocket-Protocol`, protocol) in request’s
87    // header list.
88    let protocols = match request.mode {
89        RequestMode::WebSocket {
90            ref protocols,
91            original_url: _,
92        } => protocols,
93        _ => unreachable!("How did we get here?"),
94    };
95    if !protocols.is_empty() {
96        let protocols = protocols.join(",");
97        headers.insert("Sec-WebSocket-Protocol", HeaderValue::from_str(&protocols)?);
98    }
99
100    let mut cookie_jar = http_state.cookie_jar.write().unwrap();
101    cookie_jar.remove_expired_cookies_for_url(&request.url);
102    if let Some(cookie_list) = cookie_jar.cookies_for_url(&request.url, CookieSource::HTTP) {
103        headers.insert("Cookie", HeaderValue::from_str(&cookie_list)?);
104    }
105
106    if request.url.password().is_some() || request.url.username() != "" {
107        let basic = base64::engine::general_purpose::STANDARD.encode(format!(
108            "{}:{}",
109            request.url.username(),
110            request.url.password().unwrap_or("")
111        ));
112        headers.insert(
113            "Authorization",
114            HeaderValue::from_str(&format!("Basic {}", basic))?,
115        );
116    }
117    Ok(request.headers(headers).build())
118}
119
120#[allow(clippy::result_large_err)]
121/// Process an HTTP response resulting from a WS handshake.
122/// This ensures that any `Cookie` or HSTS headers are recognized.
123/// Returns an error if the protocol selected by the handshake doesn't
124/// match the list of provided protocols in the original request.
125fn process_ws_response(
126    http_state: &HttpState,
127    response: &Response,
128    resource_url: &ServoUrl,
129    protocols: &[String],
130) -> Result<Option<String>, Error> {
131    trace!("processing websocket http response for {}", resource_url);
132    let mut protocol_in_use = None;
133    if let Some(protocol_name) = response.headers().get("Sec-WebSocket-Protocol") {
134        let protocol_name = protocol_name.to_str().unwrap_or("");
135        if !protocols.is_empty() && !protocols.iter().any(|p| protocol_name == (*p)) {
136            return Err(Error::Protocol(ProtocolError::InvalidHeader(Box::new(
137                HeaderName::from_static("sec-websocket-protocol"),
138            ))));
139        }
140        protocol_in_use = Some(protocol_name.to_string());
141    }
142
143    let mut jar = http_state.cookie_jar.write().unwrap();
144    // TODO(eijebong): Replace thise once typed headers settled on a cookie impl
145    for cookie in response.headers().get_all(header::SET_COOKIE) {
146        if let Ok(s) = std::str::from_utf8(cookie.as_bytes()) {
147            if let Some(cookie) =
148                ServoCookie::from_cookie_string(s.into(), resource_url, CookieSource::HTTP)
149            {
150                jar.push(cookie, resource_url, CookieSource::HTTP);
151            }
152        }
153    }
154
155    http_state
156        .hsts_list
157        .write()
158        .unwrap()
159        .update_hsts_list_from_response(resource_url, response.headers());
160
161    Ok(protocol_in_use)
162}
163
164#[derive(Debug)]
165enum DomMsg {
166    Send(Message),
167    Close(Option<(u16, String)>),
168}
169
170/// Initialize a listener for DOM actions. These are routed from the IPC channel
171/// to a tokio channel that the main WS client task uses to receive them.
172fn setup_dom_listener(
173    dom_action_receiver: IpcReceiver<WebSocketDomAction>,
174    initiated_close: Arc<AtomicBool>,
175) -> UnboundedReceiver<DomMsg> {
176    let (sender, receiver) = unbounded_channel();
177
178    ROUTER.add_typed_route(
179        dom_action_receiver,
180        Box::new(move |message| {
181            let dom_action = message.expect("Ws dom_action message to deserialize");
182            trace!("handling WS DOM action: {:?}", dom_action);
183            match dom_action {
184                WebSocketDomAction::SendMessage(MessageData::Text(data)) => {
185                    if let Err(e) = sender.send(DomMsg::Send(Message::Text(data.into()))) {
186                        warn!("Error sending websocket message: {:?}", e);
187                    }
188                },
189                WebSocketDomAction::SendMessage(MessageData::Binary(data)) => {
190                    if let Err(e) = sender.send(DomMsg::Send(Message::Binary(data.into()))) {
191                        warn!("Error sending websocket message: {:?}", e);
192                    }
193                },
194                WebSocketDomAction::Close(code, reason) => {
195                    if initiated_close.fetch_or(true, Ordering::SeqCst) {
196                        return;
197                    }
198                    let frame = code.map(move |c| (c, reason.unwrap_or_default()));
199                    if let Err(e) = sender.send(DomMsg::Close(frame)) {
200                        warn!("Error closing websocket: {:?}", e);
201                    }
202                },
203            }
204        }),
205    );
206
207    receiver
208}
209
210/// Listen for WS events from the DOM and the network until one side
211/// closes the connection or an error occurs. Since this is an async
212/// function that uses the select operation, it will run as a task
213/// on the WS tokio runtime.
214async fn run_ws_loop(
215    mut dom_receiver: UnboundedReceiver<DomMsg>,
216    resource_event_sender: IpcSender<WebSocketNetworkEvent>,
217    mut stream: WebSocketStream<ConnectStream>,
218) {
219    loop {
220        select! {
221            dom_msg = dom_receiver.recv() => {
222                trace!("processing dom msg: {:?}", dom_msg);
223                let dom_msg = match dom_msg {
224                    Some(msg) => msg,
225                    None => break,
226                };
227                match dom_msg {
228                    DomMsg::Send(m) => {
229                        if let Err(e) = stream.send(m).await {
230                            warn!("error sending websocket message: {:?}", e);
231                        }
232                    },
233                    DomMsg::Close(frame) => {
234                        if let Err(e) = stream.close(frame.map(|(code, reason)| {
235                            CloseFrame {
236                                code: code.into(),
237                                reason: reason.into(),
238                            }
239                        })).await {
240                            warn!("error closing websocket: {:?}", e);
241                        }
242                    },
243                }
244            }
245            ws_msg = stream.next() => {
246                trace!("processing WS stream: {:?}", ws_msg);
247                let msg = match ws_msg {
248                    Some(Ok(msg)) => msg,
249                    Some(Err(e)) => {
250                        warn!("Error in WebSocket communication: {:?}", e);
251                        let _ = resource_event_sender.send(WebSocketNetworkEvent::Fail);
252                        break;
253                    },
254                    None => {
255                        warn!("Error in WebSocket communication");
256                        let _ = resource_event_sender.send(WebSocketNetworkEvent::Fail);
257                        break;
258                    }
259                };
260                match msg {
261                    Message::Text(s) => {
262                        let message = MessageData::Text(s.as_str().to_owned());
263                        if let Err(e) = resource_event_sender
264                            .send(WebSocketNetworkEvent::MessageReceived(message))
265                        {
266                            warn!("Error sending websocket notification: {:?}", e);
267                            break;
268                        }
269                    }
270
271                    Message::Binary(v) => {
272                        let message = MessageData::Binary(v.to_vec());
273                        if let Err(e) = resource_event_sender
274                            .send(WebSocketNetworkEvent::MessageReceived(message))
275                        {
276                            warn!("Error sending websocket notification: {:?}", e);
277                            break;
278                        }
279                    }
280
281                    Message::Ping(_) | Message::Pong(_) => {}
282
283                    Message::Close(frame) => {
284                        let (reason, code) = match frame {
285                            Some(frame) => (frame.reason, Some(frame.code.into())),
286                            None => ("".into(), None),
287                        };
288                        debug!("Websocket connection closing due to ({:?}) {}", code, reason);
289                        let _ = resource_event_sender.send(WebSocketNetworkEvent::Close(
290                            code,
291                            reason.to_string(),
292                        ));
293                        break;
294                    }
295
296                    Message::Frame(_) => {
297                        warn!("Unexpected websocket frame message");
298                    }
299                }
300            }
301        }
302    }
303}
304
305/// Initiate a new async WS connection. Returns an error if the connection fails
306/// for any reason, or if the response isn't valid. Otherwise, the endless WS
307/// listening loop will be started.
308pub(crate) async fn start_websocket(
309    http_state: Arc<HttpState>,
310    resource_event_sender: IpcSender<WebSocketNetworkEvent>,
311    protocols: &[String],
312    client: &net_traits::request::Request,
313    tls_config: TlsConfig,
314    dom_action_receiver: IpcReceiver<WebSocketDomAction>,
315) -> Result<Response, Error> {
316    trace!("starting WS connection to {}", client.url());
317
318    let initiated_close = Arc::new(AtomicBool::new(false));
319    let dom_receiver = setup_dom_listener(dom_action_receiver, initiated_close.clone());
320
321    let url = client.url();
322    let host = replace_host(url.host_str().expect("URL has no host"));
323    let mut net_url = client.url().into_url();
324    net_url
325        .set_host(Some(&host))
326        .map_err(|e| Error::Url(UrlError::UnableToConnect(e.to_string())))?;
327
328    let domain = net_url
329        .host()
330        .ok_or_else(|| Error::Url(UrlError::NoHostName))?;
331    let port = net_url
332        .port_or_known_default()
333        .ok_or_else(|| Error::Url(UrlError::UnableToConnect("Unknown port".into())))?;
334
335    let try_socket = TcpStream::connect((&*domain.to_string(), port)).await;
336    let socket = try_socket.map_err(Error::Io)?;
337    let connector = TlsConnector::from(Arc::new(tls_config));
338
339    // TODO(pylbrecht): move request conversion to a separate function
340    let mut original_url = client.original_url();
341    if original_url.scheme() == "ws" && url.scheme() == "https" {
342        original_url.as_mut_url().set_scheme("wss").unwrap();
343    }
344    let mut builder = ClientRequestBuilder::new(
345        original_url
346            .into_string()
347            .parse()
348            .expect("unable to parse URI"),
349    );
350    for (key, value) in client.headers.iter() {
351        builder = builder.with_header(
352            key.as_str(),
353            value
354                .to_str()
355                .expect("unable to convert header value to string"),
356        );
357    }
358
359    let (stream, response) =
360        client_async_tls_with_connector_and_config(builder, socket, Some(connector), None).await?;
361
362    let protocol_in_use = process_ws_response(&http_state, &response, &url, protocols)?;
363
364    if !initiated_close.load(Ordering::SeqCst) {
365        if resource_event_sender
366            .send(WebSocketNetworkEvent::ConnectionEstablished { protocol_in_use })
367            .is_err()
368        {
369            return Ok(response);
370        }
371
372        trace!("about to start ws loop for {}", url);
373        spawn_task(run_ws_loop(dom_receiver, resource_event_sender, stream));
374    } else {
375        trace!("client closed connection for {}, not running loop", url);
376    }
377    Ok(response)
378}