1use std::collections::hash_map::HashMap;
6use std::convert::TryFrom;
7use std::sync::{Arc, Mutex};
8
9use futures::Future;
10use futures::task::{Context, Poll};
11use http::uri::{Authority, Uri as Destination};
12use http_body_util::combinators::BoxBody;
13use hyper::body::Bytes;
14use hyper::rt::Executor;
15use hyper_rustls::HttpsConnector as HyperRustlsHttpsConnector;
16use hyper_util::client::legacy::Client;
17use hyper_util::client::legacy::connect::HttpConnector as HyperHttpConnector;
18use log::warn;
19use rustls::client::WebPkiServerVerifier;
20use rustls::{ClientConfig, RootCertStore};
21use rustls_pki_types::{CertificateDer, ServerName, UnixTime};
22use tower_service::Service;
23
24use crate::async_runtime::spawn_task;
25use crate::hosts::replace_host;
26
27pub const BUF_SIZE: usize = 32768;
28
29#[derive(Clone)]
30pub struct ServoHttpConnector {
31 inner: HyperHttpConnector,
32}
33
34impl ServoHttpConnector {
35 fn new() -> ServoHttpConnector {
36 let mut inner = HyperHttpConnector::new();
37 inner.enforce_http(false);
38 inner.set_happy_eyeballs_timeout(None);
39 ServoHttpConnector { inner }
40 }
41}
42
43impl Service<Destination> for ServoHttpConnector {
44 type Response = <HyperHttpConnector as Service<Destination>>::Response;
45 type Error = <HyperHttpConnector as Service<Destination>>::Error;
46 type Future = <HyperHttpConnector as Service<Destination>>::Future;
47
48 fn call(&mut self, dest: Destination) -> Self::Future {
49 let mut new_dest = dest.clone();
51 let mut parts = dest.into_parts();
52
53 if let Some(auth) = parts.authority {
54 let host = auth.host();
55 let host = replace_host(host);
56
57 let authority = if let Some(port) = auth.port() {
58 format!("{}:{}", host, port.as_str())
59 } else {
60 (*host).to_string()
61 };
62
63 if let Ok(authority) = Authority::from_maybe_shared(authority) {
64 parts.authority = Some(authority);
65 if let Ok(dest) = Destination::from_parts(parts) {
66 new_dest = dest
67 }
68 }
69 }
70
71 self.inner.call(new_dest)
72 }
73
74 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
75 Ok(()).into()
76 }
77}
78
79pub type Connector = HyperRustlsHttpsConnector<ServoHttpConnector>;
80pub type TlsConfig = ClientConfig;
81
82#[derive(Clone, Debug, Default)]
83struct CertificateErrorOverrideManagerInternal {
84 certificates_failing_to_verify: HashMap<ServerName<'static>, CertificateDer<'static>>,
87 overrides: Vec<CertificateDer<'static>>,
90}
91
92#[derive(Clone, Debug, Default)]
97pub struct CertificateErrorOverrideManager(Arc<Mutex<CertificateErrorOverrideManagerInternal>>);
98
99impl CertificateErrorOverrideManager {
100 pub fn new() -> Self {
101 Self(Default::default())
102 }
103
104 pub fn add_override(&self, certificate: &CertificateDer<'static>) {
107 self.0.lock().unwrap().overrides.push(certificate.clone());
108 }
109
110 pub(crate) fn remove_certificate_failing_verification(
114 &self,
115 host: &str,
116 ) -> Option<CertificateDer<'static>> {
117 let server_name = match ServerName::try_from(host) {
118 Ok(name) => name.to_owned(),
119 Err(error) => {
120 warn!("Could not convert host string into RustTLS ServerName: {error:?}");
121 return None;
122 },
123 };
124 self.0
125 .lock()
126 .unwrap()
127 .certificates_failing_to_verify
128 .remove(&server_name)
129 }
130}
131
132#[derive(Clone, Debug)]
133pub enum CACertificates {
134 Default,
135 Override(RootCertStore),
136}
137
138pub fn create_tls_config(
145 ca_certificates: CACertificates,
146 ignore_certificate_errors: bool,
147 override_manager: CertificateErrorOverrideManager,
148) -> TlsConfig {
149 let verifier = CertificateVerificationOverrideVerifier::new(
150 ca_certificates,
151 ignore_certificate_errors,
152 override_manager,
153 );
154 rustls::ClientConfig::builder()
155 .dangerous()
156 .with_custom_certificate_verifier(Arc::new(verifier))
157 .with_no_client_auth()
158}
159
160#[derive(Clone)]
161struct TokioExecutor {}
162
163impl<F> Executor<F> for TokioExecutor
164where
165 F: Future<Output = ()> + 'static + std::marker::Send,
166{
167 fn execute(&self, fut: F) {
168 spawn_task(fut);
169 }
170}
171
172#[derive(Debug)]
173struct CertificateVerificationOverrideVerifier {
174 webpki_verifier: Arc<WebPkiServerVerifier>,
175 ignore_certificate_errors: bool,
176 override_manager: CertificateErrorOverrideManager,
177}
178
179impl CertificateVerificationOverrideVerifier {
180 fn new(
181 ca_certficates: CACertificates,
182 ignore_certificate_errors: bool,
183 override_manager: CertificateErrorOverrideManager,
184 ) -> Self {
185 let root_cert_store = match ca_certficates {
186 CACertificates::Default => rustls::RootCertStore {
187 roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
188 },
189 CACertificates::Override(root_cert_store) => root_cert_store,
190 };
191
192 Self {
193 webpki_verifier: WebPkiServerVerifier::builder(root_cert_store.into())
196 .build()
197 .unwrap(),
198 ignore_certificate_errors,
199 override_manager,
200 }
201 }
202}
203
204impl rustls::client::danger::ServerCertVerifier for CertificateVerificationOverrideVerifier {
205 fn verify_tls12_signature(
206 &self,
207 message: &[u8],
208 cert: &CertificateDer<'_>,
209 dss: &rustls::DigitallySignedStruct,
210 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
211 self.webpki_verifier
212 .verify_tls12_signature(message, cert, dss)
213 }
214
215 fn verify_tls13_signature(
216 &self,
217 message: &[u8],
218 cert: &CertificateDer<'_>,
219 dss: &rustls::DigitallySignedStruct,
220 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
221 self.webpki_verifier
222 .verify_tls13_signature(message, cert, dss)
223 }
224
225 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
226 self.webpki_verifier.supported_verify_schemes()
227 }
228
229 fn verify_server_cert(
230 &self,
231 end_entity: &CertificateDer<'_>,
232 intermediates: &[CertificateDer<'_>],
233 server_name: &ServerName<'_>,
234 ocsp_response: &[u8],
235 now: UnixTime,
236 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
237 let error = match self.webpki_verifier.verify_server_cert(
238 end_entity,
239 intermediates,
240 server_name,
241 ocsp_response,
242 now,
243 ) {
244 Ok(result) => return Ok(result),
245 Err(error) => error,
246 };
247
248 if self.ignore_certificate_errors {
249 warn!("Ignoring certficate error: {error:?}");
250 return Ok(rustls::client::danger::ServerCertVerified::assertion());
251 }
252
253 for cert_with_exception in &*self.override_manager.0.lock().unwrap().overrides {
255 if *end_entity == *cert_with_exception {
256 return Ok(rustls::client::danger::ServerCertVerified::assertion());
257 }
258 }
259 self.override_manager
260 .0
261 .lock()
262 .unwrap()
263 .certificates_failing_to_verify
264 .insert(server_name.to_owned(), end_entity.clone().into_owned());
265 Err(error)
266 }
267}
268
269pub type BoxedBody = BoxBody<Bytes, hyper::Error>;
270
271pub fn create_http_client(tls_config: TlsConfig) -> Client<Connector, BoxedBody> {
272 let connector = hyper_rustls::HttpsConnectorBuilder::new()
273 .with_tls_config(tls_config)
274 .https_or_http()
275 .enable_http1()
276 .enable_http2()
277 .wrap_connector(ServoHttpConnector::new());
278
279 Client::builder(TokioExecutor {})
280 .http1_title_case_headers(true)
281 .build(connector)
282}