1use std::sync::Arc;
15use std::sync::atomic::{AtomicBool, Ordering};
16
17use async_tungstenite::WebSocketStream;
18use async_tungstenite::tokio::{ConnectStream, client_async_tls_with_connector_and_config};
19use base64::Engine;
20use futures::stream::StreamExt;
21use http::HeaderMap;
22use http::header::{self, HeaderName, HeaderValue};
23use ipc_channel::ipc::{IpcReceiver, IpcSender};
24use ipc_channel::router::ROUTER;
25use log::{debug, trace, warn};
26use net_traits::request::{RequestBuilder, RequestMode};
27use net_traits::{CookieSource, MessageData, WebSocketDomAction, WebSocketNetworkEvent};
28use servo_url::ServoUrl;
29use tokio::net::TcpStream;
30use tokio::select;
31use tokio::sync::mpsc::{UnboundedReceiver, unbounded_channel};
32use tokio_rustls::TlsConnector;
33use tungstenite::error::{Error, ProtocolError, UrlError};
34use tungstenite::handshake::client::Response;
35use tungstenite::protocol::CloseFrame;
36use tungstenite::{ClientRequestBuilder, Message};
37
38use crate::async_runtime::spawn_task;
39use crate::connector::TlsConfig;
40use crate::cookie::ServoCookie;
41use crate::hosts::replace_host;
42use crate::http_loader::HttpState;
43
44pub fn create_handshake_request(
50 request: RequestBuilder,
51 http_state: Arc<HttpState>,
52) -> Result<net_traits::request::Request, Error> {
53 let origin = request.url.origin();
54
55 let mut headers = HeaderMap::new();
56 headers.insert(
57 "Origin",
58 HeaderValue::from_str(&request.url.origin().ascii_serialization())?,
59 );
60
61 let host = format!(
62 "{}",
63 origin
64 .host()
65 .ok_or_else(|| Error::Url(UrlError::NoHostName))?
66 );
67 headers.insert("Host", HeaderValue::from_str(&host)?);
68 headers.insert("Upgrade", HeaderValue::from_static("websocket"));
71
72 headers.insert("Connection", HeaderValue::from_static("upgrade"));
74
75 let key = HeaderValue::from_str(&tungstenite::handshake::client::generate_key()).unwrap();
78
79 headers.insert("Sec-WebSocket-Key", key);
81
82 headers.insert("Sec-Websocket-Version", HeaderValue::from_static("13"));
84
85 let protocols = match request.mode {
88 RequestMode::WebSocket {
89 ref protocols,
90 original_url: _,
91 } => protocols,
92 _ => unreachable!("How did we get here?"),
93 };
94 if !protocols.is_empty() {
95 let protocols = protocols.join(",");
96 headers.insert("Sec-WebSocket-Protocol", HeaderValue::from_str(&protocols)?);
97 }
98
99 let mut cookie_jar = http_state.cookie_jar.write();
100 cookie_jar.remove_expired_cookies_for_url(&request.url);
101 if let Some(cookie_list) = cookie_jar.cookies_for_url(&request.url, CookieSource::HTTP) {
102 headers.insert("Cookie", HeaderValue::from_str(&cookie_list)?);
103 }
104
105 if request.url.password().is_some() || request.url.username() != "" {
106 let basic = base64::engine::general_purpose::STANDARD.encode(format!(
107 "{}:{}",
108 request.url.username(),
109 request.url.password().unwrap_or("")
110 ));
111 headers.insert(
112 "Authorization",
113 HeaderValue::from_str(&format!("Basic {}", basic))?,
114 );
115 }
116 Ok(request.headers(headers).build())
117}
118
119fn process_ws_response(
124 http_state: &HttpState,
125 response: &Response,
126 resource_url: &ServoUrl,
127 protocols: &[String],
128) -> Result<Option<String>, Error> {
129 trace!("processing websocket http response for {}", resource_url);
130 let mut protocol_in_use = None;
131 if let Some(protocol_name) = response.headers().get("Sec-WebSocket-Protocol") {
132 let protocol_name = protocol_name.to_str().unwrap_or("");
133 if !protocols.is_empty() && !protocols.iter().any(|p| protocol_name == (*p)) {
134 return Err(Error::Protocol(ProtocolError::InvalidHeader(Box::new(
135 HeaderName::from_static("sec-websocket-protocol"),
136 ))));
137 }
138 protocol_in_use = Some(protocol_name.to_string());
139 }
140
141 let mut jar = http_state.cookie_jar.write();
142 for cookie in response.headers().get_all(header::SET_COOKIE) {
144 let cookie_bytes = cookie.as_bytes();
145 if !ServoCookie::is_valid_name_or_value(cookie_bytes) {
146 continue;
147 }
148 if let Ok(s) = std::str::from_utf8(cookie_bytes) {
149 if let Some(cookie) =
150 ServoCookie::from_cookie_string(s, resource_url, CookieSource::HTTP)
151 {
152 jar.push(cookie, resource_url, CookieSource::HTTP);
153 }
154 }
155 }
156
157 http_state
158 .hsts_list
159 .write()
160 .update_hsts_list_from_response(resource_url, response.headers());
161
162 Ok(protocol_in_use)
163}
164
165#[derive(Debug)]
166enum DomMsg {
167 Send(Message),
168 Close(Option<(u16, String)>),
169}
170
171fn setup_dom_listener(
174 dom_action_receiver: IpcReceiver<WebSocketDomAction>,
175 initiated_close: Arc<AtomicBool>,
176) -> UnboundedReceiver<DomMsg> {
177 let (sender, receiver) = unbounded_channel();
178
179 ROUTER.add_typed_route(
180 dom_action_receiver,
181 Box::new(move |message| {
182 let dom_action = message.expect("Ws dom_action message to deserialize");
183 trace!("handling WS DOM action: {:?}", dom_action);
184 match dom_action {
185 WebSocketDomAction::SendMessage(MessageData::Text(data)) => {
186 if let Err(e) = sender.send(DomMsg::Send(Message::Text(data.into()))) {
187 warn!("Error sending websocket message: {:?}", e);
188 }
189 },
190 WebSocketDomAction::SendMessage(MessageData::Binary(data)) => {
191 if let Err(e) = sender.send(DomMsg::Send(Message::Binary(data.into()))) {
192 warn!("Error sending websocket message: {:?}", e);
193 }
194 },
195 WebSocketDomAction::Close(code, reason) => {
196 if initiated_close.fetch_or(true, Ordering::SeqCst) {
197 return;
198 }
199 let frame = code.map(move |c| (c, reason.unwrap_or_default()));
200 if let Err(e) = sender.send(DomMsg::Close(frame)) {
201 warn!("Error closing websocket: {:?}", e);
202 }
203 },
204 }
205 }),
206 );
207
208 receiver
209}
210
211async fn run_ws_loop(
216 mut dom_receiver: UnboundedReceiver<DomMsg>,
217 resource_event_sender: IpcSender<WebSocketNetworkEvent>,
218 mut stream: WebSocketStream<ConnectStream>,
219) {
220 loop {
221 select! {
222 dom_msg = dom_receiver.recv() => {
223 trace!("processing dom msg: {:?}", dom_msg);
224 let dom_msg = match dom_msg {
225 Some(msg) => msg,
226 None => break,
227 };
228 match dom_msg {
229 DomMsg::Send(m) => {
230 if let Err(e) = stream.send(m).await {
231 warn!("error sending websocket message: {:?}", e);
232 }
233 },
234 DomMsg::Close(frame) => {
235 if let Err(e) = stream.close(frame.map(|(code, reason)| {
236 CloseFrame {
237 code: code.into(),
238 reason: reason.into(),
239 }
240 })).await {
241 warn!("error closing websocket: {:?}", e);
242 }
243 },
244 }
245 }
246 ws_msg = stream.next() => {
247 trace!("processing WS stream: {:?}", ws_msg);
248 let msg = match ws_msg {
249 Some(Ok(msg)) => msg,
250 Some(Err(e)) => {
251 warn!("Error in WebSocket communication: {:?}", e);
252 let _ = resource_event_sender.send(WebSocketNetworkEvent::Fail);
253 break;
254 },
255 None => {
256 warn!("Error in WebSocket communication");
257 let _ = resource_event_sender.send(WebSocketNetworkEvent::Fail);
258 break;
259 }
260 };
261 match msg {
262 Message::Text(s) => {
263 let message = MessageData::Text(s.as_str().to_owned());
264 if let Err(e) = resource_event_sender
265 .send(WebSocketNetworkEvent::MessageReceived(message))
266 {
267 warn!("Error sending websocket notification: {:?}", e);
268 break;
269 }
270 }
271
272 Message::Binary(v) => {
273 let message = MessageData::Binary(v.to_vec());
274 if let Err(e) = resource_event_sender
275 .send(WebSocketNetworkEvent::MessageReceived(message))
276 {
277 warn!("Error sending websocket notification: {:?}", e);
278 break;
279 }
280 }
281
282 Message::Ping(_) | Message::Pong(_) => {}
283
284 Message::Close(frame) => {
285 let (reason, code) = match frame {
286 Some(frame) => (frame.reason, Some(frame.code.into())),
287 None => ("".into(), None),
288 };
289 debug!("Websocket connection closing due to ({:?}) {}", code, reason);
290 let _ = resource_event_sender.send(WebSocketNetworkEvent::Close(
291 code,
292 reason.to_string(),
293 ));
294 break;
295 }
296
297 Message::Frame(_) => {
298 warn!("Unexpected websocket frame message");
299 }
300 }
301 }
302 }
303 }
304}
305
306pub(crate) async fn start_websocket(
310 http_state: Arc<HttpState>,
311 resource_event_sender: IpcSender<WebSocketNetworkEvent>,
312 protocols: &[String],
313 client: &net_traits::request::Request,
314 tls_config: TlsConfig,
315 dom_action_receiver: IpcReceiver<WebSocketDomAction>,
316) -> Result<Response, Error> {
317 trace!("starting WS connection to {}", client.url());
318
319 let initiated_close = Arc::new(AtomicBool::new(false));
320 let dom_receiver = setup_dom_listener(dom_action_receiver, initiated_close.clone());
321
322 let url = client.url();
323 let host = replace_host(url.host_str().expect("URL has no host"));
324 let mut net_url = client.url().into_url();
325 net_url
326 .set_host(Some(&host))
327 .map_err(|e| Error::Url(UrlError::UnableToConnect(e.to_string())))?;
328
329 let domain = net_url
330 .host()
331 .ok_or_else(|| Error::Url(UrlError::NoHostName))?;
332 let port = net_url
333 .port_or_known_default()
334 .ok_or_else(|| Error::Url(UrlError::UnableToConnect("Unknown port".into())))?;
335
336 let try_socket = TcpStream::connect((&*domain.to_string(), port)).await;
337 let socket = try_socket.map_err(Error::Io)?;
338 let connector = TlsConnector::from(Arc::new(tls_config));
339
340 let mut original_url = client.original_url();
342 if original_url.scheme() == "ws" && url.scheme() == "https" {
343 original_url.as_mut_url().set_scheme("wss").unwrap();
344 }
345 let mut builder =
346 ClientRequestBuilder::new(original_url.as_str().parse().expect("unable to parse URI"));
347 for (key, value) in client.headers.iter() {
348 builder = builder.with_header(
349 key.as_str(),
350 value
351 .to_str()
352 .expect("unable to convert header value to string"),
353 );
354 }
355
356 let (stream, response) =
357 client_async_tls_with_connector_and_config(builder, socket, Some(connector), None).await?;
358
359 let protocol_in_use = process_ws_response(&http_state, &response, &url, protocols)?;
360
361 if !initiated_close.load(Ordering::SeqCst) {
362 if resource_event_sender
363 .send(WebSocketNetworkEvent::ConnectionEstablished { protocol_in_use })
364 .is_err()
365 {
366 return Ok(response);
367 }
368
369 trace!("about to start ws loop for {}", url);
370 spawn_task(run_ws_loop(dom_receiver, resource_event_sender, stream));
371 } else {
372 trace!("client closed connection for {}, not running loop", url);
373 }
374 Ok(response)
375}