1use std::collections::hash_map::HashMap;
6use std::convert::TryFrom;
7use std::sync::Arc;
8use std::time::Duration;
9use std::{fmt, io};
10
11use futures::task::{Context, Poll};
12use futures::{Future, TryFutureExt};
13use http::uri::{Authority, Uri as Destination};
14use http_body_util::combinators::BoxBody;
15use hyper::body::Bytes;
16use hyper::rt::Executor;
17use hyper_rustls::{HttpsConnector as HyperRustlsHttpsConnector, MaybeHttpsStream};
18use hyper_util::client::legacy::Client;
19use hyper_util::client::legacy::connect::proxy::Tunnel;
20use hyper_util::client::legacy::connect::{
21 Connected, Connection, HttpConnector as HyperHttpConnector,
22};
23use hyper_util::rt::TokioIo;
24use log::warn;
25use parking_lot::Mutex;
26use rustls::client::danger::ServerCertVerifier;
27use rustls::client::{ClientConnection, EchStatus};
28use rustls::crypto::{CryptoProvider, aws_lc_rs};
29use rustls::{ClientConfig, ProtocolVersion};
30use rustls_pki_types::{CertificateDer, ServerName, UnixTime};
31use servo_config::pref;
32use tokio::net::TcpStream;
33use tower::Service;
34
35use crate::async_runtime::spawn_task;
36use crate::hosts::replace_host;
37
38pub const BUF_SIZE: usize = 32768;
39
40pub const ALPN_H2: &str = "h2";
42
43#[derive(Clone)]
44pub struct ServoHttpConnector {
45 inner: HyperHttpConnector,
46}
47
48impl ServoHttpConnector {
49 fn new() -> ServoHttpConnector {
50 let mut inner = HyperHttpConnector::new();
51 inner.enforce_http(false);
52 inner.set_happy_eyeballs_timeout(None);
53 inner.set_connect_timeout(Some(Duration::from_secs(pref!(network_connection_timeout))));
54 ServoHttpConnector { inner }
55 }
56}
57
58impl Service<Destination> for ServoHttpConnector {
59 type Response = TokioIo<TcpStream>;
60 type Error = ConnectionError;
61 type Future =
62 std::pin::Pin<Box<dyn Future<Output = Result<TokioIo<TcpStream>, ConnectionError>> + Send>>;
63
64 fn call(&mut self, dest: Destination) -> Self::Future {
65 let mut new_dest = dest.clone();
67 let mut parts = dest.into_parts();
68
69 if let Some(auth) = parts.authority {
70 let host = auth.host();
71 let host = replace_host(host);
72
73 let authority = if let Some(port) = auth.port() {
74 format!("{}:{}", host, port.as_str())
75 } else {
76 (*host).to_string()
77 };
78
79 if let Ok(authority) = Authority::from_maybe_shared(authority) {
80 parts.authority = Some(authority);
81 if let Ok(dest) = Destination::from_parts(parts) {
82 new_dest = dest
83 }
84 }
85 }
86
87 Box::pin(
88 self.inner
89 .call(new_dest)
90 .map_err(|e| ConnectionError::HttpError(format!("{e}"))),
91 )
92 }
93
94 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
95 Ok(()).into()
96 }
97}
98
99type BoxError = Box<dyn std::error::Error + Send + Sync>;
100
101#[derive(Clone)]
102pub struct InstrumentedConnector<T> {
103 inner: HyperRustlsHttpsConnector<T>,
104}
105
106impl<T> InstrumentedConnector<T> {
107 fn new(inner: HyperRustlsHttpsConnector<T>) -> Self {
108 Self { inner }
109 }
110}
111
112impl<T> From<HyperRustlsHttpsConnector<T>> for InstrumentedConnector<T> {
113 fn from(inner: HyperRustlsHttpsConnector<T>) -> Self {
114 Self::new(inner)
115 }
116}
117
118pub struct InstrumentedStream<T> {
119 inner: MaybeHttpsStream<T>,
120 tls_info: Option<TlsHandshakeInfo>,
121}
122
123impl<T: Unpin> Unpin for InstrumentedStream<T> {}
124
125impl<T> fmt::Debug for InstrumentedStream<T> {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.debug_struct("InstrumentedStream")
128 .field("tls_info", &self.tls_info)
129 .finish()
130 }
131}
132
133#[derive(Clone, Debug)]
134pub struct TlsHandshakeInfo {
135 pub protocol_version: Option<String>,
136 pub cipher_suite: Option<String>,
137 pub kea_group_name: Option<String>,
138 pub signature_scheme_name: Option<String>,
139 pub alpn_protocol: Option<String>,
140 pub certificate_chain_der: Vec<Vec<u8>>,
141 pub used_ech: bool,
142}
143
144impl TlsHandshakeInfo {
145 fn from_connection(conn: &ClientConnection) -> Self {
146 let protocol_version = conn.protocol_version().map(protocol_version_to_string);
147 let cipher_suite = conn
148 .negotiated_cipher_suite()
149 .map(|suite| format!("{:?}", suite.suite()));
150 let kea_group_name = conn
151 .negotiated_key_exchange_group()
152 .map(|group| format!("{:?}", group.name()));
153 let certificate_chain_der = conn
154 .peer_certificates()
155 .map(|certs| certs.iter().map(|cert| cert.as_ref().to_vec()).collect())
156 .unwrap_or_default();
157 let alpn_protocol = conn
158 .alpn_protocol()
159 .map(|proto| String::from_utf8_lossy(proto).into_owned());
160 let used_ech = matches!(conn.ech_status(), EchStatus::Accepted);
161
162 Self {
163 protocol_version,
164 cipher_suite,
165 kea_group_name,
166 signature_scheme_name: None,
167 alpn_protocol,
168 certificate_chain_der,
169 used_ech,
170 }
171 }
172}
173
174fn protocol_version_to_string(version: ProtocolVersion) -> String {
175 match version {
176 ProtocolVersion::TLSv1_3 => "TLS 1.3".to_string(),
177 ProtocolVersion::TLSv1_2 => "TLS 1.2".to_string(),
178 ProtocolVersion::TLSv1_1 => "TLS 1.1".to_string(),
179 ProtocolVersion::TLSv1_0 => "TLS 1.0".to_string(),
180 ProtocolVersion::SSLv2 => "SSL 2.0".to_string(),
181 ProtocolVersion::SSLv3 => "SSL 3.0".to_string(),
182 ProtocolVersion::DTLSv1_0 => "DTLS 1.0".to_string(),
183 ProtocolVersion::DTLSv1_2 => "DTLS 1.2".to_string(),
184 ProtocolVersion::DTLSv1_3 => "DTLS 1.3".to_string(),
185 ProtocolVersion::Unknown(v) => format!("Unknown(0x{v:04x})"),
186 _ => format!("{version:?}"),
187 }
188}
189
190impl<T> InstrumentedStream<T>
191where
192 T: Connection + hyper::rt::Read + hyper::rt::Write + Unpin,
193{
194 fn from_maybe_https_stream(stream: MaybeHttpsStream<T>) -> Self {
195 match stream {
196 MaybeHttpsStream::Http(inner) => Self {
197 inner: MaybeHttpsStream::Http(inner),
198 tls_info: None,
199 },
200 MaybeHttpsStream::Https(tls_stream) => {
201 let (_tcp, tls) = tls_stream.inner().get_ref();
202 let tls_info = TlsHandshakeInfo::from_connection(tls);
203
204 Self {
205 inner: MaybeHttpsStream::Https(tls_stream),
206 tls_info: Some(tls_info),
207 }
208 },
209 }
210 }
211}
212
213impl<T> Connection for InstrumentedStream<T>
214where
215 T: Connection + hyper::rt::Read + hyper::rt::Write + Unpin,
216{
217 fn connected(&self) -> Connected {
218 let connected = match &self.inner {
219 MaybeHttpsStream::Http(stream) => stream.connected(),
220 MaybeHttpsStream::Https(stream) => {
221 let (tcp, tls) = stream.inner().get_ref();
222 if tls.alpn_protocol() == Some(ALPN_H2.as_bytes()) {
223 tcp.inner().connected().negotiated_h2()
224 } else {
225 tcp.inner().connected()
226 }
227 },
228 };
229 if let Some(info) = &self.tls_info {
230 connected.extra(info.clone())
231 } else {
232 connected
233 }
234 }
235}
236
237impl<T> hyper::rt::Read for InstrumentedStream<T>
238where
239 T: Connection + hyper::rt::Read + hyper::rt::Write + Unpin,
240{
241 fn poll_read(
242 self: std::pin::Pin<&mut Self>,
243 cx: &mut Context<'_>,
244 buf: hyper::rt::ReadBufCursor<'_>,
245 ) -> Poll<Result<(), io::Error>> {
246 std::pin::Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
247 }
248}
249
250impl<T> hyper::rt::Write for InstrumentedStream<T>
251where
252 T: Connection + hyper::rt::Read + hyper::rt::Write + Unpin,
253{
254 fn poll_write(
255 self: std::pin::Pin<&mut Self>,
256 cx: &mut Context<'_>,
257 buf: &[u8],
258 ) -> Poll<Result<usize, io::Error>> {
259 std::pin::Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
260 }
261
262 fn poll_flush(
263 self: std::pin::Pin<&mut Self>,
264 cx: &mut Context<'_>,
265 ) -> Poll<Result<(), io::Error>> {
266 std::pin::Pin::new(&mut self.get_mut().inner).poll_flush(cx)
267 }
268
269 fn poll_shutdown(
270 self: std::pin::Pin<&mut Self>,
271 cx: &mut Context<'_>,
272 ) -> Poll<Result<(), io::Error>> {
273 std::pin::Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
274 }
275
276 fn is_write_vectored(&self) -> bool {
277 self.inner.is_write_vectored()
278 }
279
280 fn poll_write_vectored(
281 self: std::pin::Pin<&mut Self>,
282 cx: &mut Context<'_>,
283 bufs: &[io::IoSlice<'_>],
284 ) -> Poll<Result<usize, io::Error>> {
285 std::pin::Pin::new(&mut self.get_mut().inner).poll_write_vectored(cx, bufs)
286 }
287}
288
289impl<T> Service<Destination> for InstrumentedConnector<T>
290where
291 T: Service<Destination>,
292 T::Response: Connection + hyper::rt::Read + hyper::rt::Write + Send + Unpin + 'static,
293 T::Future: Send + 'static,
294 T::Error: Into<BoxError>,
295{
296 type Response = InstrumentedStream<T::Response>;
297 type Error = BoxError;
298 type Future = std::pin::Pin<
299 Box<dyn Future<Output = Result<InstrumentedStream<T::Response>, BoxError>> + Send>,
300 >;
301
302 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
303 self.inner.poll_ready(cx).map_err(Into::into)
304 }
305
306 fn call(&mut self, dst: Destination) -> Self::Future {
307 let future = self.inner.call(dst);
308 Box::pin(async move {
309 let stream = future.await.map_err(|error| -> BoxError { error })?;
310 Ok(InstrumentedStream::from_maybe_https_stream(stream))
311 })
312 }
313}
314
315pub type Connector = InstrumentedConnector<ServoHttpConnector>;
316pub type TlsConfig = ClientConfig;
317
318#[derive(Clone, Debug, Default)]
319struct CertificateErrorOverrideManagerInternal {
320 certificates_failing_to_verify: HashMap<ServerName<'static>, CertificateDer<'static>>,
323 overrides: Vec<CertificateDer<'static>>,
326}
327
328#[derive(Clone, Debug, Default)]
333pub struct CertificateErrorOverrideManager(Arc<Mutex<CertificateErrorOverrideManagerInternal>>);
334
335impl CertificateErrorOverrideManager {
336 pub fn new() -> Self {
337 Self(Default::default())
338 }
339
340 pub fn add_override(&self, certificate: &CertificateDer<'static>) {
343 self.0.lock().overrides.push(certificate.clone());
344 }
345
346 pub(crate) fn remove_certificate_failing_verification(
350 &self,
351 host: &str,
352 ) -> Option<CertificateDer<'static>> {
353 let server_name = match ServerName::try_from(host) {
354 Ok(name) => name.to_owned(),
355 Err(error) => {
356 warn!("Could not convert host string into RustTLS ServerName: {error:?}");
357 return None;
358 },
359 };
360 self.0
361 .lock()
362 .certificates_failing_to_verify
363 .remove(&server_name)
364 }
365}
366
367#[derive(Clone, Debug, Default)]
368pub enum CACertificates<'de> {
369 #[default]
370 Default,
371 Override(Vec<CertificateDer<'de>>),
372}
373
374pub fn create_tls_config(
381 ca_certificates: CACertificates<'static>,
382 ignore_certificate_errors: bool,
383 override_manager: CertificateErrorOverrideManager,
384) -> TlsConfig {
385 let verifier = CertificateVerificationOverrideVerifier::new(
386 ca_certificates,
387 ignore_certificate_errors,
388 override_manager,
389 );
390 rustls::ClientConfig::builder()
391 .dangerous()
392 .with_custom_certificate_verifier(Arc::new(verifier))
393 .with_no_client_auth()
394}
395
396#[derive(Clone)]
397struct TokioExecutor {}
398
399impl<F> Executor<F> for TokioExecutor
400where
401 F: Future<Output = ()> + 'static + std::marker::Send,
402{
403 fn execute(&self, fut: F) {
404 spawn_task(fut);
405 }
406}
407
408#[derive(Debug)]
409struct CertificateVerificationOverrideVerifier {
410 main_verifier: Arc<dyn ServerCertVerifier>,
411 ignore_certificate_errors: bool,
412 override_manager: CertificateErrorOverrideManager,
413}
414
415impl CertificateVerificationOverrideVerifier {
416 fn new(
417 ca_certficates: CACertificates<'static>,
418 ignore_certificate_errors: bool,
419 override_manager: CertificateErrorOverrideManager,
420 ) -> Self {
421 let use_webpki_roots = cfg!(target_os = "android") || pref!(network_use_webpki_roots);
430 let main_verifier = if !use_webpki_roots {
431 let crypto_provider = CryptoProvider::get_default()
432 .unwrap_or(&Arc::new(aws_lc_rs::default_provider()))
433 .clone();
434 let verifier = match ca_certficates {
435 CACertificates::Default => rustls_platform_verifier::Verifier::new(crypto_provider),
436 CACertificates::Override(_certificates) => {
439 #[cfg(target_os = "android")]
440 unreachable!("Android should always use the WebPKI verifier.");
441 #[cfg(not(target_os = "android"))]
442 rustls_platform_verifier::Verifier::new_with_extra_roots(
443 _certificates,
444 crypto_provider,
445 )
446 },
447 }
448 .expect("Could not initialize platform certificate verifier");
449 Arc::new(verifier) as Arc<dyn ServerCertVerifier>
450 } else {
451 let mut root_store =
452 rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
453 match ca_certficates {
454 CACertificates::Default => {},
455 CACertificates::Override(certificates) => {
456 for certificate in certificates {
457 if root_store.add(certificate).is_err() {
458 log::error!("Could not add an override certificate.");
459 }
460 }
461 },
462 }
463 rustls::client::WebPkiServerVerifier::builder(root_store.into())
464 .build()
465 .expect("Could not initialize platform certificate verifier.")
466 as Arc<dyn ServerCertVerifier>
467 };
468
469 Self {
470 main_verifier,
471 ignore_certificate_errors,
472 override_manager,
473 }
474 }
475}
476
477impl rustls::client::danger::ServerCertVerifier for CertificateVerificationOverrideVerifier {
478 fn verify_tls12_signature(
479 &self,
480 message: &[u8],
481 cert: &CertificateDer<'_>,
482 dss: &rustls::DigitallySignedStruct,
483 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
484 self.main_verifier
485 .verify_tls12_signature(message, cert, dss)
486 }
487
488 fn verify_tls13_signature(
489 &self,
490 message: &[u8],
491 cert: &CertificateDer<'_>,
492 dss: &rustls::DigitallySignedStruct,
493 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
494 self.main_verifier
495 .verify_tls13_signature(message, cert, dss)
496 }
497
498 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
499 self.main_verifier.supported_verify_schemes()
500 }
501
502 fn verify_server_cert(
503 &self,
504 end_entity: &CertificateDer<'_>,
505 intermediates: &[CertificateDer<'_>],
506 server_name: &ServerName<'_>,
507 ocsp_response: &[u8],
508 now: UnixTime,
509 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
510 let error = match self.main_verifier.verify_server_cert(
511 end_entity,
512 intermediates,
513 server_name,
514 ocsp_response,
515 now,
516 ) {
517 Ok(result) => return Ok(result),
518 Err(error) => error,
519 };
520
521 if self.ignore_certificate_errors {
522 warn!("Ignoring certficate error: {error:?}");
523 return Ok(rustls::client::danger::ServerCertVerified::assertion());
524 }
525
526 for cert_with_exception in &*self.override_manager.0.lock().overrides {
528 if *end_entity == *cert_with_exception {
529 return Ok(rustls::client::danger::ServerCertVerified::assertion());
530 }
531 }
532 self.override_manager
533 .0
534 .lock()
535 .certificates_failing_to_verify
536 .insert(server_name.to_owned(), end_entity.clone().into_owned());
537 Err(error)
538 }
539}
540
541pub type BoxedBody = BoxBody<Bytes, hyper::Error>;
542
543#[derive(Debug)]
544pub enum ConnectionError {
546 HttpError(String),
547 ProxyError(String),
549}
550
551impl std::fmt::Display for ConnectionError {
552 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
553 write!(f, "{self:?}")
554 }
555}
556
557impl std::error::Error for ConnectionError {}
558
559#[derive(Clone)]
560pub struct ProxyConnector {
563 client: ServoHttpConnector,
565 matcher: std::sync::Arc<hyper_util::client::proxy::matcher::Matcher>,
567}
568
569impl ProxyConnector {
570 fn new() -> Self {
571 let matcher_builder = hyper_util::client::proxy::matcher::Matcher::builder()
572 .http(servo_config::pref!(network_http_proxy_uri))
573 .https(servo_config::pref!(network_https_proxy_uri))
574 .no(servo_config::pref!(network_http_no_proxy));
575 ProxyConnector {
576 client: ServoHttpConnector::new(),
577 matcher: std::sync::Arc::new(matcher_builder.build()),
578 }
579 }
580}
581
582impl Service<Destination> for ProxyConnector {
584 type Response = TokioIo<TcpStream>;
585 type Error = ConnectionError;
586 type Future =
587 std::pin::Pin<Box<dyn Future<Output = Result<TokioIo<TcpStream>, ConnectionError>> + Send>>;
588
589 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
590 self.client
591 .poll_ready(cx)
592 .map_err(|e| ConnectionError::ProxyError(format!("{e}")))
593 }
594
595 fn call(&mut self, req: Destination) -> Self::Future {
596 match self.matcher.intercept(&req) {
597 Some(intercept) => Box::pin(
598 Tunnel::new(intercept.uri().clone(), self.client.clone())
599 .call(req)
600 .map_err(|e| ConnectionError::ProxyError(format!("{e}"))),
601 ),
602 None => Box::pin(
603 self.client
604 .call(req)
605 .map_err(|e| ConnectionError::ProxyError(format!("{e}"))),
606 ),
607 }
608 }
609}
610
611pub type ServoClient = Client<InstrumentedConnector<ProxyConnector>, BoxedBody>;
612
613pub fn create_http_client(tls_config: TlsConfig) -> ServoClient {
614 let connector = hyper_rustls::HttpsConnectorBuilder::new()
615 .with_tls_config(tls_config)
616 .https_or_http()
617 .enable_http1()
618 .enable_http2()
619 .wrap_connector(ProxyConnector::new());
620
621 Client::builder(TokioExecutor {})
622 .http1_title_case_headers(true)
623 .build(InstrumentedConnector::from(connector))
624}