hyper_rustls/
connector.rs1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5use std::{fmt, io};
6
7use http::Uri;
8use hyper::rt;
9use hyper_util::client::legacy::connect::Connection;
10use hyper_util::rt::TokioIo;
11use rustls::pki_types::ServerName;
12use tokio_rustls::TlsConnector;
13use tower_service::Service;
14
15use crate::stream::MaybeHttpsStream;
16
17pub(crate) mod builder;
18
19type BoxError = Box<dyn std::error::Error + Send + Sync>;
20
21#[derive(Clone)]
23pub struct HttpsConnector<T> {
24 force_https: bool,
25 http: T,
26 tls_config: Arc<rustls::ClientConfig>,
27 server_name_resolver: Arc<dyn ResolveServerName + Sync + Send>,
28}
29
30impl<T> HttpsConnector<T> {
31 pub fn builder() -> builder::ConnectorBuilder<builder::WantsTlsConfig> {
35 builder::ConnectorBuilder::new()
36 }
37
38 pub fn new(
42 http: T,
43 tls_config: impl Into<Arc<rustls::ClientConfig>>,
44 force_https: bool,
45 server_name_resolver: Arc<dyn ResolveServerName + Send + Sync>,
46 ) -> Self {
47 Self {
48 http,
49 tls_config: tls_config.into(),
50 force_https,
51 server_name_resolver,
52 }
53 }
54
55 pub fn enforce_https(&mut self) {
59 self.force_https = true;
60 }
61}
62
63impl<T> Service<Uri> for HttpsConnector<T>
64where
65 T: Service<Uri>,
66 T::Response: Connection + rt::Read + rt::Write + Send + Unpin + 'static,
67 T::Future: Send + 'static,
68 T::Error: Into<BoxError>,
69{
70 type Response = MaybeHttpsStream<T::Response>;
71 type Error = BoxError;
72
73 #[allow(clippy::type_complexity)]
74 type Future =
75 Pin<Box<dyn Future<Output = Result<MaybeHttpsStream<T::Response>, BoxError>> + Send>>;
76
77 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
78 match self.http.poll_ready(cx) {
79 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
80 Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
81 Poll::Pending => Poll::Pending,
82 }
83 }
84
85 fn call(&mut self, dst: Uri) -> Self::Future {
86 match dst.scheme() {
89 Some(scheme) if scheme == &http::uri::Scheme::HTTP && !self.force_https => {
90 let future = self.http.call(dst);
91 return Box::pin(async move {
92 Ok(MaybeHttpsStream::Http(future.await.map_err(Into::into)?))
93 });
94 }
95 Some(scheme) if scheme != &http::uri::Scheme::HTTPS => {
96 let message = format!("unsupported scheme {scheme}");
97 return Box::pin(async move { Err(io::Error::other(message).into()) });
98 }
99 Some(_) => {}
100 None => return Box::pin(async move { Err(io::Error::other("missing scheme").into()) }),
101 };
102
103 let cfg = self.tls_config.clone();
104 let hostname = match self.server_name_resolver.resolve(&dst) {
105 Ok(hostname) => hostname,
106 Err(e) => {
107 return Box::pin(async move { Err(e) });
108 }
109 };
110
111 let connecting_future = self.http.call(dst);
112 Box::pin(async move {
113 let tcp = connecting_future
114 .await
115 .map_err(Into::into)?;
116 Ok(MaybeHttpsStream::Https(TokioIo::new(
117 TlsConnector::from(cfg)
118 .connect(hostname, TokioIo::new(tcp))
119 .await
120 .map_err(io::Error::other)?,
121 )))
122 })
123 }
124}
125
126impl<H, C> From<(H, C)> for HttpsConnector<H>
127where
128 C: Into<Arc<rustls::ClientConfig>>,
129{
130 fn from((http, cfg): (H, C)) -> Self {
131 Self {
132 force_https: false,
133 http,
134 tls_config: cfg.into(),
135 server_name_resolver: Arc::new(DefaultServerNameResolver::default()),
136 }
137 }
138}
139
140impl<T> fmt::Debug for HttpsConnector<T> {
141 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
142 f.debug_struct("HttpsConnector")
143 .field("force_https", &self.force_https)
144 .finish()
145 }
146}
147
148#[derive(Default)]
150pub struct DefaultServerNameResolver(());
151
152impl ResolveServerName for DefaultServerNameResolver {
153 fn resolve(
154 &self,
155 uri: &Uri,
156 ) -> Result<ServerName<'static>, Box<dyn std::error::Error + Sync + Send>> {
157 let mut hostname = uri.host().unwrap_or_default();
158
159 if let Some(trimmed) = hostname
161 .strip_prefix('[')
162 .and_then(|h| h.strip_suffix(']'))
163 {
164 hostname = trimmed;
165 }
166
167 ServerName::try_from(hostname.to_string()).map_err(|e| Box::new(e) as _)
168 }
169}
170
171pub struct FixedServerNameResolver {
173 name: ServerName<'static>,
174}
175
176impl FixedServerNameResolver {
177 pub fn new(name: ServerName<'static>) -> Self {
179 Self { name }
180 }
181}
182
183impl ResolveServerName for FixedServerNameResolver {
184 fn resolve(
185 &self,
186 _: &Uri,
187 ) -> Result<ServerName<'static>, Box<dyn std::error::Error + Sync + Send>> {
188 Ok(self.name.clone())
189 }
190}
191
192impl<F, E> ResolveServerName for F
193where
194 F: Fn(&Uri) -> Result<ServerName<'static>, E>,
195 E: Into<Box<dyn std::error::Error + Sync + Send>>,
196{
197 fn resolve(
198 &self,
199 uri: &Uri,
200 ) -> Result<ServerName<'static>, Box<dyn std::error::Error + Sync + Send>> {
201 self(uri).map_err(Into::into)
202 }
203}
204
205pub trait ResolveServerName {
207 fn resolve(
209 &self,
210 uri: &Uri,
211 ) -> Result<ServerName<'static>, Box<dyn std::error::Error + Sync + Send>>;
212}
213
214#[cfg(all(
215 test,
216 any(feature = "ring", feature = "aws-lc-rs"),
217 any(
218 feature = "rustls-native-certs",
219 feature = "webpki-roots",
220 feature = "rustls-platform-verifier",
221 )
222))]
223mod tests {
224 use std::future::poll_fn;
225
226 use http::Uri;
227 use hyper_util::rt::TokioIo;
228 use tokio::net::TcpStream;
229 use tower_service::Service;
230
231 use super::*;
232 use crate::{ConfigBuilderExt, HttpsConnectorBuilder, MaybeHttpsStream};
233
234 #[tokio::test]
235 async fn connects_https() {
236 connect(Allow::Any, Scheme::Https)
237 .await
238 .unwrap();
239 }
240
241 #[tokio::test]
242 async fn connects_http() {
243 connect(Allow::Any, Scheme::Http)
244 .await
245 .unwrap();
246 }
247
248 #[tokio::test]
249 async fn connects_https_only() {
250 connect(Allow::Https, Scheme::Https)
251 .await
252 .unwrap();
253 }
254
255 #[tokio::test]
256 async fn enforces_https_only() {
257 let message = connect(Allow::Https, Scheme::Http)
258 .await
259 .unwrap_err()
260 .to_string();
261
262 assert_eq!(message, "unsupported scheme http");
263 }
264
265 async fn connect(
266 allow: Allow,
267 scheme: Scheme,
268 ) -> Result<MaybeHttpsStream<TokioIo<TcpStream>>, BoxError> {
269 let config_builder = rustls::ClientConfig::builder();
270 cfg_if::cfg_if! {
271 if #[cfg(feature = "rustls-platform-verifier")] {
272 let config_builder = config_builder.try_with_platform_verifier()?;
273 } else if #[cfg(feature = "rustls-native-certs")] {
274 let config_builder = config_builder.with_native_roots().unwrap();
275 } else if #[cfg(feature = "webpki-roots")] {
276 let config_builder = config_builder.with_webpki_roots();
277 }
278 }
279 let config = config_builder.with_no_client_auth();
280
281 let builder = HttpsConnectorBuilder::new().with_tls_config(config);
282 let mut service = match allow {
283 Allow::Https => builder.https_only(),
284 Allow::Any => builder.https_or_http(),
285 }
286 .enable_http1()
287 .build();
288
289 poll_fn(|cx| service.poll_ready(cx)).await?;
290 service
291 .call(Uri::from_static(match scheme {
292 Scheme::Https => "https://google.com",
293 Scheme::Http => "http://google.com",
294 }))
295 .await
296 }
297
298 enum Allow {
299 Https,
300 Any,
301 }
302
303 enum Scheme {
304 Https,
305 Http,
306 }
307}