1use std::future::Future;
2use std::net::SocketAddr;
3#[cfg(feature = "tls")]
4use std::path::Path;
5
6use futures_util::TryFuture;
7
8use crate::filter::Filter;
9use crate::reject::IsReject;
10use crate::reply::Reply;
11#[cfg(feature = "tls")]
12use crate::tls::TlsConfigBuilder;
13
14pub fn serve<F>(filter: F) -> Server<F, accept::LazyTcp, run::Standard>
16where
17 F: Filter + Clone + Send + Sync + 'static,
18 F::Extract: Reply,
19 F::Error: IsReject,
20{
21 Server {
22 acceptor: accept::LazyTcp,
23 pipeline: false,
24 filter,
25 runner: run::Standard,
26 }
27}
28
29#[derive(Debug)]
40pub struct Server<F, A, R> {
41 acceptor: A,
42 filter: F,
43 pipeline: bool,
44 runner: R,
45}
46
47impl<F, R> Server<F, accept::LazyTcp, R>
50where
51 F: Filter + Clone + Send + Sync + 'static,
52 <F::Future as TryFuture>::Ok: Reply,
53 <F::Future as TryFuture>::Error: IsReject,
54 R: run::Run,
55{
56 pub async fn run(self, addr: impl Into<SocketAddr>) {
64 self.bind(addr).await.run().await;
65 }
66
67 pub async fn bind(self, addr: impl Into<SocketAddr>) -> Server<F, tokio::net::TcpListener, R> {
75 let addr = addr.into();
76 let acceptor = tokio::net::TcpListener::bind(addr)
77 .await
78 .expect("failed to bind to address");
79
80 self.incoming(acceptor)
81 }
82
83 pub fn incoming<A>(self, acceptor: A) -> Server<F, A, R> {
85 Server {
86 acceptor,
87 filter: self.filter,
88 pipeline: self.pipeline,
89 runner: self.runner,
90 }
91 }
92
93 }
95
96impl<F, A, R> Server<F, A, R>
97where
98 F: Filter + Clone + Send + Sync + 'static,
99 <F::Future as TryFuture>::Ok: Reply,
100 <F::Future as TryFuture>::Error: IsReject,
101 A: accept::Accept,
102 R: run::Run,
103{
104 #[cfg(feature = "tls")]
105 pub fn tls(self) -> Server<F, accept::Tls<A>, R> {}
106
107 pub fn graceful<Fut>(self, shutdown_signal: Fut) -> Server<F, A, run::Graceful<Fut>>
124 where
125 Fut: Future<Output = ()> + Send + 'static,
126 {
127 Server {
128 acceptor: self.acceptor,
129 filter: self.filter,
130 pipeline: self.pipeline,
131 runner: run::Graceful(shutdown_signal),
132 }
133 }
134
135 pub async fn run(self) {
137 R::run(self).await;
138 }
139}
140
141#[cfg(feature = "tls")]
144impl<F, A, R> Server<F, accept::Tls<A>, R>
145where
146 F: Filter + Clone + Send + Sync + 'static,
147 <F::Future as TryFuture>::Ok: Reply,
148 <F::Future as TryFuture>::Error: IsReject,
149 A: accept::Accept,
150 R: run::Run,
151{
152 pub fn key_path(self, path: impl AsRef<Path>) -> Self {
158 self.with_tls(|tls| tls.key_path(path))
159 }
160
161 pub fn cert_path(self, path: impl AsRef<Path>) -> Self {
165 self.with_tls(|tls| tls.cert_path(path))
166 }
167
168 pub fn client_auth_optional_path(self, path: impl AsRef<Path>) -> Self {
175 self.with_tls(|tls| tls.client_auth_optional_path(path))
176 }
177
178 pub fn client_auth_required_path(self, path: impl AsRef<Path>) -> Self {
185 self.with_tls(|tls| tls.client_auth_required_path(path))
186 }
187
188 pub fn key(self, key: impl AsRef<[u8]>) -> Self {
192 self.with_tls(|tls| tls.key(key.as_ref()))
193 }
194
195 pub fn cert(self, cert: impl AsRef<[u8]>) -> Self {
199 self.with_tls(|tls| tls.cert(cert.as_ref()))
200 }
201
202 pub fn client_auth_optional(self, trust_anchor: impl AsRef<[u8]>) -> Self {
209 self.with_tls(|tls| tls.client_auth_optional(trust_anchor.as_ref()))
210 }
211
212 pub fn client_auth_required(self, trust_anchor: impl AsRef<[u8]>) -> Self {
219 self.with_tls(|tls| tls.client_auth_required(trust_anchor.as_ref()))
220 }
221
222 pub fn ocsp_resp(self, resp: impl AsRef<[u8]>) -> Self {
226 self.with_tls(|tls| tls.ocsp_resp(resp.as_ref()))
227 }
228
229 fn with_tls<Func>(self, func: Func) -> Self
230 where
231 Func: FnOnce(TlsConfigBuilder) -> TlsConfigBuilder,
232 {
233 let tls = func(tls);
234 }
235}
236
237mod accept {
238 use std::net::SocketAddr;
239
240 pub trait Accept {
241 type IO: hyper::rt::Read + hyper::rt::Write + Send + Unpin + 'static;
242 type AcceptError: std::fmt::Debug;
243 type Accepting: super::Future<Output = Result<(Self::IO, Option<SocketAddr>), Self::AcceptError>>
244 + Send
245 + 'static;
246 #[allow(async_fn_in_trait)]
247 async fn accept(&mut self) -> Result<Self::Accepting, std::io::Error>;
248 }
249
250 #[derive(Debug)]
251 pub struct LazyTcp;
252
253 impl Accept for tokio::net::TcpListener {
254 type IO = hyper_util::rt::TokioIo<tokio::net::TcpStream>;
255 type AcceptError = std::convert::Infallible;
256 type Accepting =
257 std::future::Ready<Result<(Self::IO, Option<SocketAddr>), Self::AcceptError>>;
258 async fn accept(&mut self) -> Result<Self::Accepting, std::io::Error> {
259 let (io, addr) = <tokio::net::TcpListener>::accept(self).await?;
260 Ok(std::future::ready(Ok((
261 hyper_util::rt::TokioIo::new(io),
262 Some(addr),
263 ))))
264 }
265 }
266
267 #[cfg(unix)]
268 impl Accept for tokio::net::UnixListener {
269 type IO = hyper_util::rt::TokioIo<tokio::net::UnixStream>;
270 type AcceptError = std::convert::Infallible;
271 type Accepting =
272 std::future::Ready<Result<(Self::IO, Option<SocketAddr>), Self::AcceptError>>;
273 async fn accept(&mut self) -> Result<Self::Accepting, std::io::Error> {
274 let (io, _addr) = <tokio::net::UnixListener>::accept(self).await?;
275 Ok(std::future::ready(Ok((
276 hyper_util::rt::TokioIo::new(io),
277 None,
278 ))))
279 }
280 }
281
282 #[cfg(feature = "tls")]
283 #[derive(Debug)]
284 pub struct Tls<A>(pub(super) A);
285
286 #[cfg(feature = "tls")]
287 impl<A: Accept> Accept for Tls<A> {
288 type IO = hyper_util::rt::TokioIo<tokio::net::TcpStream>;
289 type AcceptError = std::convert::Infallible;
290 type Accepting =
291 std::future::Ready<Result<(Self::IO, Option<SocketAddr>), Self::AcceptError>>;
292 async fn accept(&mut self) -> Result<Self::Accepting, std::io::Error> {
293 let (io, addr) = self.0.accept().await?;
294 Ok(std::future::ready(Ok((
295 hyper_util::rt::TokioIo::new(io),
296 addr,
297 ))))
298 }
299 }
300}
301
302mod middleware {
303 use std::net::SocketAddr;
304 use std::task::{Context, Poll};
305 use tower_service::Service;
306
307 use crate::filters::addr::RemoteAddr;
308
309 #[derive(Clone, Debug)]
310 pub(super) struct RemoteAddrService<S> {
311 inner: S,
312 remote_addr: Option<SocketAddr>,
313 }
314
315 impl<S> RemoteAddrService<S> {
316 pub(super) fn new(inner: S, remote_addr: Option<SocketAddr>) -> Self {
317 Self { inner, remote_addr }
318 }
319 }
320
321 impl<S, B> Service<http::Request<B>> for RemoteAddrService<S>
322 where
323 S: Service<http::Request<B>>,
324 {
325 type Response = S::Response;
326 type Error = S::Error;
327 type Future = S::Future;
328
329 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
330 self.inner.poll_ready(cx)
331 }
332
333 fn call(&mut self, mut req: http::Request<B>) -> Self::Future {
334 if let Some(addr) = self.remote_addr {
335 req.extensions_mut().insert(RemoteAddr(addr));
336 }
337 self.inner.call(req)
338 }
339 }
340}
341
342mod run {
343 pub trait Run {
344 #[allow(async_fn_in_trait)]
345 async fn run<F, A>(server: super::Server<F, A, Self>)
346 where
347 F: super::Filter + Clone + Send + Sync + 'static,
348 <F::Future as super::TryFuture>::Ok: super::Reply,
349 <F::Future as super::TryFuture>::Error: super::IsReject,
350 A: super::accept::Accept,
351 Self: Sized;
352 }
353
354 #[derive(Debug)]
355 pub struct Standard;
356
357 impl Run for Standard {
358 async fn run<F, A>(mut server: super::Server<F, A, Self>)
359 where
360 F: super::Filter + Clone + Send + Sync + 'static,
361 <F::Future as super::TryFuture>::Ok: super::Reply,
362 <F::Future as super::TryFuture>::Error: super::IsReject,
363 A: super::accept::Accept,
364 Self: Sized,
365 {
366 let pipeline = server.pipeline;
367 loop {
368 let accepting = match server.acceptor.accept().await {
369 Ok(fut) => fut,
370 Err(err) => {
371 handle_accept_error(err).await;
372 continue;
373 }
374 };
375 let svc = crate::service(server.filter.clone());
376 tokio::spawn(async move {
377 let (io, remote_addr) = match accepting.await {
378 Ok(pair) => pair,
379 Err(err) => {
380 tracing::debug!("server accept error: {:?}", err);
381 return;
382 }
383 };
384 let svc = super::middleware::RemoteAddrService::new(svc, remote_addr);
385 let svc = hyper_util::service::TowerToHyperService::new(svc);
386 if let Err(err) = hyper_util::server::conn::auto::Builder::new(
387 hyper_util::rt::TokioExecutor::new(),
388 )
389 .http1()
390 .pipeline_flush(pipeline)
391 .serve_connection_with_upgrades(io, svc)
392 .await
393 {
394 tracing::error!("server connection error: {:?}", err)
395 }
396 });
397 }
398 }
399 }
400
401 #[derive(Debug)]
402 pub struct Graceful<Fut>(pub(super) Fut);
403
404 impl<Fut> Run for Graceful<Fut>
405 where
406 Fut: super::Future<Output = ()> + Send + 'static,
407 {
408 async fn run<F, A>(mut server: super::Server<F, A, Self>)
409 where
410 F: super::Filter + Clone + Send + Sync + 'static,
411 <F::Future as super::TryFuture>::Ok: super::Reply,
412 <F::Future as super::TryFuture>::Error: super::IsReject,
413 A: super::accept::Accept,
414 Self: Sized,
415 {
416 use futures_util::future;
417
418 let pipeline = server.pipeline;
419 let graceful_util = hyper_util::server::graceful::GracefulShutdown::new();
420 let mut shutdown_signal = std::pin::pin!(server.runner.0);
421 loop {
422 let accept = std::pin::pin!(server.acceptor.accept());
423 let accepting = match future::select(accept, &mut shutdown_signal).await {
424 future::Either::Left((Ok(fut), _)) => fut,
425 future::Either::Left((Err(err), _)) => {
426 handle_accept_error(err).await;
427 continue;
428 }
429 future::Either::Right(((), _)) => {
430 tracing::debug!("shutdown signal received, starting graceful shutdown");
431 break;
432 }
433 };
434 let svc = crate::service(server.filter.clone());
435 let watcher = graceful_util.watcher();
436 tokio::spawn(async move {
437 let (io, remote_addr) = match accepting.await {
438 Ok(pair) => pair,
439 Err(err) => {
440 tracing::debug!("server accepting error: {:?}", err);
441 return;
442 }
443 };
444 let svc = super::middleware::RemoteAddrService::new(svc, remote_addr);
445 let svc = hyper_util::service::TowerToHyperService::new(svc);
446 let mut hyper = hyper_util::server::conn::auto::Builder::new(
447 hyper_util::rt::TokioExecutor::new(),
448 );
449 hyper.http1().pipeline_flush(pipeline);
450 let conn = hyper.serve_connection_with_upgrades(io, svc);
451 let conn = watcher.watch(conn);
452 if let Err(err) = conn.await {
453 tracing::error!("server connection error: {:?}", err)
454 }
455 });
456 }
457
458 drop(server.acceptor); graceful_util.shutdown().await;
460 }
461 }
462
463 async fn handle_accept_error(e: std::io::Error) {
465 if is_connection_error(&e) {
466 return;
467 }
468 tracing::error!("accept error: {:?}", e);
478 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
479 }
480
481 fn is_connection_error(e: &std::io::Error) -> bool {
482 matches!(
485 e.kind(),
486 std::io::ErrorKind::ConnectionRefused
487 | std::io::ErrorKind::ConnectionAborted
488 | std::io::ErrorKind::ConnectionReset
489 )
490 }
491}