1use std::collections::hash_map::HashMap;
6use std::convert::TryFrom;
7use std::sync::{Arc, LazyLock};
8use std::time::Duration;
9use std::{fmt, io};
10
11use futures::task::{Context, Poll};
12use futures::{Future, TryFutureExt};
13use http::uri::{Authority, Uri as Destination};
14use http_body_util::combinators::BoxBody;
15use hyper::body::Bytes;
16use hyper::rt::Executor;
17use hyper_rustls::{HttpsConnector as HyperRustlsHttpsConnector, MaybeHttpsStream};
18use hyper_util::client::legacy::Client;
19use hyper_util::client::legacy::connect::proxy::Tunnel;
20use hyper_util::client::legacy::connect::{
21 Connected, Connection, HttpConnector as HyperHttpConnector,
22};
23use hyper_util::rt::TokioIo;
24use log::warn;
25use parking_lot::Mutex;
26use rustls::client::danger::ServerCertVerifier;
27use rustls::client::{ClientConnection, EchStatus};
28use rustls::crypto::{CryptoProvider, aws_lc_rs};
29use rustls::{ClientConfig, ProtocolVersion};
30use rustls_pki_types::{CertificateDer, ServerName, UnixTime};
31use servo_config::pref;
32use tokio::net::TcpStream;
33use tower::Service;
34
35use crate::async_runtime::spawn_task;
36use crate::hosts::replace_host;
37
38pub const BUF_SIZE: usize = 32768;
39
40pub const ALPN_H2: &str = "h2";
42
43#[derive(Clone)]
44pub struct ServoHttpConnector {
45 inner: HyperHttpConnector,
46}
47
48impl ServoHttpConnector {
49 fn new() -> ServoHttpConnector {
50 let mut inner = HyperHttpConnector::new();
51 inner.enforce_http(false);
52 inner.set_happy_eyeballs_timeout(None);
53 inner.set_connect_timeout(Some(Duration::from_secs(pref!(network_connection_timeout))));
54 ServoHttpConnector { inner }
55 }
56}
57
58impl Service<Destination> for ServoHttpConnector {
59 type Response = TokioIo<TcpStream>;
60 type Error = ConnectionError;
61 type Future =
62 std::pin::Pin<Box<dyn Future<Output = Result<TokioIo<TcpStream>, ConnectionError>> + Send>>;
63
64 fn call(&mut self, dest: Destination) -> Self::Future {
65 let mut new_dest = dest.clone();
67 let mut parts = dest.into_parts();
68
69 if let Some(auth) = parts.authority {
70 let host = auth.host();
71 let host = replace_host(host);
72
73 let authority = if let Some(port) = auth.port() {
74 format!("{}:{}", host, port.as_str())
75 } else {
76 (*host).to_string()
77 };
78
79 if let Ok(authority) = Authority::from_maybe_shared(authority) {
80 parts.authority = Some(authority);
81 if let Ok(dest) = Destination::from_parts(parts) {
82 new_dest = dest
83 }
84 }
85 }
86
87 Box::pin(
88 self.inner
89 .call(new_dest)
90 .map_err(|e| ConnectionError::HttpError(format!("{e}"))),
91 )
92 }
93
94 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
95 Ok(()).into()
96 }
97}
98
99type BoxError = Box<dyn std::error::Error + Send + Sync>;
100
101#[derive(Clone)]
102pub struct InstrumentedConnector<T> {
103 inner: HyperRustlsHttpsConnector<T>,
104}
105
106impl<T> InstrumentedConnector<T> {
107 fn new(inner: HyperRustlsHttpsConnector<T>) -> Self {
108 Self { inner }
109 }
110}
111
112impl<T> From<HyperRustlsHttpsConnector<T>> for InstrumentedConnector<T> {
113 fn from(inner: HyperRustlsHttpsConnector<T>) -> Self {
114 Self::new(inner)
115 }
116}
117
118pub struct InstrumentedStream<T> {
119 inner: MaybeHttpsStream<T>,
120 tls_info: Option<TlsHandshakeInfo>,
121}
122
123impl<T: Unpin> Unpin for InstrumentedStream<T> {}
124
125impl<T> fmt::Debug for InstrumentedStream<T> {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.debug_struct("InstrumentedStream")
128 .field("tls_info", &self.tls_info)
129 .finish()
130 }
131}
132
133#[derive(Clone, Debug)]
134pub struct TlsHandshakeInfo {
135 pub protocol_version: Option<String>,
136 pub cipher_suite: Option<String>,
137 pub kea_group_name: Option<String>,
138 pub signature_scheme_name: Option<String>,
139 pub alpn_protocol: Option<String>,
140 pub certificate_chain_der: Vec<Vec<u8>>,
141 pub used_ech: bool,
142}
143
144impl TlsHandshakeInfo {
145 fn from_connection(conn: &ClientConnection) -> Self {
146 let protocol_version = conn.protocol_version().map(protocol_version_to_string);
147 let cipher_suite = conn
148 .negotiated_cipher_suite()
149 .map(|suite| format!("{:?}", suite.suite()));
150 let kea_group_name = conn
151 .negotiated_key_exchange_group()
152 .map(|group| format!("{:?}", group.name()));
153 let certificate_chain_der = conn
154 .peer_certificates()
155 .map(|certs| certs.iter().map(|cert| cert.as_ref().to_vec()).collect())
156 .unwrap_or_default();
157 let alpn_protocol = conn
158 .alpn_protocol()
159 .map(|proto| String::from_utf8_lossy(proto).into_owned());
160 let used_ech = matches!(conn.ech_status(), EchStatus::Accepted);
161
162 Self {
163 protocol_version,
164 cipher_suite,
165 kea_group_name,
166 signature_scheme_name: None,
167 alpn_protocol,
168 certificate_chain_der,
169 used_ech,
170 }
171 }
172}
173
174fn protocol_version_to_string(version: ProtocolVersion) -> String {
175 match version {
176 ProtocolVersion::TLSv1_3 => "TLS 1.3".to_string(),
177 ProtocolVersion::TLSv1_2 => "TLS 1.2".to_string(),
178 ProtocolVersion::TLSv1_1 => "TLS 1.1".to_string(),
179 ProtocolVersion::TLSv1_0 => "TLS 1.0".to_string(),
180 ProtocolVersion::SSLv2 => "SSL 2.0".to_string(),
181 ProtocolVersion::SSLv3 => "SSL 3.0".to_string(),
182 ProtocolVersion::DTLSv1_0 => "DTLS 1.0".to_string(),
183 ProtocolVersion::DTLSv1_2 => "DTLS 1.2".to_string(),
184 ProtocolVersion::DTLSv1_3 => "DTLS 1.3".to_string(),
185 ProtocolVersion::Unknown(v) => format!("Unknown(0x{v:04x})"),
186 _ => format!("{version:?}"),
187 }
188}
189
190impl<T> InstrumentedStream<T>
191where
192 T: Connection + hyper::rt::Read + hyper::rt::Write + Unpin,
193{
194 fn from_maybe_https_stream(stream: MaybeHttpsStream<T>) -> Self {
195 match stream {
196 MaybeHttpsStream::Http(inner) => Self {
197 inner: MaybeHttpsStream::Http(inner),
198 tls_info: None,
199 },
200 MaybeHttpsStream::Https(tls_stream) => {
201 let (_tcp, tls) = tls_stream.inner().get_ref();
202 let tls_info = TlsHandshakeInfo::from_connection(tls);
203
204 Self {
205 inner: MaybeHttpsStream::Https(tls_stream),
206 tls_info: Some(tls_info),
207 }
208 },
209 }
210 }
211}
212
213impl<T> Connection for InstrumentedStream<T>
214where
215 T: Connection + hyper::rt::Read + hyper::rt::Write + Unpin,
216{
217 fn connected(&self) -> Connected {
218 let connected = match &self.inner {
219 MaybeHttpsStream::Http(stream) => stream.connected(),
220 MaybeHttpsStream::Https(stream) => {
221 let (tcp, tls) = stream.inner().get_ref();
222 if tls.alpn_protocol() == Some(ALPN_H2.as_bytes()) {
223 tcp.inner().connected().negotiated_h2()
224 } else {
225 tcp.inner().connected()
226 }
227 },
228 };
229 if let Some(info) = &self.tls_info {
230 connected.extra(info.clone())
231 } else {
232 connected
233 }
234 }
235}
236
237impl<T> hyper::rt::Read for InstrumentedStream<T>
238where
239 T: Connection + hyper::rt::Read + hyper::rt::Write + Unpin,
240{
241 fn poll_read(
242 self: std::pin::Pin<&mut Self>,
243 cx: &mut Context<'_>,
244 buf: hyper::rt::ReadBufCursor<'_>,
245 ) -> Poll<Result<(), io::Error>> {
246 std::pin::Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
247 }
248}
249
250impl<T> hyper::rt::Write for InstrumentedStream<T>
251where
252 T: Connection + hyper::rt::Read + hyper::rt::Write + Unpin,
253{
254 fn poll_write(
255 self: std::pin::Pin<&mut Self>,
256 cx: &mut Context<'_>,
257 buf: &[u8],
258 ) -> Poll<Result<usize, io::Error>> {
259 std::pin::Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
260 }
261
262 fn poll_flush(
263 self: std::pin::Pin<&mut Self>,
264 cx: &mut Context<'_>,
265 ) -> Poll<Result<(), io::Error>> {
266 std::pin::Pin::new(&mut self.get_mut().inner).poll_flush(cx)
267 }
268
269 fn poll_shutdown(
270 self: std::pin::Pin<&mut Self>,
271 cx: &mut Context<'_>,
272 ) -> Poll<Result<(), io::Error>> {
273 std::pin::Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
274 }
275
276 fn is_write_vectored(&self) -> bool {
277 self.inner.is_write_vectored()
278 }
279
280 fn poll_write_vectored(
281 self: std::pin::Pin<&mut Self>,
282 cx: &mut Context<'_>,
283 bufs: &[io::IoSlice<'_>],
284 ) -> Poll<Result<usize, io::Error>> {
285 std::pin::Pin::new(&mut self.get_mut().inner).poll_write_vectored(cx, bufs)
286 }
287}
288
289impl<T> Service<Destination> for InstrumentedConnector<T>
290where
291 T: Service<Destination>,
292 T::Response: Connection + hyper::rt::Read + hyper::rt::Write + Send + Unpin + 'static,
293 T::Future: Send + 'static,
294 T::Error: Into<BoxError>,
295{
296 type Response = InstrumentedStream<T::Response>;
297 type Error = BoxError;
298 type Future = std::pin::Pin<
299 Box<dyn Future<Output = Result<InstrumentedStream<T::Response>, BoxError>> + Send>,
300 >;
301
302 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
303 self.inner.poll_ready(cx).map_err(Into::into)
304 }
305
306 fn call(&mut self, dst: Destination) -> Self::Future {
307 let future = self.inner.call(dst);
308 Box::pin(async move {
309 let stream = future.await.map_err(|error| -> BoxError { error })?;
310 Ok(InstrumentedStream::from_maybe_https_stream(stream))
311 })
312 }
313}
314
315pub type Connector = InstrumentedConnector<ServoHttpConnector>;
316pub type TlsConfig = ClientConfig;
317
318#[derive(Clone, Debug, Default)]
319struct CertificateErrorOverrideManagerInternal {
320 certificates_failing_to_verify: HashMap<ServerName<'static>, CertificateDer<'static>>,
323 overrides: Vec<CertificateDer<'static>>,
326}
327
328#[derive(Clone, Debug, Default)]
333pub struct CertificateErrorOverrideManager(Arc<Mutex<CertificateErrorOverrideManagerInternal>>);
334
335impl CertificateErrorOverrideManager {
336 pub fn new() -> Self {
337 Self(Default::default())
338 }
339
340 pub fn add_override(&self, certificate: &CertificateDer<'static>) {
343 self.0.lock().overrides.push(certificate.clone());
344 }
345
346 pub(crate) fn remove_certificate_failing_verification(
350 &self,
351 host: &str,
352 ) -> Option<CertificateDer<'static>> {
353 let server_name = match ServerName::try_from(host) {
354 Ok(name) => name.to_owned(),
355 Err(error) => {
356 warn!("Could not convert host string into RustTLS ServerName: {error:?}");
357 return None;
358 },
359 };
360 self.0
361 .lock()
362 .certificates_failing_to_verify
363 .remove(&server_name)
364 }
365}
366
367#[derive(Clone, Debug, Default)]
368pub enum CACertificates<'de> {
369 #[default]
370 Default,
371 Override(Vec<CertificateDer<'de>>),
372}
373
374#[servo_tracing::instrument(skip_all)]
381pub fn create_tls_config(
382 ca_certificates: CACertificates<'static>,
383 ignore_certificate_errors: bool,
384 override_manager: CertificateErrorOverrideManager,
385) -> TlsConfig {
386 let verifier = CertificateVerificationOverrideVerifier::new(
387 ca_certificates,
388 ignore_certificate_errors,
389 override_manager,
390 );
391 rustls::ClientConfig::builder()
394 .dangerous()
395 .with_custom_certificate_verifier(Arc::new(verifier))
396 .with_no_client_auth()
397}
398
399#[derive(Clone)]
400struct TokioExecutor {}
401
402impl<F> Executor<F> for TokioExecutor
403where
404 F: Future<Output = ()> + 'static + std::marker::Send,
405{
406 fn execute(&self, fut: F) {
407 spawn_task(fut);
408 }
409}
410
411static CRYPTO_PROVIDER_CACHE: LazyLock<Arc<CryptoProvider>> = LazyLock::new(|| {
412 CryptoProvider::get_default()
413 .cloned()
414 .unwrap_or_else(|| {
417 warn!("Default crypto provider not initialized before first access in connector.");
418 Arc::new(aws_lc_rs::default_provider())
419 })
420});
421
422static RUSTLS_PLATFORM_VERIFIER_CACHE: LazyLock<Arc<rustls_platform_verifier::Verifier>> =
427 LazyLock::new(|| {
428 Arc::new(
429 rustls_platform_verifier::Verifier::new(CRYPTO_PROVIDER_CACHE.clone())
430 .expect("Could not initialize platform certificate verifier"),
431 )
432 });
433
434#[inline]
441pub fn prewarm_tls() {
442 #[servo_tracing::instrument]
443 fn prewarm_tls_impl() {
444 let mut sink = [0u8; 32];
445 let _ = CRYPTO_PROVIDER_CACHE.secure_random.fill(&mut sink);
447 }
450
451 if let Err(error) = std::thread::Builder::new()
452 .name("Net-TLS-prewarm".into())
453 .spawn(prewarm_tls_impl)
454 {
455 warn!("Failed to spawn thread to prewarm TLS: {error:?}");
456 }
457}
458
459#[derive(Debug)]
460struct CertificateVerificationOverrideVerifier {
461 main_verifier: Arc<dyn ServerCertVerifier>,
462 ignore_certificate_errors: bool,
463 override_manager: CertificateErrorOverrideManager,
464}
465
466impl CertificateVerificationOverrideVerifier {
467 fn new(
468 ca_certficates: CACertificates<'static>,
469 ignore_certificate_errors: bool,
470 override_manager: CertificateErrorOverrideManager,
471 ) -> Self {
472 let use_webpki_roots = cfg!(target_os = "android") || pref!(network_use_webpki_roots);
481 let main_verifier = if !use_webpki_roots {
482 let verifier = match ca_certficates {
483 CACertificates::Default => RUSTLS_PLATFORM_VERIFIER_CACHE.clone(),
484 CACertificates::Override(_certificates) => {
487 #[cfg(target_os = "android")]
488 unreachable!("Android should always use the WebPKI verifier.");
489 #[cfg(not(target_os = "android"))]
490 {
491 let verifier = rustls_platform_verifier::Verifier::new_with_extra_roots(
492 _certificates,
493 CRYPTO_PROVIDER_CACHE.clone(),
494 )
495 .expect("Could not initialize platform certificate verifier");
496 Arc::new(verifier)
497 }
498 },
499 };
500 verifier as Arc<dyn ServerCertVerifier>
501 } else {
502 let mut root_store =
503 rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
504 match ca_certficates {
505 CACertificates::Default => {},
506 CACertificates::Override(certificates) => {
507 for certificate in certificates {
508 if root_store.add(certificate).is_err() {
509 log::error!("Could not add an override certificate.");
510 }
511 }
512 },
513 }
514 rustls::client::WebPkiServerVerifier::builder(root_store.into())
515 .build()
516 .expect("Could not initialize platform certificate verifier.")
517 as Arc<dyn ServerCertVerifier>
518 };
519
520 Self {
521 main_verifier,
522 ignore_certificate_errors,
523 override_manager,
524 }
525 }
526}
527
528impl rustls::client::danger::ServerCertVerifier for CertificateVerificationOverrideVerifier {
529 fn verify_tls12_signature(
530 &self,
531 message: &[u8],
532 cert: &CertificateDer<'_>,
533 dss: &rustls::DigitallySignedStruct,
534 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
535 self.main_verifier
536 .verify_tls12_signature(message, cert, dss)
537 }
538
539 fn verify_tls13_signature(
540 &self,
541 message: &[u8],
542 cert: &CertificateDer<'_>,
543 dss: &rustls::DigitallySignedStruct,
544 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
545 self.main_verifier
546 .verify_tls13_signature(message, cert, dss)
547 }
548
549 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
550 self.main_verifier.supported_verify_schemes()
551 }
552
553 fn verify_server_cert(
554 &self,
555 end_entity: &CertificateDer<'_>,
556 intermediates: &[CertificateDer<'_>],
557 server_name: &ServerName<'_>,
558 ocsp_response: &[u8],
559 now: UnixTime,
560 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
561 let error = match self.main_verifier.verify_server_cert(
562 end_entity,
563 intermediates,
564 server_name,
565 ocsp_response,
566 now,
567 ) {
568 Ok(result) => return Ok(result),
569 Err(error) => error,
570 };
571
572 if self.ignore_certificate_errors {
573 warn!("Ignoring certficate error: {error:?}");
574 return Ok(rustls::client::danger::ServerCertVerified::assertion());
575 }
576
577 for cert_with_exception in &*self.override_manager.0.lock().overrides {
579 if *end_entity == *cert_with_exception {
580 return Ok(rustls::client::danger::ServerCertVerified::assertion());
581 }
582 }
583 self.override_manager
584 .0
585 .lock()
586 .certificates_failing_to_verify
587 .insert(server_name.to_owned(), end_entity.clone().into_owned());
588 Err(error)
589 }
590}
591
592pub type BoxedBody = BoxBody<Bytes, hyper::Error>;
593
594#[derive(Debug)]
595pub enum ConnectionError {
597 HttpError(String),
598 ProxyError(String),
600}
601
602impl std::fmt::Display for ConnectionError {
603 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
604 write!(f, "{self:?}")
605 }
606}
607
608impl std::error::Error for ConnectionError {}
609
610#[derive(Clone)]
611pub struct ProxyConnector {
614 client: ServoHttpConnector,
616 matcher: std::sync::Arc<hyper_util::client::proxy::matcher::Matcher>,
618}
619
620impl ProxyConnector {
621 fn new() -> Self {
622 let matcher_builder = hyper_util::client::proxy::matcher::Matcher::builder()
623 .http(servo_config::pref!(network_http_proxy_uri))
624 .https(servo_config::pref!(network_https_proxy_uri))
625 .no(servo_config::pref!(network_http_no_proxy));
626 ProxyConnector {
627 client: ServoHttpConnector::new(),
628 matcher: std::sync::Arc::new(matcher_builder.build()),
629 }
630 }
631}
632
633impl Service<Destination> for ProxyConnector {
635 type Response = TokioIo<TcpStream>;
636 type Error = ConnectionError;
637 type Future =
638 std::pin::Pin<Box<dyn Future<Output = Result<TokioIo<TcpStream>, ConnectionError>> + Send>>;
639
640 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
641 self.client
642 .poll_ready(cx)
643 .map_err(|e| ConnectionError::ProxyError(format!("{e}")))
644 }
645
646 fn call(&mut self, req: Destination) -> Self::Future {
647 match self.matcher.intercept(&req) {
648 Some(intercept) => Box::pin(
649 Tunnel::new(intercept.uri().clone(), self.client.clone())
650 .call(req)
651 .map_err(|e| ConnectionError::ProxyError(format!("{e}"))),
652 ),
653 None => Box::pin(
654 self.client
655 .call(req)
656 .map_err(|e| ConnectionError::ProxyError(format!("{e}"))),
657 ),
658 }
659 }
660}
661
662pub type ServoClient = Client<InstrumentedConnector<ProxyConnector>, BoxedBody>;
663
664pub fn create_http_client(tls_config: TlsConfig) -> ServoClient {
665 let connector = hyper_rustls::HttpsConnectorBuilder::new()
666 .with_tls_config(tls_config)
667 .https_or_http()
668 .enable_http1()
669 .enable_http2()
670 .wrap_connector(ProxyConnector::new());
671
672 Client::builder(TokioExecutor {})
673 .http1_title_case_headers(true)
674 .build(InstrumentedConnector::from(connector))
675}