Skip to main content

warp/
server.rs

1use std::future::Future;
2use std::net::SocketAddr;
3#[cfg(feature = "tls")]
4use std::path::Path;
5
6use futures_util::TryFuture;
7
8use crate::filter::Filter;
9use crate::reject::IsReject;
10use crate::reply::Reply;
11#[cfg(feature = "tls")]
12use crate::tls::TlsConfigBuilder;
13
14/// Create a `Server` with the provided `Filter`.
15pub fn serve<F>(filter: F) -> Server<F, accept::LazyTcp, run::Standard>
16where
17    F: Filter + Clone + Send + Sync + 'static,
18    F::Extract: Reply,
19    F::Error: IsReject,
20{
21    Server {
22        acceptor: accept::LazyTcp,
23        pipeline: false,
24        filter,
25        runner: run::Standard,
26    }
27}
28
29/// A warp Server ready to filter requests.
30///
31/// Construct this type using [`serve()`].
32///
33/// # Unnameable
34///
35/// This type is publicly available in the docs only.
36///
37/// It is not otherwise nameable, since it is a builder type using typestate
38/// to allow for ergonomic configuration.
39#[derive(Debug)]
40pub struct Server<F, A, R> {
41    acceptor: A,
42    filter: F,
43    pipeline: bool,
44    runner: R,
45}
46
47// ===== impl Server =====
48
49impl<F, R> Server<F, accept::LazyTcp, R>
50where
51    F: Filter + Clone + Send + Sync + 'static,
52    <F::Future as TryFuture>::Ok: Reply,
53    <F::Future as TryFuture>::Error: IsReject,
54    R: run::Run,
55{
56    /// Binds and runs this server.
57    ///
58    /// # Panics
59    ///
60    /// Panics if we are unable to bind to the provided address.
61    ///
62    /// To handle bind failures, bind a listener and call `incoming()`.
63    pub async fn run(self, addr: impl Into<SocketAddr>) {
64        self.bind(addr).await.run().await;
65    }
66
67    /// Binds this server.
68    ///
69    /// # Panics
70    ///
71    /// Panics if we are unable to bind to the provided address.
72    ///
73    /// To handle bind failures, bind a listener and call `incoming()`.
74    pub async fn bind(self, addr: impl Into<SocketAddr>) -> Server<F, tokio::net::TcpListener, R> {
75        let addr = addr.into();
76        let acceptor = tokio::net::TcpListener::bind(addr)
77            .await
78            .expect("failed to bind to address");
79
80        self.incoming(acceptor)
81    }
82
83    /// Configure the server with an acceptor of incoming connections.
84    pub fn incoming<A>(self, acceptor: A) -> Server<F, A, R> {
85        Server {
86            acceptor,
87            filter: self.filter,
88            pipeline: self.pipeline,
89            runner: self.runner,
90        }
91    }
92
93    // pub fn conn
94}
95
96impl<F, A, R> Server<F, A, R>
97where
98    F: Filter + Clone + Send + Sync + 'static,
99    <F::Future as TryFuture>::Ok: Reply,
100    <F::Future as TryFuture>::Error: IsReject,
101    A: accept::Accept,
102    R: run::Run,
103{
104    #[cfg(feature = "tls")]
105    pub fn tls(self) -> Server<F, accept::Tls<A>, R> {}
106
107    /// Add graceful shutdown support to this server.
108    ///
109    /// # Example
110    ///
111    /// ```
112    /// # async fn ex(addr: std::net::SocketAddr) {
113    /// # use warp::Filter;
114    /// # let filter = warp::any().map(|| "ok");
115    /// warp::serve(filter)
116    ///     .bind(addr).await
117    ///     .graceful(async {
118    ///         // some signal in here, such as ctrl_c
119    ///     })
120    ///     .run().await;
121    /// # }
122    /// ```
123    pub fn graceful<Fut>(self, shutdown_signal: Fut) -> Server<F, A, run::Graceful<Fut>>
124    where
125        Fut: Future<Output = ()> + Send + 'static,
126    {
127        Server {
128            acceptor: self.acceptor,
129            filter: self.filter,
130            pipeline: self.pipeline,
131            runner: run::Graceful(shutdown_signal),
132        }
133    }
134
135    /// Run this server.
136    pub async fn run(self) {
137        R::run(self).await;
138    }
139}
140
141// // ===== impl Tls =====
142
143#[cfg(feature = "tls")]
144impl<F, A, R> Server<F, accept::Tls<A>, R>
145where
146    F: Filter + Clone + Send + Sync + 'static,
147    <F::Future as TryFuture>::Ok: Reply,
148    <F::Future as TryFuture>::Error: IsReject,
149    A: accept::Accept,
150    R: run::Run,
151{
152    // TLS config methods
153
154    /// Specify the file path to read the private key.
155    ///
156    /// *This function requires the `"tls"` feature.*
157    pub fn key_path(self, path: impl AsRef<Path>) -> Self {
158        self.with_tls(|tls| tls.key_path(path))
159    }
160
161    /// Specify the file path to read the certificate.
162    ///
163    /// *This function requires the `"tls"` feature.*
164    pub fn cert_path(self, path: impl AsRef<Path>) -> Self {
165        self.with_tls(|tls| tls.cert_path(path))
166    }
167
168    /// Specify the file path to read the trust anchor for optional client authentication.
169    ///
170    /// Anonymous and authenticated clients will be accepted. If no trust anchor is provided by any
171    /// of the `client_auth_` methods, then client authentication is disabled by default.
172    ///
173    /// *This function requires the `"tls"` feature.*
174    pub fn client_auth_optional_path(self, path: impl AsRef<Path>) -> Self {
175        self.with_tls(|tls| tls.client_auth_optional_path(path))
176    }
177
178    /// Specify the file path to read the trust anchor for required client authentication.
179    ///
180    /// Only authenticated clients will be accepted. If no trust anchor is provided by any of the
181    /// `client_auth_` methods, then client authentication is disabled by default.
182    ///
183    /// *This function requires the `"tls"` feature.*
184    pub fn client_auth_required_path(self, path: impl AsRef<Path>) -> Self {
185        self.with_tls(|tls| tls.client_auth_required_path(path))
186    }
187
188    /// Specify the in-memory contents of the private key.
189    ///
190    /// *This function requires the `"tls"` feature.*
191    pub fn key(self, key: impl AsRef<[u8]>) -> Self {
192        self.with_tls(|tls| tls.key(key.as_ref()))
193    }
194
195    /// Specify the in-memory contents of the certificate.
196    ///
197    /// *This function requires the `"tls"` feature.*
198    pub fn cert(self, cert: impl AsRef<[u8]>) -> Self {
199        self.with_tls(|tls| tls.cert(cert.as_ref()))
200    }
201
202    /// Specify the in-memory contents of the trust anchor for optional client authentication.
203    ///
204    /// Anonymous and authenticated clients will be accepted. If no trust anchor is provided by any
205    /// of the `client_auth_` methods, then client authentication is disabled by default.
206    ///
207    /// *This function requires the `"tls"` feature.*
208    pub fn client_auth_optional(self, trust_anchor: impl AsRef<[u8]>) -> Self {
209        self.with_tls(|tls| tls.client_auth_optional(trust_anchor.as_ref()))
210    }
211
212    /// Specify the in-memory contents of the trust anchor for required client authentication.
213    ///
214    /// Only authenticated clients will be accepted. If no trust anchor is provided by any of the
215    /// `client_auth_` methods, then client authentication is disabled by default.
216    ///
217    /// *This function requires the `"tls"` feature.*
218    pub fn client_auth_required(self, trust_anchor: impl AsRef<[u8]>) -> Self {
219        self.with_tls(|tls| tls.client_auth_required(trust_anchor.as_ref()))
220    }
221
222    /// Specify the DER-encoded OCSP response.
223    ///
224    /// *This function requires the `"tls"` feature.*
225    pub fn ocsp_resp(self, resp: impl AsRef<[u8]>) -> Self {
226        self.with_tls(|tls| tls.ocsp_resp(resp.as_ref()))
227    }
228
229    fn with_tls<Func>(self, func: Func) -> Self
230    where
231        Func: FnOnce(TlsConfigBuilder) -> TlsConfigBuilder,
232    {
233        let tls = func(tls);
234    }
235}
236
237mod accept {
238    use std::net::SocketAddr;
239
240    pub trait Accept {
241        type IO: hyper::rt::Read + hyper::rt::Write + Send + Unpin + 'static;
242        type AcceptError: std::fmt::Debug;
243        type Accepting: super::Future<Output = Result<(Self::IO, Option<SocketAddr>), Self::AcceptError>>
244            + Send
245            + 'static;
246        #[allow(async_fn_in_trait)]
247        async fn accept(&mut self) -> Result<Self::Accepting, std::io::Error>;
248    }
249
250    #[derive(Debug)]
251    pub struct LazyTcp;
252
253    impl Accept for tokio::net::TcpListener {
254        type IO = hyper_util::rt::TokioIo<tokio::net::TcpStream>;
255        type AcceptError = std::convert::Infallible;
256        type Accepting =
257            std::future::Ready<Result<(Self::IO, Option<SocketAddr>), Self::AcceptError>>;
258        async fn accept(&mut self) -> Result<Self::Accepting, std::io::Error> {
259            let (io, addr) = <tokio::net::TcpListener>::accept(self).await?;
260            Ok(std::future::ready(Ok((
261                hyper_util::rt::TokioIo::new(io),
262                Some(addr),
263            ))))
264        }
265    }
266
267    #[cfg(unix)]
268    impl Accept for tokio::net::UnixListener {
269        type IO = hyper_util::rt::TokioIo<tokio::net::UnixStream>;
270        type AcceptError = std::convert::Infallible;
271        type Accepting =
272            std::future::Ready<Result<(Self::IO, Option<SocketAddr>), Self::AcceptError>>;
273        async fn accept(&mut self) -> Result<Self::Accepting, std::io::Error> {
274            let (io, _addr) = <tokio::net::UnixListener>::accept(self).await?;
275            Ok(std::future::ready(Ok((
276                hyper_util::rt::TokioIo::new(io),
277                None,
278            ))))
279        }
280    }
281
282    #[cfg(feature = "tls")]
283    #[derive(Debug)]
284    pub struct Tls<A>(pub(super) A);
285
286    #[cfg(feature = "tls")]
287    impl<A: Accept> Accept for Tls<A> {
288        type IO = hyper_util::rt::TokioIo<tokio::net::TcpStream>;
289        type AcceptError = std::convert::Infallible;
290        type Accepting =
291            std::future::Ready<Result<(Self::IO, Option<SocketAddr>), Self::AcceptError>>;
292        async fn accept(&mut self) -> Result<Self::Accepting, std::io::Error> {
293            let (io, addr) = self.0.accept().await?;
294            Ok(std::future::ready(Ok((
295                hyper_util::rt::TokioIo::new(io),
296                addr,
297            ))))
298        }
299    }
300}
301
302mod middleware {
303    use std::net::SocketAddr;
304    use std::task::{Context, Poll};
305    use tower_service::Service;
306
307    use crate::filters::addr::RemoteAddr;
308
309    #[derive(Clone, Debug)]
310    pub(super) struct RemoteAddrService<S> {
311        inner: S,
312        remote_addr: Option<SocketAddr>,
313    }
314
315    impl<S> RemoteAddrService<S> {
316        pub(super) fn new(inner: S, remote_addr: Option<SocketAddr>) -> Self {
317            Self { inner, remote_addr }
318        }
319    }
320
321    impl<S, B> Service<http::Request<B>> for RemoteAddrService<S>
322    where
323        S: Service<http::Request<B>>,
324    {
325        type Response = S::Response;
326        type Error = S::Error;
327        type Future = S::Future;
328
329        fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
330            self.inner.poll_ready(cx)
331        }
332
333        fn call(&mut self, mut req: http::Request<B>) -> Self::Future {
334            if let Some(addr) = self.remote_addr {
335                req.extensions_mut().insert(RemoteAddr(addr));
336            }
337            self.inner.call(req)
338        }
339    }
340}
341
342mod run {
343    pub trait Run {
344        #[allow(async_fn_in_trait)]
345        async fn run<F, A>(server: super::Server<F, A, Self>)
346        where
347            F: super::Filter + Clone + Send + Sync + 'static,
348            <F::Future as super::TryFuture>::Ok: super::Reply,
349            <F::Future as super::TryFuture>::Error: super::IsReject,
350            A: super::accept::Accept,
351            Self: Sized;
352    }
353
354    #[derive(Debug)]
355    pub struct Standard;
356
357    impl Run for Standard {
358        async fn run<F, A>(mut server: super::Server<F, A, Self>)
359        where
360            F: super::Filter + Clone + Send + Sync + 'static,
361            <F::Future as super::TryFuture>::Ok: super::Reply,
362            <F::Future as super::TryFuture>::Error: super::IsReject,
363            A: super::accept::Accept,
364            Self: Sized,
365        {
366            let pipeline = server.pipeline;
367            loop {
368                let accepting = match server.acceptor.accept().await {
369                    Ok(fut) => fut,
370                    Err(err) => {
371                        handle_accept_error(err).await;
372                        continue;
373                    }
374                };
375                let svc = crate::service(server.filter.clone());
376                tokio::spawn(async move {
377                    let (io, remote_addr) = match accepting.await {
378                        Ok(pair) => pair,
379                        Err(err) => {
380                            tracing::debug!("server accept error: {:?}", err);
381                            return;
382                        }
383                    };
384                    let svc = super::middleware::RemoteAddrService::new(svc, remote_addr);
385                    let svc = hyper_util::service::TowerToHyperService::new(svc);
386                    if let Err(err) = hyper_util::server::conn::auto::Builder::new(
387                        hyper_util::rt::TokioExecutor::new(),
388                    )
389                    .http1()
390                    .pipeline_flush(pipeline)
391                    .serve_connection_with_upgrades(io, svc)
392                    .await
393                    {
394                        tracing::error!("server connection error: {:?}", err)
395                    }
396                });
397            }
398        }
399    }
400
401    #[derive(Debug)]
402    pub struct Graceful<Fut>(pub(super) Fut);
403
404    impl<Fut> Run for Graceful<Fut>
405    where
406        Fut: super::Future<Output = ()> + Send + 'static,
407    {
408        async fn run<F, A>(mut server: super::Server<F, A, Self>)
409        where
410            F: super::Filter + Clone + Send + Sync + 'static,
411            <F::Future as super::TryFuture>::Ok: super::Reply,
412            <F::Future as super::TryFuture>::Error: super::IsReject,
413            A: super::accept::Accept,
414            Self: Sized,
415        {
416            use futures_util::future;
417
418            let pipeline = server.pipeline;
419            let graceful_util = hyper_util::server::graceful::GracefulShutdown::new();
420            let mut shutdown_signal = std::pin::pin!(server.runner.0);
421            loop {
422                let accept = std::pin::pin!(server.acceptor.accept());
423                let accepting = match future::select(accept, &mut shutdown_signal).await {
424                    future::Either::Left((Ok(fut), _)) => fut,
425                    future::Either::Left((Err(err), _)) => {
426                        handle_accept_error(err).await;
427                        continue;
428                    }
429                    future::Either::Right(((), _)) => {
430                        tracing::debug!("shutdown signal received, starting graceful shutdown");
431                        break;
432                    }
433                };
434                let svc = crate::service(server.filter.clone());
435                let watcher = graceful_util.watcher();
436                tokio::spawn(async move {
437                    let (io, remote_addr) = match accepting.await {
438                        Ok(pair) => pair,
439                        Err(err) => {
440                            tracing::debug!("server accepting error: {:?}", err);
441                            return;
442                        }
443                    };
444                    let svc = super::middleware::RemoteAddrService::new(svc, remote_addr);
445                    let svc = hyper_util::service::TowerToHyperService::new(svc);
446                    let mut hyper = hyper_util::server::conn::auto::Builder::new(
447                        hyper_util::rt::TokioExecutor::new(),
448                    );
449                    hyper.http1().pipeline_flush(pipeline);
450                    let conn = hyper.serve_connection_with_upgrades(io, svc);
451                    let conn = watcher.watch(conn);
452                    if let Err(err) = conn.await {
453                        tracing::error!("server connection error: {:?}", err)
454                    }
455                });
456            }
457
458            drop(server.acceptor); // close listener
459            graceful_util.shutdown().await;
460        }
461    }
462
463    // TODO: allow providing your own handler
464    async fn handle_accept_error(e: std::io::Error) {
465        if is_connection_error(&e) {
466            return;
467        }
468        // [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186)
469        //
470        // > A possible scenario is that the process has hit the max open files
471        // > allowed, and so trying to accept a new connection will fail with
472        // > `EMFILE`. In some cases, it's preferable to just wait for some time, if
473        // > the application will likely close some files (or connections), and try
474        // > to accept the connection again. If this option is `true`, the error
475        // > will be logged at the `error` level, since it is still a big deal,
476        // > and then the listener will sleep for 1 second.
477        tracing::error!("accept error: {:?}", e);
478        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
479    }
480
481    fn is_connection_error(e: &std::io::Error) -> bool {
482        // some errors that occur on the TCP stream are emitted when
483        // accepting, they can be ignored.
484        matches!(
485            e.kind(),
486            std::io::ErrorKind::ConnectionRefused
487                | std::io::ErrorKind::ConnectionAborted
488                | std::io::ErrorKind::ConnectionReset
489        )
490    }
491}