net/
websocket_loader.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
4
5//! The websocket handler has three main responsibilities:
6//! 1) initiate the initial HTTP connection and process the response
7//! 2) ensure any DOM requests for sending/closing are propagated to the network
8//! 3) transmit any incoming messages/closing to the DOM
9//!
10//! In order to accomplish this, the handler uses a long-running loop that selects
11//! over events from the network and events from the DOM, using async/await to avoid
12//! the need for a dedicated thread per websocket.
13
14use std::sync::Arc;
15use std::sync::atomic::{AtomicBool, Ordering};
16
17use async_tungstenite::WebSocketStream;
18use async_tungstenite::tokio::{ConnectStream, client_async_tls_with_connector_and_config};
19use base::generic_channel::CallbackSetter;
20use base64::Engine;
21use futures::stream::StreamExt;
22use http::HeaderMap;
23use http::header::{self, HeaderName, HeaderValue};
24use ipc_channel::ipc::IpcSender;
25use log::{debug, trace, warn};
26use net_traits::request::{RequestBuilder, RequestMode};
27use net_traits::{CookieSource, MessageData, WebSocketDomAction, WebSocketNetworkEvent};
28use servo_url::ServoUrl;
29use tokio::net::TcpStream;
30use tokio::select;
31use tokio::sync::mpsc::{UnboundedReceiver, unbounded_channel};
32use tokio_rustls::TlsConnector;
33use tungstenite::error::{Error, ProtocolError, UrlError};
34use tungstenite::handshake::client::Response;
35use tungstenite::protocol::CloseFrame;
36use tungstenite::{ClientRequestBuilder, Message};
37
38use crate::async_runtime::spawn_task;
39use crate::connector::TlsConfig;
40use crate::cookie::ServoCookie;
41use crate::hosts::replace_host;
42use crate::http_loader::HttpState;
43
44/// Create a Request object for the initial HTTP request.
45/// This request contains `Origin`, `Sec-WebSocket-Protocol`, `Authorization`,
46/// and `Cookie` headers as appropriate.
47/// Returns an error if any header values are invalid or tungstenite cannot create
48/// the desired request.
49pub fn create_handshake_request(
50    request: RequestBuilder,
51    http_state: Arc<HttpState>,
52) -> Result<net_traits::request::Request, Error> {
53    let origin = request.url.origin();
54
55    let mut headers = HeaderMap::new();
56    headers.insert(
57        "Origin",
58        HeaderValue::from_str(&request.url.origin().ascii_serialization())?,
59    );
60
61    let host = format!(
62        "{}",
63        origin
64            .host()
65            .ok_or_else(|| Error::Url(UrlError::NoHostName))?
66    );
67    headers.insert("Host", HeaderValue::from_str(&host)?);
68    // https://websockets.spec.whatwg.org/#concept-websocket-establish
69    // 3. Append (`Upgrade`, `websocket`) to request’s header list.
70    headers.insert("Upgrade", HeaderValue::from_static("websocket"));
71
72    // 4. Append (`Connection`, `Upgrade`) to request’s header list.
73    headers.insert("Connection", HeaderValue::from_static("upgrade"));
74
75    // 5. Let keyValue be a nonce consisting of a randomly selected 16-byte value that has been
76    // forgiving-base64-encoded and isomorphic encoded.
77    let key = HeaderValue::from_str(&tungstenite::handshake::client::generate_key()).unwrap();
78
79    // 6. Append (`Sec-WebSocket-Key`, keyValue) to request’s header list.
80    headers.insert("Sec-WebSocket-Key", key);
81
82    // 7. Append (`Sec-WebSocket-Version`, `13`) to request’s header list.
83    headers.insert("Sec-Websocket-Version", HeaderValue::from_static("13"));
84
85    // 8. For each protocol in protocols, combine (`Sec-WebSocket-Protocol`, protocol) in request’s
86    // header list.
87    let protocols = match request.mode {
88        RequestMode::WebSocket {
89            ref protocols,
90            original_url: _,
91        } => protocols,
92        _ => unreachable!("How did we get here?"),
93    };
94    if !protocols.is_empty() {
95        let protocols = protocols.join(",");
96        headers.insert("Sec-WebSocket-Protocol", HeaderValue::from_str(&protocols)?);
97    }
98
99    let mut cookie_jar = http_state.cookie_jar.write();
100    cookie_jar.remove_expired_cookies_for_url(&request.url);
101    if let Some(cookie_list) = cookie_jar.cookies_for_url(&request.url, CookieSource::HTTP) {
102        headers.insert("Cookie", HeaderValue::from_str(&cookie_list)?);
103    }
104
105    if request.url.password().is_some() || request.url.username() != "" {
106        let basic = base64::engine::general_purpose::STANDARD.encode(format!(
107            "{}:{}",
108            request.url.username(),
109            request.url.password().unwrap_or("")
110        ));
111        headers.insert(
112            "Authorization",
113            HeaderValue::from_str(&format!("Basic {}", basic))?,
114        );
115    }
116    Ok(request.headers(headers).build())
117}
118
119/// Process an HTTP response resulting from a WS handshake.
120/// This ensures that any `Cookie` or HSTS headers are recognized.
121/// Returns an error if the protocol selected by the handshake doesn't
122/// match the list of provided protocols in the original request.
123fn process_ws_response(
124    http_state: &HttpState,
125    response: &Response,
126    resource_url: &ServoUrl,
127    protocols: &[String],
128) -> Result<Option<String>, Error> {
129    trace!("processing websocket http response for {}", resource_url);
130    let mut protocol_in_use = None;
131    if let Some(protocol_name) = response.headers().get("Sec-WebSocket-Protocol") {
132        let protocol_name = protocol_name.to_str().unwrap_or("");
133        if !protocols.is_empty() && !protocols.iter().any(|p| protocol_name == (*p)) {
134            return Err(Error::Protocol(ProtocolError::InvalidHeader(Box::new(
135                HeaderName::from_static("sec-websocket-protocol"),
136            ))));
137        }
138        protocol_in_use = Some(protocol_name.to_string());
139    }
140
141    let mut jar = http_state.cookie_jar.write();
142    // TODO(eijebong): Replace thise once typed headers settled on a cookie impl
143    for cookie in response.headers().get_all(header::SET_COOKIE) {
144        let cookie_bytes = cookie.as_bytes();
145        if !ServoCookie::is_valid_name_or_value(cookie_bytes) {
146            continue;
147        }
148        if let Ok(s) = std::str::from_utf8(cookie_bytes) {
149            if let Some(cookie) =
150                ServoCookie::from_cookie_string(s, resource_url, CookieSource::HTTP)
151            {
152                jar.push(cookie, resource_url, CookieSource::HTTP);
153            }
154        }
155    }
156
157    http_state
158        .hsts_list
159        .write()
160        .update_hsts_list_from_response(resource_url, response.headers());
161
162    Ok(protocol_in_use)
163}
164
165#[derive(Debug)]
166enum DomMsg {
167    Send(Message),
168    Close(Option<(u16, String)>),
169}
170
171/// Initialize a listener for DOM actions. These are routed from the IPC channel
172/// to a tokio channel that the main WS client task uses to receive them.
173fn setup_dom_listener(
174    dom_action_receiver: CallbackSetter<WebSocketDomAction>,
175    initiated_close: Arc<AtomicBool>,
176) -> UnboundedReceiver<DomMsg> {
177    let (sender, receiver) = unbounded_channel();
178
179    dom_action_receiver.set_callback(move |message| {
180        let dom_action = message.expect("Ws dom_action message to deserialize");
181        trace!("handling WS DOM action: {:?}", dom_action);
182        match dom_action {
183            WebSocketDomAction::SendMessage(MessageData::Text(data)) => {
184                if let Err(e) = sender.send(DomMsg::Send(Message::Text(data.into()))) {
185                    warn!("Error sending websocket message: {:?}", e);
186                }
187            },
188            WebSocketDomAction::SendMessage(MessageData::Binary(data)) => {
189                if let Err(e) = sender.send(DomMsg::Send(Message::Binary(data.into()))) {
190                    warn!("Error sending websocket message: {:?}", e);
191                }
192            },
193            WebSocketDomAction::Close(code, reason) => {
194                if initiated_close.fetch_or(true, Ordering::SeqCst) {
195                    return;
196                }
197                let frame = code.map(move |c| (c, reason.unwrap_or_default()));
198                if let Err(e) = sender.send(DomMsg::Close(frame)) {
199                    warn!("Error closing websocket: {:?}", e);
200                }
201            },
202        }
203    });
204
205    receiver
206}
207
208/// Listen for WS events from the DOM and the network until one side
209/// closes the connection or an error occurs. Since this is an async
210/// function that uses the select operation, it will run as a task
211/// on the WS tokio runtime.
212async fn run_ws_loop(
213    mut dom_receiver: UnboundedReceiver<DomMsg>,
214    resource_event_sender: IpcSender<WebSocketNetworkEvent>,
215    mut stream: WebSocketStream<ConnectStream>,
216) {
217    loop {
218        select! {
219            dom_msg = dom_receiver.recv() => {
220                trace!("processing dom msg: {:?}", dom_msg);
221                let dom_msg = match dom_msg {
222                    Some(msg) => msg,
223                    None => break,
224                };
225                match dom_msg {
226                    DomMsg::Send(m) => {
227                        if let Err(e) = stream.send(m).await {
228                            warn!("error sending websocket message: {:?}", e);
229                        }
230                    },
231                    DomMsg::Close(frame) => {
232                        if let Err(e) = stream.close(frame.map(|(code, reason)| {
233                            CloseFrame {
234                                code: code.into(),
235                                reason: reason.into(),
236                            }
237                        })).await {
238                            warn!("error closing websocket: {:?}", e);
239                        }
240                    },
241                }
242            }
243            ws_msg = stream.next() => {
244                trace!("processing WS stream: {:?}", ws_msg);
245                let msg = match ws_msg {
246                    Some(Ok(msg)) => msg,
247                    Some(Err(e)) => {
248                        warn!("Error in WebSocket communication: {:?}", e);
249                        let _ = resource_event_sender.send(WebSocketNetworkEvent::Fail);
250                        break;
251                    },
252                    None => {
253                        warn!("Error in WebSocket communication");
254                        let _ = resource_event_sender.send(WebSocketNetworkEvent::Fail);
255                        break;
256                    }
257                };
258                match msg {
259                    Message::Text(s) => {
260                        let message = MessageData::Text(s.as_str().to_owned());
261                        if let Err(e) = resource_event_sender
262                            .send(WebSocketNetworkEvent::MessageReceived(message))
263                        {
264                            warn!("Error sending websocket notification: {:?}", e);
265                            break;
266                        }
267                    }
268
269                    Message::Binary(v) => {
270                        let message = MessageData::Binary(v.to_vec());
271                        if let Err(e) = resource_event_sender
272                            .send(WebSocketNetworkEvent::MessageReceived(message))
273                        {
274                            warn!("Error sending websocket notification: {:?}", e);
275                            break;
276                        }
277                    }
278
279                    Message::Ping(_) | Message::Pong(_) => {}
280
281                    Message::Close(frame) => {
282                        let (reason, code) = match frame {
283                            Some(frame) => (frame.reason, Some(frame.code.into())),
284                            None => ("".into(), None),
285                        };
286                        debug!("Websocket connection closing due to ({:?}) {}", code, reason);
287                        let _ = resource_event_sender.send(WebSocketNetworkEvent::Close(
288                            code,
289                            reason.to_string(),
290                        ));
291                        break;
292                    }
293
294                    Message::Frame(_) => {
295                        warn!("Unexpected websocket frame message");
296                    }
297                }
298            }
299        }
300    }
301}
302
303/// Initiate a new async WS connection. Returns an error if the connection fails
304/// for any reason, or if the response isn't valid. Otherwise, the endless WS
305/// listening loop will be started.
306pub(crate) async fn start_websocket(
307    http_state: Arc<HttpState>,
308    resource_event_sender: IpcSender<WebSocketNetworkEvent>,
309    protocols: &[String],
310    client: &net_traits::request::Request,
311    tls_config: TlsConfig,
312    dom_action_receiver: CallbackSetter<WebSocketDomAction>,
313) -> Result<Response, Error> {
314    trace!("starting WS connection to {}", client.url());
315
316    let initiated_close = Arc::new(AtomicBool::new(false));
317    let dom_receiver = setup_dom_listener(dom_action_receiver, initiated_close.clone());
318
319    let url = client.url();
320    let host = replace_host(url.host_str().expect("URL has no host"));
321    let mut net_url = client.url().into_url();
322    net_url
323        .set_host(Some(&host))
324        .map_err(|e| Error::Url(UrlError::UnableToConnect(e.to_string())))?;
325
326    let domain = net_url
327        .host()
328        .ok_or_else(|| Error::Url(UrlError::NoHostName))?;
329    let port = net_url
330        .port_or_known_default()
331        .ok_or_else(|| Error::Url(UrlError::UnableToConnect("Unknown port".into())))?;
332
333    let try_socket = TcpStream::connect((&*domain.to_string(), port)).await;
334    let socket = try_socket.map_err(Error::Io)?;
335    let connector = TlsConnector::from(Arc::new(tls_config));
336
337    // TODO(pylbrecht): move request conversion to a separate function
338    let mut original_url = client.original_url();
339    if original_url.scheme() == "ws" && url.scheme() == "https" {
340        original_url.as_mut_url().set_scheme("wss").unwrap();
341    }
342    let mut builder =
343        ClientRequestBuilder::new(original_url.as_str().parse().expect("unable to parse URI"));
344    for (key, value) in client.headers.iter() {
345        builder = builder.with_header(
346            key.as_str(),
347            value
348                .to_str()
349                .expect("unable to convert header value to string"),
350        );
351    }
352
353    let (stream, response) =
354        client_async_tls_with_connector_and_config(builder, socket, Some(connector), None).await?;
355
356    let protocol_in_use = process_ws_response(&http_state, &response, &url, protocols)?;
357
358    if !initiated_close.load(Ordering::SeqCst) {
359        if resource_event_sender
360            .send(WebSocketNetworkEvent::ConnectionEstablished { protocol_in_use })
361            .is_err()
362        {
363            return Ok(response);
364        }
365
366        trace!("about to start ws loop for {}", url);
367        spawn_task(run_ws_loop(dom_receiver, resource_event_sender, stream));
368    } else {
369        trace!("client closed connection for {}, not running loop", url);
370    }
371    Ok(response)
372}