1use std::sync::Arc;
15use std::sync::atomic::{AtomicBool, Ordering};
16
17use async_tungstenite::WebSocketStream;
18use async_tungstenite::tokio::{ConnectStream, client_async_tls_with_connector_and_config};
19use base64::Engine;
20use futures::stream::StreamExt;
21use http::HeaderMap;
22use http::header::{self, HeaderName, HeaderValue};
23use ipc_channel::ipc::{IpcReceiver, IpcSender};
24use ipc_channel::router::ROUTER;
25use log::{debug, trace, warn};
26use net_traits::request::{RequestBuilder, RequestMode};
27use net_traits::{CookieSource, MessageData, WebSocketDomAction, WebSocketNetworkEvent};
28use servo_url::ServoUrl;
29use tokio::net::TcpStream;
30use tokio::select;
31use tokio::sync::mpsc::{UnboundedReceiver, unbounded_channel};
32use tokio_rustls::TlsConnector;
33use tungstenite::error::{Error, ProtocolError, UrlError};
34use tungstenite::handshake::client::Response;
35use tungstenite::protocol::CloseFrame;
36use tungstenite::{ClientRequestBuilder, Message};
37
38use crate::async_runtime::spawn_task;
39use crate::connector::TlsConfig;
40use crate::cookie::ServoCookie;
41use crate::hosts::replace_host;
42use crate::http_loader::HttpState;
43
44#[allow(clippy::result_large_err)]
45pub fn create_handshake_request(
51 request: RequestBuilder,
52 http_state: Arc<HttpState>,
53) -> Result<net_traits::request::Request, Error> {
54 let origin = request.url.origin();
55
56 let mut headers = HeaderMap::new();
57 headers.insert(
58 "Origin",
59 HeaderValue::from_str(&request.url.origin().ascii_serialization())?,
60 );
61
62 let host = format!(
63 "{}",
64 origin
65 .host()
66 .ok_or_else(|| Error::Url(UrlError::NoHostName))?
67 );
68 headers.insert("Host", HeaderValue::from_str(&host)?);
69 headers.insert("Upgrade", HeaderValue::from_static("websocket"));
72
73 headers.insert("Connection", HeaderValue::from_static("upgrade"));
75
76 let key = HeaderValue::from_str(&tungstenite::handshake::client::generate_key()).unwrap();
79
80 headers.insert("Sec-WebSocket-Key", key);
82
83 headers.insert("Sec-Websocket-Version", HeaderValue::from_static("13"));
85
86 let protocols = match request.mode {
89 RequestMode::WebSocket {
90 ref protocols,
91 original_url: _,
92 } => protocols,
93 _ => unreachable!("How did we get here?"),
94 };
95 if !protocols.is_empty() {
96 let protocols = protocols.join(",");
97 headers.insert("Sec-WebSocket-Protocol", HeaderValue::from_str(&protocols)?);
98 }
99
100 let mut cookie_jar = http_state.cookie_jar.write().unwrap();
101 cookie_jar.remove_expired_cookies_for_url(&request.url);
102 if let Some(cookie_list) = cookie_jar.cookies_for_url(&request.url, CookieSource::HTTP) {
103 headers.insert("Cookie", HeaderValue::from_str(&cookie_list)?);
104 }
105
106 if request.url.password().is_some() || request.url.username() != "" {
107 let basic = base64::engine::general_purpose::STANDARD.encode(format!(
108 "{}:{}",
109 request.url.username(),
110 request.url.password().unwrap_or("")
111 ));
112 headers.insert(
113 "Authorization",
114 HeaderValue::from_str(&format!("Basic {}", basic))?,
115 );
116 }
117 Ok(request.headers(headers).build())
118}
119
120#[allow(clippy::result_large_err)]
121fn process_ws_response(
126 http_state: &HttpState,
127 response: &Response,
128 resource_url: &ServoUrl,
129 protocols: &[String],
130) -> Result<Option<String>, Error> {
131 trace!("processing websocket http response for {}", resource_url);
132 let mut protocol_in_use = None;
133 if let Some(protocol_name) = response.headers().get("Sec-WebSocket-Protocol") {
134 let protocol_name = protocol_name.to_str().unwrap_or("");
135 if !protocols.is_empty() && !protocols.iter().any(|p| protocol_name == (*p)) {
136 return Err(Error::Protocol(ProtocolError::InvalidHeader(Box::new(
137 HeaderName::from_static("sec-websocket-protocol"),
138 ))));
139 }
140 protocol_in_use = Some(protocol_name.to_string());
141 }
142
143 let mut jar = http_state.cookie_jar.write().unwrap();
144 for cookie in response.headers().get_all(header::SET_COOKIE) {
146 let cookie_bytes = cookie.as_bytes();
147 if !ServoCookie::is_valid_name_or_value(cookie_bytes) {
148 continue;
149 }
150 if let Ok(s) = std::str::from_utf8(cookie_bytes) {
151 if let Some(cookie) =
152 ServoCookie::from_cookie_string(s, resource_url, CookieSource::HTTP)
153 {
154 jar.push(cookie, resource_url, CookieSource::HTTP);
155 }
156 }
157 }
158
159 http_state
160 .hsts_list
161 .write()
162 .unwrap()
163 .update_hsts_list_from_response(resource_url, response.headers());
164
165 Ok(protocol_in_use)
166}
167
168#[derive(Debug)]
169enum DomMsg {
170 Send(Message),
171 Close(Option<(u16, String)>),
172}
173
174fn setup_dom_listener(
177 dom_action_receiver: IpcReceiver<WebSocketDomAction>,
178 initiated_close: Arc<AtomicBool>,
179) -> UnboundedReceiver<DomMsg> {
180 let (sender, receiver) = unbounded_channel();
181
182 ROUTER.add_typed_route(
183 dom_action_receiver,
184 Box::new(move |message| {
185 let dom_action = message.expect("Ws dom_action message to deserialize");
186 trace!("handling WS DOM action: {:?}", dom_action);
187 match dom_action {
188 WebSocketDomAction::SendMessage(MessageData::Text(data)) => {
189 if let Err(e) = sender.send(DomMsg::Send(Message::Text(data.into()))) {
190 warn!("Error sending websocket message: {:?}", e);
191 }
192 },
193 WebSocketDomAction::SendMessage(MessageData::Binary(data)) => {
194 if let Err(e) = sender.send(DomMsg::Send(Message::Binary(data.into()))) {
195 warn!("Error sending websocket message: {:?}", e);
196 }
197 },
198 WebSocketDomAction::Close(code, reason) => {
199 if initiated_close.fetch_or(true, Ordering::SeqCst) {
200 return;
201 }
202 let frame = code.map(move |c| (c, reason.unwrap_or_default()));
203 if let Err(e) = sender.send(DomMsg::Close(frame)) {
204 warn!("Error closing websocket: {:?}", e);
205 }
206 },
207 }
208 }),
209 );
210
211 receiver
212}
213
214async fn run_ws_loop(
219 mut dom_receiver: UnboundedReceiver<DomMsg>,
220 resource_event_sender: IpcSender<WebSocketNetworkEvent>,
221 mut stream: WebSocketStream<ConnectStream>,
222) {
223 loop {
224 select! {
225 dom_msg = dom_receiver.recv() => {
226 trace!("processing dom msg: {:?}", dom_msg);
227 let dom_msg = match dom_msg {
228 Some(msg) => msg,
229 None => break,
230 };
231 match dom_msg {
232 DomMsg::Send(m) => {
233 if let Err(e) = stream.send(m).await {
234 warn!("error sending websocket message: {:?}", e);
235 }
236 },
237 DomMsg::Close(frame) => {
238 if let Err(e) = stream.close(frame.map(|(code, reason)| {
239 CloseFrame {
240 code: code.into(),
241 reason: reason.into(),
242 }
243 })).await {
244 warn!("error closing websocket: {:?}", e);
245 }
246 },
247 }
248 }
249 ws_msg = stream.next() => {
250 trace!("processing WS stream: {:?}", ws_msg);
251 let msg = match ws_msg {
252 Some(Ok(msg)) => msg,
253 Some(Err(e)) => {
254 warn!("Error in WebSocket communication: {:?}", e);
255 let _ = resource_event_sender.send(WebSocketNetworkEvent::Fail);
256 break;
257 },
258 None => {
259 warn!("Error in WebSocket communication");
260 let _ = resource_event_sender.send(WebSocketNetworkEvent::Fail);
261 break;
262 }
263 };
264 match msg {
265 Message::Text(s) => {
266 let message = MessageData::Text(s.as_str().to_owned());
267 if let Err(e) = resource_event_sender
268 .send(WebSocketNetworkEvent::MessageReceived(message))
269 {
270 warn!("Error sending websocket notification: {:?}", e);
271 break;
272 }
273 }
274
275 Message::Binary(v) => {
276 let message = MessageData::Binary(v.to_vec());
277 if let Err(e) = resource_event_sender
278 .send(WebSocketNetworkEvent::MessageReceived(message))
279 {
280 warn!("Error sending websocket notification: {:?}", e);
281 break;
282 }
283 }
284
285 Message::Ping(_) | Message::Pong(_) => {}
286
287 Message::Close(frame) => {
288 let (reason, code) = match frame {
289 Some(frame) => (frame.reason, Some(frame.code.into())),
290 None => ("".into(), None),
291 };
292 debug!("Websocket connection closing due to ({:?}) {}", code, reason);
293 let _ = resource_event_sender.send(WebSocketNetworkEvent::Close(
294 code,
295 reason.to_string(),
296 ));
297 break;
298 }
299
300 Message::Frame(_) => {
301 warn!("Unexpected websocket frame message");
302 }
303 }
304 }
305 }
306 }
307}
308
309pub(crate) async fn start_websocket(
313 http_state: Arc<HttpState>,
314 resource_event_sender: IpcSender<WebSocketNetworkEvent>,
315 protocols: &[String],
316 client: &net_traits::request::Request,
317 tls_config: TlsConfig,
318 dom_action_receiver: IpcReceiver<WebSocketDomAction>,
319) -> Result<Response, Error> {
320 trace!("starting WS connection to {}", client.url());
321
322 let initiated_close = Arc::new(AtomicBool::new(false));
323 let dom_receiver = setup_dom_listener(dom_action_receiver, initiated_close.clone());
324
325 let url = client.url();
326 let host = replace_host(url.host_str().expect("URL has no host"));
327 let mut net_url = client.url().into_url();
328 net_url
329 .set_host(Some(&host))
330 .map_err(|e| Error::Url(UrlError::UnableToConnect(e.to_string())))?;
331
332 let domain = net_url
333 .host()
334 .ok_or_else(|| Error::Url(UrlError::NoHostName))?;
335 let port = net_url
336 .port_or_known_default()
337 .ok_or_else(|| Error::Url(UrlError::UnableToConnect("Unknown port".into())))?;
338
339 let try_socket = TcpStream::connect((&*domain.to_string(), port)).await;
340 let socket = try_socket.map_err(Error::Io)?;
341 let connector = TlsConnector::from(Arc::new(tls_config));
342
343 let mut original_url = client.original_url();
345 if original_url.scheme() == "ws" && url.scheme() == "https" {
346 original_url.as_mut_url().set_scheme("wss").unwrap();
347 }
348 let mut builder =
349 ClientRequestBuilder::new(original_url.as_str().parse().expect("unable to parse URI"));
350 for (key, value) in client.headers.iter() {
351 builder = builder.with_header(
352 key.as_str(),
353 value
354 .to_str()
355 .expect("unable to convert header value to string"),
356 );
357 }
358
359 let (stream, response) =
360 client_async_tls_with_connector_and_config(builder, socket, Some(connector), None).await?;
361
362 let protocol_in_use = process_ws_response(&http_state, &response, &url, protocols)?;
363
364 if !initiated_close.load(Ordering::SeqCst) {
365 if resource_event_sender
366 .send(WebSocketNetworkEvent::ConnectionEstablished { protocol_in_use })
367 .is_err()
368 {
369 return Ok(response);
370 }
371
372 trace!("about to start ws loop for {}", url);
373 spawn_task(run_ws_loop(dom_receiver, resource_event_sender, stream));
374 } else {
375 trace!("client closed connection for {}, not running loop", url);
376 }
377 Ok(response)
378}