hyper_rustls/connector/
builder.rs1use std::sync::Arc;
2
3use hyper_util::client::legacy::connect::HttpConnector;
4#[cfg(any(
5    feature = "rustls-native-certs",
6    feature = "rustls-platform-verifier",
7    feature = "webpki-roots"
8))]
9use rustls::crypto::CryptoProvider;
10use rustls::ClientConfig;
11
12use super::{DefaultServerNameResolver, HttpsConnector, ResolveServerName};
13#[cfg(any(
14    feature = "rustls-native-certs",
15    feature = "webpki-roots",
16    feature = "rustls-platform-verifier"
17))]
18use crate::config::ConfigBuilderExt;
19use pki_types::ServerName;
20
21pub struct ConnectorBuilder<State>(State);
42
43pub struct WantsTlsConfig(());
45
46impl ConnectorBuilder<WantsTlsConfig> {
47    pub fn new() -> Self {
49        Self(WantsTlsConfig(()))
50    }
51
52    pub fn with_tls_config(self, config: ClientConfig) -> ConnectorBuilder<WantsSchemes> {
61        assert!(
62            config.alpn_protocols.is_empty(),
63            "ALPN protocols should not be pre-defined"
64        );
65        ConnectorBuilder(WantsSchemes { tls_config: config })
66    }
67
68    #[cfg(all(
73        any(feature = "ring", feature = "aws-lc-rs"),
74        feature = "rustls-platform-verifier"
75    ))]
76    pub fn with_platform_verifier(self) -> ConnectorBuilder<WantsSchemes> {
77        self.try_with_platform_verifier()
78            .expect("failure to initialize platform verifier")
79    }
80
81    #[cfg(all(
86        any(feature = "ring", feature = "aws-lc-rs"),
87        feature = "rustls-platform-verifier"
88    ))]
89    pub fn try_with_platform_verifier(
90        self,
91    ) -> Result<ConnectorBuilder<WantsSchemes>, rustls::Error> {
92        Ok(self.with_tls_config(
93            ClientConfig::builder()
94                .try_with_platform_verifier()?
95                .with_no_client_auth(),
96        ))
97    }
98
99    #[cfg(feature = "rustls-platform-verifier")]
103    pub fn with_provider_and_platform_verifier(
104        self,
105        provider: impl Into<Arc<CryptoProvider>>,
106    ) -> std::io::Result<ConnectorBuilder<WantsSchemes>> {
107        Ok(self.with_tls_config(
108            ClientConfig::builder_with_provider(provider.into())
109                .with_safe_default_protocol_versions()
110                .and_then(|builder| builder.try_with_platform_verifier())
111                .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?
112                .with_no_client_auth(),
113        ))
114    }
115
116    #[cfg(all(
121        any(feature = "ring", feature = "aws-lc-rs"),
122        feature = "rustls-native-certs"
123    ))]
124    pub fn with_native_roots(self) -> std::io::Result<ConnectorBuilder<WantsSchemes>> {
125        Ok(self.with_tls_config(
126            ClientConfig::builder()
127                .with_native_roots()?
128                .with_no_client_auth(),
129        ))
130    }
131
132    #[cfg(feature = "rustls-native-certs")]
136    pub fn with_provider_and_native_roots(
137        self,
138        provider: impl Into<Arc<CryptoProvider>>,
139    ) -> std::io::Result<ConnectorBuilder<WantsSchemes>> {
140        Ok(self.with_tls_config(
141            ClientConfig::builder_with_provider(provider.into())
142                .with_safe_default_protocol_versions()
143                .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?
144                .with_native_roots()?
145                .with_no_client_auth(),
146        ))
147    }
148
149    #[cfg(all(any(feature = "ring", feature = "aws-lc-rs"), feature = "webpki-roots"))]
154    pub fn with_webpki_roots(self) -> ConnectorBuilder<WantsSchemes> {
155        self.with_tls_config(
156            ClientConfig::builder()
157                .with_webpki_roots()
158                .with_no_client_auth(),
159        )
160    }
161
162    #[cfg(feature = "webpki-roots")]
167    pub fn with_provider_and_webpki_roots(
168        self,
169        provider: impl Into<Arc<CryptoProvider>>,
170    ) -> Result<ConnectorBuilder<WantsSchemes>, rustls::Error> {
171        Ok(self.with_tls_config(
172            ClientConfig::builder_with_provider(provider.into())
173                .with_safe_default_protocol_versions()?
174                .with_webpki_roots()
175                .with_no_client_auth(),
176        ))
177    }
178}
179
180impl Default for ConnectorBuilder<WantsTlsConfig> {
181    fn default() -> Self {
182        Self::new()
183    }
184}
185
186pub struct WantsSchemes {
189    tls_config: ClientConfig,
190}
191
192impl ConnectorBuilder<WantsSchemes> {
193    pub fn https_only(self) -> ConnectorBuilder<WantsProtocols1> {
197        ConnectorBuilder(WantsProtocols1 {
198            tls_config: self.0.tls_config,
199            https_only: true,
200            server_name_resolver: None,
201        })
202    }
203
204    pub fn https_or_http(self) -> ConnectorBuilder<WantsProtocols1> {
209        ConnectorBuilder(WantsProtocols1 {
210            tls_config: self.0.tls_config,
211            https_only: false,
212            server_name_resolver: None,
213        })
214    }
215}
216
217pub struct WantsProtocols1 {
222    tls_config: ClientConfig,
223    https_only: bool,
224    server_name_resolver: Option<Arc<dyn ResolveServerName + Sync + Send>>,
225}
226
227impl WantsProtocols1 {
228    fn wrap_connector<H>(self, conn: H) -> HttpsConnector<H> {
229        HttpsConnector {
230            force_https: self.https_only,
231            http: conn,
232            tls_config: std::sync::Arc::new(self.tls_config),
233            server_name_resolver: self
234                .server_name_resolver
235                .unwrap_or_else(|| Arc::new(DefaultServerNameResolver::default())),
236        }
237    }
238
239    fn build(self) -> HttpsConnector<HttpConnector> {
240        let mut http = HttpConnector::new();
241        http.enforce_http(false);
243        self.wrap_connector(http)
244    }
245}
246
247impl ConnectorBuilder<WantsProtocols1> {
248    #[cfg(feature = "http1")]
252    pub fn enable_http1(self) -> ConnectorBuilder<WantsProtocols2> {
253        ConnectorBuilder(WantsProtocols2 { inner: self.0 })
254    }
255
256    #[cfg(feature = "http2")]
260    pub fn enable_http2(mut self) -> ConnectorBuilder<WantsProtocols3> {
261        self.0.tls_config.alpn_protocols = vec![b"h2".to_vec()];
262        ConnectorBuilder(WantsProtocols3 {
263            inner: self.0,
264            enable_http1: false,
265        })
266    }
267
268    #[cfg(feature = "http2")]
273    pub fn enable_all_versions(mut self) -> ConnectorBuilder<WantsProtocols3> {
274        #[cfg(feature = "http1")]
275        let alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
276        #[cfg(not(feature = "http1"))]
277        let alpn_protocols = vec![b"h2".to_vec()];
278
279        self.0.tls_config.alpn_protocols = alpn_protocols;
280        ConnectorBuilder(WantsProtocols3 {
281            inner: self.0,
282            enable_http1: cfg!(feature = "http1"),
283        })
284    }
285
286    pub fn with_server_name_resolver(
295        mut self,
296        resolver: impl ResolveServerName + 'static + Sync + Send,
297    ) -> Self {
298        self.0.server_name_resolver = Some(Arc::new(resolver));
299        self
300    }
301
302    #[deprecated(
312        since = "0.27.1",
313        note = "use Self::with_server_name_resolver with FixedServerNameResolver instead"
314    )]
315    pub fn with_server_name(self, mut override_server_name: String) -> Self {
316        if let Some(trimmed) = override_server_name
318            .strip_prefix('[')
319            .and_then(|s| s.strip_suffix(']'))
320        {
321            override_server_name = trimmed.to_string();
322        }
323
324        self.with_server_name_resolver(move |_: &_| {
325            ServerName::try_from(override_server_name.clone())
326        })
327    }
328}
329
330pub struct WantsProtocols2 {
337    inner: WantsProtocols1,
338}
339
340impl ConnectorBuilder<WantsProtocols2> {
341    #[cfg(feature = "http2")]
345    pub fn enable_http2(mut self) -> ConnectorBuilder<WantsProtocols3> {
346        self.0.inner.tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
347        ConnectorBuilder(WantsProtocols3 {
348            inner: self.0.inner,
349            enable_http1: true,
350        })
351    }
352
353    pub fn build(self) -> HttpsConnector<HttpConnector> {
355        self.0.inner.build()
356    }
357
358    pub fn wrap_connector<H>(self, conn: H) -> HttpsConnector<H> {
360        self.0.inner.wrap_connector(conn)
365    }
366}
367
368#[cfg(feature = "http2")]
374pub struct WantsProtocols3 {
375    inner: WantsProtocols1,
376    #[allow(dead_code)]
378    enable_http1: bool,
379}
380
381#[cfg(feature = "http2")]
382impl ConnectorBuilder<WantsProtocols3> {
383    pub fn build(self) -> HttpsConnector<HttpConnector> {
385        self.0.inner.build()
386    }
387
388    pub fn wrap_connector<H>(self, conn: H) -> HttpsConnector<H> {
390        self.0.inner.wrap_connector(conn)
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    #[test]
401    #[cfg(all(feature = "webpki-roots", feature = "http1"))]
402    fn test_builder() {
403        ensure_global_state();
404        let _connector = super::ConnectorBuilder::new()
405            .with_webpki_roots()
406            .https_only()
407            .enable_http1()
408            .build();
409    }
410
411    #[test]
412    #[cfg(feature = "http1")]
413    #[should_panic(expected = "ALPN protocols should not be pre-defined")]
414    fn test_reject_predefined_alpn() {
415        ensure_global_state();
416        let roots = rustls::RootCertStore::empty();
417        let mut config_with_alpn = rustls::ClientConfig::builder()
418            .with_root_certificates(roots)
419            .with_no_client_auth();
420        config_with_alpn.alpn_protocols = vec![b"fancyprotocol".to_vec()];
421        let _connector = super::ConnectorBuilder::new()
422            .with_tls_config(config_with_alpn)
423            .https_only()
424            .enable_http1()
425            .build();
426    }
427
428    #[test]
429    #[cfg(all(feature = "http1", feature = "http2"))]
430    fn test_alpn() {
431        ensure_global_state();
432        let roots = rustls::RootCertStore::empty();
433        let tls_config = rustls::ClientConfig::builder()
434            .with_root_certificates(roots)
435            .with_no_client_auth();
436        let connector = super::ConnectorBuilder::new()
437            .with_tls_config(tls_config.clone())
438            .https_only()
439            .enable_http1()
440            .build();
441        assert!(connector
442            .tls_config
443            .alpn_protocols
444            .is_empty());
445        let connector = super::ConnectorBuilder::new()
446            .with_tls_config(tls_config.clone())
447            .https_only()
448            .enable_http2()
449            .build();
450        assert_eq!(&connector.tls_config.alpn_protocols, &[b"h2".to_vec()]);
451        let connector = super::ConnectorBuilder::new()
452            .with_tls_config(tls_config.clone())
453            .https_only()
454            .enable_http1()
455            .enable_http2()
456            .build();
457        assert_eq!(
458            &connector.tls_config.alpn_protocols,
459            &[b"h2".to_vec(), b"http/1.1".to_vec()]
460        );
461        let connector = super::ConnectorBuilder::new()
462            .with_tls_config(tls_config)
463            .https_only()
464            .enable_all_versions()
465            .build();
466        assert_eq!(
467            &connector.tls_config.alpn_protocols,
468            &[b"h2".to_vec(), b"http/1.1".to_vec()]
469        );
470    }
471
472    #[test]
473    #[cfg(all(not(feature = "http1"), feature = "http2"))]
474    fn test_alpn_http2() {
475        let roots = rustls::RootCertStore::empty();
476        let tls_config = rustls::ClientConfig::builder()
477            .with_safe_defaults()
478            .with_root_certificates(roots)
479            .with_no_client_auth();
480        let connector = super::ConnectorBuilder::new()
481            .with_tls_config(tls_config.clone())
482            .https_only()
483            .enable_http2()
484            .build();
485        assert_eq!(&connector.tls_config.alpn_protocols, &[b"h2".to_vec()]);
486        let connector = super::ConnectorBuilder::new()
487            .with_tls_config(tls_config)
488            .https_only()
489            .enable_all_versions()
490            .build();
491        assert_eq!(&connector.tls_config.alpn_protocols, &[b"h2".to_vec()]);
492    }
493
494    fn ensure_global_state() {
495        #[cfg(feature = "ring")]
496        let _ = rustls::crypto::ring::default_provider().install_default();
497        #[cfg(feature = "aws-lc-rs")]
498        let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
499    }
500}