net/
connector.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
4
5use std::collections::hash_map::HashMap;
6use std::convert::TryFrom;
7use std::sync::Arc;
8use std::time::Duration;
9use std::{fmt, io};
10
11use futures::task::{Context, Poll};
12use futures::{Future, TryFutureExt};
13use http::uri::{Authority, Uri as Destination};
14use http_body_util::combinators::BoxBody;
15use hyper::body::Bytes;
16use hyper::rt::Executor;
17use hyper_rustls::{HttpsConnector as HyperRustlsHttpsConnector, MaybeHttpsStream};
18use hyper_util::client::legacy::Client;
19use hyper_util::client::legacy::connect::proxy::Tunnel;
20use hyper_util::client::legacy::connect::{
21    Connected, Connection, HttpConnector as HyperHttpConnector,
22};
23use hyper_util::rt::TokioIo;
24use log::warn;
25use parking_lot::Mutex;
26use rustls::client::danger::ServerCertVerifier;
27use rustls::client::{ClientConnection, EchStatus};
28use rustls::crypto::{CryptoProvider, aws_lc_rs};
29use rustls::{ClientConfig, ProtocolVersion};
30use rustls_pki_types::{CertificateDer, ServerName, UnixTime};
31use servo_config::pref;
32use tokio::net::TcpStream;
33use tower::Service;
34
35use crate::async_runtime::spawn_task;
36use crate::hosts::replace_host;
37
38pub const BUF_SIZE: usize = 32768;
39
40/// ALPN identifier for HTTP/2 (RFC 7540 ยง3.1).
41pub const ALPN_H2: &str = "h2";
42
43#[derive(Clone)]
44pub struct ServoHttpConnector {
45    inner: HyperHttpConnector,
46}
47
48impl ServoHttpConnector {
49    fn new() -> ServoHttpConnector {
50        let mut inner = HyperHttpConnector::new();
51        inner.enforce_http(false);
52        inner.set_happy_eyeballs_timeout(None);
53        inner.set_connect_timeout(Some(Duration::from_secs(pref!(network_connection_timeout))));
54        ServoHttpConnector { inner }
55    }
56}
57
58impl Service<Destination> for ServoHttpConnector {
59    type Response = TokioIo<TcpStream>;
60    type Error = ConnectionError;
61    type Future =
62        std::pin::Pin<Box<dyn Future<Output = Result<TokioIo<TcpStream>, ConnectionError>> + Send>>;
63
64    fn call(&mut self, dest: Destination) -> Self::Future {
65        // Perform host replacement when making the actual TCP connection.
66        let mut new_dest = dest.clone();
67        let mut parts = dest.into_parts();
68
69        if let Some(auth) = parts.authority {
70            let host = auth.host();
71            let host = replace_host(host);
72
73            let authority = if let Some(port) = auth.port() {
74                format!("{}:{}", host, port.as_str())
75            } else {
76                (*host).to_string()
77            };
78
79            if let Ok(authority) = Authority::from_maybe_shared(authority) {
80                parts.authority = Some(authority);
81                if let Ok(dest) = Destination::from_parts(parts) {
82                    new_dest = dest
83                }
84            }
85        }
86
87        Box::pin(
88            self.inner
89                .call(new_dest)
90                .map_err(|e| ConnectionError::HttpError(format!("{e}"))),
91        )
92    }
93
94    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
95        Ok(()).into()
96    }
97}
98
99type BoxError = Box<dyn std::error::Error + Send + Sync>;
100
101#[derive(Clone)]
102pub struct InstrumentedConnector<T> {
103    inner: HyperRustlsHttpsConnector<T>,
104}
105
106impl<T> InstrumentedConnector<T> {
107    fn new(inner: HyperRustlsHttpsConnector<T>) -> Self {
108        Self { inner }
109    }
110}
111
112impl<T> From<HyperRustlsHttpsConnector<T>> for InstrumentedConnector<T> {
113    fn from(inner: HyperRustlsHttpsConnector<T>) -> Self {
114        Self::new(inner)
115    }
116}
117
118pub struct InstrumentedStream<T> {
119    inner: MaybeHttpsStream<T>,
120    tls_info: Option<TlsHandshakeInfo>,
121}
122
123impl<T: Unpin> Unpin for InstrumentedStream<T> {}
124
125impl<T> fmt::Debug for InstrumentedStream<T> {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        f.debug_struct("InstrumentedStream")
128            .field("tls_info", &self.tls_info)
129            .finish()
130    }
131}
132
133#[derive(Clone, Debug)]
134pub struct TlsHandshakeInfo {
135    pub protocol_version: Option<String>,
136    pub cipher_suite: Option<String>,
137    pub kea_group_name: Option<String>,
138    pub signature_scheme_name: Option<String>,
139    pub alpn_protocol: Option<String>,
140    pub certificate_chain_der: Vec<Vec<u8>>,
141    pub used_ech: bool,
142}
143
144impl TlsHandshakeInfo {
145    fn from_connection(conn: &ClientConnection) -> Self {
146        let protocol_version = conn.protocol_version().map(protocol_version_to_string);
147        let cipher_suite = conn
148            .negotiated_cipher_suite()
149            .map(|suite| format!("{:?}", suite.suite()));
150        let kea_group_name = conn
151            .negotiated_key_exchange_group()
152            .map(|group| format!("{:?}", group.name()));
153        let certificate_chain_der = conn
154            .peer_certificates()
155            .map(|certs| certs.iter().map(|cert| cert.as_ref().to_vec()).collect())
156            .unwrap_or_default();
157        let alpn_protocol = conn
158            .alpn_protocol()
159            .map(|proto| String::from_utf8_lossy(proto).into_owned());
160        let used_ech = matches!(conn.ech_status(), EchStatus::Accepted);
161
162        Self {
163            protocol_version,
164            cipher_suite,
165            kea_group_name,
166            signature_scheme_name: None,
167            alpn_protocol,
168            certificate_chain_der,
169            used_ech,
170        }
171    }
172}
173
174fn protocol_version_to_string(version: ProtocolVersion) -> String {
175    match version {
176        ProtocolVersion::TLSv1_3 => "TLS 1.3".to_string(),
177        ProtocolVersion::TLSv1_2 => "TLS 1.2".to_string(),
178        ProtocolVersion::TLSv1_1 => "TLS 1.1".to_string(),
179        ProtocolVersion::TLSv1_0 => "TLS 1.0".to_string(),
180        ProtocolVersion::SSLv2 => "SSL 2.0".to_string(),
181        ProtocolVersion::SSLv3 => "SSL 3.0".to_string(),
182        ProtocolVersion::DTLSv1_0 => "DTLS 1.0".to_string(),
183        ProtocolVersion::DTLSv1_2 => "DTLS 1.2".to_string(),
184        ProtocolVersion::DTLSv1_3 => "DTLS 1.3".to_string(),
185        ProtocolVersion::Unknown(v) => format!("Unknown(0x{v:04x})"),
186        _ => format!("{version:?}"),
187    }
188}
189
190impl<T> InstrumentedStream<T>
191where
192    T: Connection + hyper::rt::Read + hyper::rt::Write + Unpin,
193{
194    fn from_maybe_https_stream(stream: MaybeHttpsStream<T>) -> Self {
195        match stream {
196            MaybeHttpsStream::Http(inner) => Self {
197                inner: MaybeHttpsStream::Http(inner),
198                tls_info: None,
199            },
200            MaybeHttpsStream::Https(tls_stream) => {
201                let (_tcp, tls) = tls_stream.inner().get_ref();
202                let tls_info = TlsHandshakeInfo::from_connection(tls);
203
204                Self {
205                    inner: MaybeHttpsStream::Https(tls_stream),
206                    tls_info: Some(tls_info),
207                }
208            },
209        }
210    }
211}
212
213impl<T> Connection for InstrumentedStream<T>
214where
215    T: Connection + hyper::rt::Read + hyper::rt::Write + Unpin,
216{
217    fn connected(&self) -> Connected {
218        let connected = match &self.inner {
219            MaybeHttpsStream::Http(stream) => stream.connected(),
220            MaybeHttpsStream::Https(stream) => {
221                let (tcp, tls) = stream.inner().get_ref();
222                if tls.alpn_protocol() == Some(ALPN_H2.as_bytes()) {
223                    tcp.inner().connected().negotiated_h2()
224                } else {
225                    tcp.inner().connected()
226                }
227            },
228        };
229        if let Some(info) = &self.tls_info {
230            connected.extra(info.clone())
231        } else {
232            connected
233        }
234    }
235}
236
237impl<T> hyper::rt::Read for InstrumentedStream<T>
238where
239    T: Connection + hyper::rt::Read + hyper::rt::Write + Unpin,
240{
241    fn poll_read(
242        self: std::pin::Pin<&mut Self>,
243        cx: &mut Context<'_>,
244        buf: hyper::rt::ReadBufCursor<'_>,
245    ) -> Poll<Result<(), io::Error>> {
246        std::pin::Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
247    }
248}
249
250impl<T> hyper::rt::Write for InstrumentedStream<T>
251where
252    T: Connection + hyper::rt::Read + hyper::rt::Write + Unpin,
253{
254    fn poll_write(
255        self: std::pin::Pin<&mut Self>,
256        cx: &mut Context<'_>,
257        buf: &[u8],
258    ) -> Poll<Result<usize, io::Error>> {
259        std::pin::Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
260    }
261
262    fn poll_flush(
263        self: std::pin::Pin<&mut Self>,
264        cx: &mut Context<'_>,
265    ) -> Poll<Result<(), io::Error>> {
266        std::pin::Pin::new(&mut self.get_mut().inner).poll_flush(cx)
267    }
268
269    fn poll_shutdown(
270        self: std::pin::Pin<&mut Self>,
271        cx: &mut Context<'_>,
272    ) -> Poll<Result<(), io::Error>> {
273        std::pin::Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
274    }
275
276    fn is_write_vectored(&self) -> bool {
277        self.inner.is_write_vectored()
278    }
279
280    fn poll_write_vectored(
281        self: std::pin::Pin<&mut Self>,
282        cx: &mut Context<'_>,
283        bufs: &[io::IoSlice<'_>],
284    ) -> Poll<Result<usize, io::Error>> {
285        std::pin::Pin::new(&mut self.get_mut().inner).poll_write_vectored(cx, bufs)
286    }
287}
288
289impl<T> Service<Destination> for InstrumentedConnector<T>
290where
291    T: Service<Destination>,
292    T::Response: Connection + hyper::rt::Read + hyper::rt::Write + Send + Unpin + 'static,
293    T::Future: Send + 'static,
294    T::Error: Into<BoxError>,
295{
296    type Response = InstrumentedStream<T::Response>;
297    type Error = BoxError;
298    type Future = std::pin::Pin<
299        Box<dyn Future<Output = Result<InstrumentedStream<T::Response>, BoxError>> + Send>,
300    >;
301
302    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
303        self.inner.poll_ready(cx).map_err(Into::into)
304    }
305
306    fn call(&mut self, dst: Destination) -> Self::Future {
307        let future = self.inner.call(dst);
308        Box::pin(async move {
309            let stream = future.await.map_err(|error| -> BoxError { error })?;
310            Ok(InstrumentedStream::from_maybe_https_stream(stream))
311        })
312    }
313}
314
315pub type Connector = InstrumentedConnector<ServoHttpConnector>;
316pub type TlsConfig = ClientConfig;
317
318#[derive(Clone, Debug, Default)]
319struct CertificateErrorOverrideManagerInternal {
320    /// A mapping of certificates and their hosts, which have seen certificate errors.
321    /// This is used to later create an override in this [CertificateErrorOverrideManager].
322    certificates_failing_to_verify: HashMap<ServerName<'static>, CertificateDer<'static>>,
323    /// A list of certificates that should be accepted despite encountering verification
324    /// errors.
325    overrides: Vec<CertificateDer<'static>>,
326}
327
328/// This data structure is used to track certificate verification errors and overrides.
329/// It tracks:
330///  - A list of [Certificate]s with verification errors mapped by their [ServerName]
331///  - A list of [Certificate]s for which to ignore verification errors.
332#[derive(Clone, Debug, Default)]
333pub struct CertificateErrorOverrideManager(Arc<Mutex<CertificateErrorOverrideManagerInternal>>);
334
335impl CertificateErrorOverrideManager {
336    pub fn new() -> Self {
337        Self(Default::default())
338    }
339
340    /// Add a certificate to this manager's list of certificates for which to ignore
341    /// validation errors.
342    pub fn add_override(&self, certificate: &CertificateDer<'static>) {
343        self.0.lock().overrides.push(certificate.clone());
344    }
345
346    /// Given the a string representation of a sever host name, remove information about
347    /// a [Certificate] with verification errors. If a certificate with
348    /// verification errors was found, return it, otherwise None.
349    pub(crate) fn remove_certificate_failing_verification(
350        &self,
351        host: &str,
352    ) -> Option<CertificateDer<'static>> {
353        let server_name = match ServerName::try_from(host) {
354            Ok(name) => name.to_owned(),
355            Err(error) => {
356                warn!("Could not convert host string into RustTLS ServerName: {error:?}");
357                return None;
358            },
359        };
360        self.0
361            .lock()
362            .certificates_failing_to_verify
363            .remove(&server_name)
364    }
365}
366
367#[derive(Clone, Debug, Default)]
368pub enum CACertificates<'de> {
369    #[default]
370    Default,
371    Override(Vec<CertificateDer<'de>>),
372}
373
374/// Create a [TlsConfig] to use for managing a HTTP connection. This currently creates
375/// a rustls [ClientConfig].
376///
377/// FIXME: The `ignore_certificate_errors` argument ignores all certificate errors. This
378/// is used when running the WPT tests, because rustls currently rejects the WPT certificiate.
379/// See <https://github.com/servo/servo/issues/30080>
380pub fn create_tls_config(
381    ca_certificates: CACertificates<'static>,
382    ignore_certificate_errors: bool,
383    override_manager: CertificateErrorOverrideManager,
384) -> TlsConfig {
385    let verifier = CertificateVerificationOverrideVerifier::new(
386        ca_certificates,
387        ignore_certificate_errors,
388        override_manager,
389    );
390    rustls::ClientConfig::builder()
391        .dangerous()
392        .with_custom_certificate_verifier(Arc::new(verifier))
393        .with_no_client_auth()
394}
395
396#[derive(Clone)]
397struct TokioExecutor {}
398
399impl<F> Executor<F> for TokioExecutor
400where
401    F: Future<Output = ()> + 'static + std::marker::Send,
402{
403    fn execute(&self, fut: F) {
404        spawn_task(fut);
405    }
406}
407
408#[derive(Debug)]
409struct CertificateVerificationOverrideVerifier {
410    main_verifier: Arc<dyn ServerCertVerifier>,
411    ignore_certificate_errors: bool,
412    override_manager: CertificateErrorOverrideManager,
413}
414
415impl CertificateVerificationOverrideVerifier {
416    fn new(
417        ca_certficates: CACertificates<'static>,
418        ignore_certificate_errors: bool,
419        override_manager: CertificateErrorOverrideManager,
420    ) -> Self {
421        // From <https://github.com/rustls/rustls-platform-verifier/blob/main/README.md>:
422        // > Some manual setup is required, outside of cargo, to use this crate on
423        // > Android. In order to use Android's certificate verifier, the crate needs to
424        // > call into the JVM. A small Kotlin component must be included in your app's
425        // > build to support rustls-platform-verifier.
426        //
427        // Since we cannot count on embedders to do this setup, just stick with webpki roots
428        // on Android.
429        let use_webpki_roots = cfg!(target_os = "android") || pref!(network_use_webpki_roots);
430        let main_verifier = if !use_webpki_roots {
431            let crypto_provider = CryptoProvider::get_default()
432                .unwrap_or(&Arc::new(aws_lc_rs::default_provider()))
433                .clone();
434            let verifier = match ca_certficates {
435                CACertificates::Default => rustls_platform_verifier::Verifier::new(crypto_provider),
436                // Android doesn't support `Verifier::new_with_extra_roots`, but currently Android
437                // never uses the platform verifier at all.
438                CACertificates::Override(_certificates) => {
439                    #[cfg(target_os = "android")]
440                    unreachable!("Android should always use the WebPKI verifier.");
441                    #[cfg(not(target_os = "android"))]
442                    rustls_platform_verifier::Verifier::new_with_extra_roots(
443                        _certificates,
444                        crypto_provider,
445                    )
446                },
447            }
448            .expect("Could not initialize platform certificate verifier");
449            Arc::new(verifier) as Arc<dyn ServerCertVerifier>
450        } else {
451            let mut root_store =
452                rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
453            match ca_certficates {
454                CACertificates::Default => {},
455                CACertificates::Override(certificates) => {
456                    for certificate in certificates {
457                        if root_store.add(certificate).is_err() {
458                            log::error!("Could not add an override certificate.");
459                        }
460                    }
461                },
462            }
463            rustls::client::WebPkiServerVerifier::builder(root_store.into())
464                .build()
465                .expect("Could not initialize platform certificate verifier.")
466                as Arc<dyn ServerCertVerifier>
467        };
468
469        Self {
470            main_verifier,
471            ignore_certificate_errors,
472            override_manager,
473        }
474    }
475}
476
477impl rustls::client::danger::ServerCertVerifier for CertificateVerificationOverrideVerifier {
478    fn verify_tls12_signature(
479        &self,
480        message: &[u8],
481        cert: &CertificateDer<'_>,
482        dss: &rustls::DigitallySignedStruct,
483    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
484        self.main_verifier
485            .verify_tls12_signature(message, cert, dss)
486    }
487
488    fn verify_tls13_signature(
489        &self,
490        message: &[u8],
491        cert: &CertificateDer<'_>,
492        dss: &rustls::DigitallySignedStruct,
493    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
494        self.main_verifier
495            .verify_tls13_signature(message, cert, dss)
496    }
497
498    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
499        self.main_verifier.supported_verify_schemes()
500    }
501
502    fn verify_server_cert(
503        &self,
504        end_entity: &CertificateDer<'_>,
505        intermediates: &[CertificateDer<'_>],
506        server_name: &ServerName<'_>,
507        ocsp_response: &[u8],
508        now: UnixTime,
509    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
510        let error = match self.main_verifier.verify_server_cert(
511            end_entity,
512            intermediates,
513            server_name,
514            ocsp_response,
515            now,
516        ) {
517            Ok(result) => return Ok(result),
518            Err(error) => error,
519        };
520
521        if self.ignore_certificate_errors {
522            warn!("Ignoring certficate error: {error:?}");
523            return Ok(rustls::client::danger::ServerCertVerified::assertion());
524        }
525
526        // If there's an override for this certificate, just accept it.
527        for cert_with_exception in &*self.override_manager.0.lock().overrides {
528            if *end_entity == *cert_with_exception {
529                return Ok(rustls::client::danger::ServerCertVerified::assertion());
530            }
531        }
532        self.override_manager
533            .0
534            .lock()
535            .certificates_failing_to_verify
536            .insert(server_name.to_owned(), end_entity.clone().into_owned());
537        Err(error)
538    }
539}
540
541pub type BoxedBody = BoxBody<Bytes, hyper::Error>;
542
543#[derive(Debug)]
544/// The error type for the MaybeProxyConnector
545pub enum ConnectionError {
546    HttpError(String),
547    // It looks like currently the type is not exported.
548    ProxyError(String),
549}
550
551impl std::fmt::Display for ConnectionError {
552    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
553        write!(f, "{self:?}")
554    }
555}
556
557impl std::error::Error for ConnectionError {}
558
559#[derive(Clone)]
560/// A proxy connector. This will automatically open a proxy connection if the uri matches the proxy uri.
561/// Also respects 'no_proxy'.
562pub struct ProxyConnector {
563    /// A client without proxy for `no_proxy` matches.
564    client: ServoHttpConnector,
565    /// Matcher to see if we should forward to the proxy or not.
566    matcher: std::sync::Arc<hyper_util::client::proxy::matcher::Matcher>,
567}
568
569impl ProxyConnector {
570    fn new() -> Self {
571        let matcher_builder = hyper_util::client::proxy::matcher::Matcher::builder()
572            .http(servo_config::pref!(network_http_proxy_uri))
573            .https(servo_config::pref!(network_https_proxy_uri))
574            .no(servo_config::pref!(network_http_no_proxy));
575        ProxyConnector {
576            client: ServoHttpConnector::new(),
577            matcher: std::sync::Arc::new(matcher_builder.build()),
578        }
579    }
580}
581
582// Just forward everything to the inner type except that we modify the errors returned.
583impl Service<Destination> for ProxyConnector {
584    type Response = TokioIo<TcpStream>;
585    type Error = ConnectionError;
586    type Future =
587        std::pin::Pin<Box<dyn Future<Output = Result<TokioIo<TcpStream>, ConnectionError>> + Send>>;
588
589    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
590        self.client
591            .poll_ready(cx)
592            .map_err(|e| ConnectionError::ProxyError(format!("{e}")))
593    }
594
595    fn call(&mut self, req: Destination) -> Self::Future {
596        match self.matcher.intercept(&req) {
597            Some(intercept) => Box::pin(
598                Tunnel::new(intercept.uri().clone(), self.client.clone())
599                    .call(req)
600                    .map_err(|e| ConnectionError::ProxyError(format!("{e}"))),
601            ),
602            None => Box::pin(
603                self.client
604                    .call(req)
605                    .map_err(|e| ConnectionError::ProxyError(format!("{e}"))),
606            ),
607        }
608    }
609}
610
611pub type ServoClient = Client<InstrumentedConnector<ProxyConnector>, BoxedBody>;
612
613pub fn create_http_client(tls_config: TlsConfig) -> ServoClient {
614    let connector = hyper_rustls::HttpsConnectorBuilder::new()
615        .with_tls_config(tls_config)
616        .https_or_http()
617        .enable_http1()
618        .enable_http2()
619        .wrap_connector(ProxyConnector::new());
620
621    Client::builder(TokioExecutor {})
622        .http1_title_case_headers(true)
623        .build(InstrumentedConnector::from(connector))
624}