tokio_rustls/
server.rs

1use std::future::Future;
2use std::io::{self, BufRead as _};
3#[cfg(unix)]
4use std::os::unix::io::{AsRawFd, RawFd};
5#[cfg(windows)]
6use std::os::windows::io::{AsRawSocket, RawSocket};
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use rustls::server::AcceptedAlert;
12use rustls::{ServerConfig, ServerConnection};
13use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
14
15use crate::common::{IoSession, MidHandshake, Stream, SyncReadAdapter, SyncWriteAdapter, TlsState};
16
17/// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method.
18#[derive(Clone)]
19pub struct TlsAcceptor {
20    inner: Arc<ServerConfig>,
21}
22
23impl From<Arc<ServerConfig>> for TlsAcceptor {
24    fn from(inner: Arc<ServerConfig>) -> Self {
25        Self { inner }
26    }
27}
28
29impl TlsAcceptor {
30    #[inline]
31    pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
32    where
33        IO: AsyncRead + AsyncWrite + Unpin,
34    {
35        self.accept_with(stream, |_| ())
36    }
37
38    pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
39    where
40        IO: AsyncRead + AsyncWrite + Unpin,
41        F: FnOnce(&mut ServerConnection),
42    {
43        let mut session = match ServerConnection::new(self.inner.clone()) {
44            Ok(session) => session,
45            Err(error) => {
46                return Accept(MidHandshake::Error {
47                    io: stream,
48                    // TODO(eliza): should this really return an `io::Error`?
49                    // Probably not...
50                    error: io::Error::new(io::ErrorKind::Other, error),
51                });
52            }
53        };
54        f(&mut session);
55
56        Accept(MidHandshake::Handshaking(TlsStream {
57            session,
58            io: stream,
59            state: TlsState::Stream,
60            need_flush: false,
61        }))
62    }
63
64    /// Get a read-only reference to underlying config
65    pub fn config(&self) -> &Arc<ServerConfig> {
66        &self.inner
67    }
68}
69
70pub struct LazyConfigAcceptor<IO> {
71    acceptor: rustls::server::Acceptor,
72    io: Option<IO>,
73    alert: Option<(rustls::Error, AcceptedAlert)>,
74}
75
76impl<IO> LazyConfigAcceptor<IO>
77where
78    IO: AsyncRead + AsyncWrite + Unpin,
79{
80    #[inline]
81    pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self {
82        Self {
83            acceptor,
84            io: Some(io),
85            alert: None,
86        }
87    }
88
89    /// Takes back the client connection. Will return `None` if called more than once or if the
90    /// connection has been accepted.
91    ///
92    /// # Example
93    ///
94    /// ```no_run
95    /// # fn choose_server_config(
96    /// #     _: rustls::server::ClientHello,
97    /// # ) -> std::sync::Arc<rustls::ServerConfig> {
98    /// #     unimplemented!();
99    /// # }
100    /// # #[allow(unused_variables)]
101    /// # async fn listen() {
102    /// use tokio::io::AsyncWriteExt;
103    /// let listener = tokio::net::TcpListener::bind("127.0.0.1:4443").await.unwrap();
104    /// let (stream, _) = listener.accept().await.unwrap();
105    ///
106    /// let acceptor = tokio_rustls::LazyConfigAcceptor::new(rustls::server::Acceptor::default(), stream);
107    /// tokio::pin!(acceptor);
108    ///
109    /// match acceptor.as_mut().await {
110    ///     Ok(start) => {
111    ///         let clientHello = start.client_hello();
112    ///         let config = choose_server_config(clientHello);
113    ///         let stream = start.into_stream(config).await.unwrap();
114    ///         // Proceed with handling the ServerConnection...
115    ///     }
116    ///     Err(err) => {
117    ///         if let Some(mut stream) = acceptor.take_io() {
118    ///             stream
119    ///                 .write_all(
120    ///                     format!("HTTP/1.1 400 Invalid Input\r\n\r\n\r\n{:?}\n", err)
121    ///                         .as_bytes()
122    ///                 )
123    ///                 .await
124    ///                 .unwrap();
125    ///         }
126    ///     }
127    /// }
128    /// # }
129    /// ```
130    pub fn take_io(&mut self) -> Option<IO> {
131        self.io.take()
132    }
133}
134
135impl<IO> Future for LazyConfigAcceptor<IO>
136where
137    IO: AsyncRead + AsyncWrite + Unpin,
138{
139    type Output = Result<StartHandshake<IO>, io::Error>;
140
141    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
142        let this = self.get_mut();
143        loop {
144            let io = match this.io.as_mut() {
145                Some(io) => io,
146                None => {
147                    return Poll::Ready(Err(io::Error::new(
148                        io::ErrorKind::Other,
149                        "acceptor cannot be polled after acceptance",
150                    )))
151                }
152            };
153
154            if let Some((err, mut alert)) = this.alert.take() {
155                match alert.write(&mut SyncWriteAdapter { io, cx }) {
156                    Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
157                        this.alert = Some((err, alert));
158                        return Poll::Pending;
159                    }
160                    Ok(0) | Err(_) => {
161                        return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, err)))
162                    }
163                    Ok(_) => {
164                        this.alert = Some((err, alert));
165                        continue;
166                    }
167                };
168            }
169
170            let mut reader = SyncReadAdapter { io, cx };
171            match this.acceptor.read_tls(&mut reader) {
172                Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()).into(),
173                Ok(_) => {}
174                Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
175                Err(e) => return Err(e).into(),
176            }
177
178            match this.acceptor.accept() {
179                Ok(Some(accepted)) => {
180                    let io = this.io.take().unwrap();
181                    return Poll::Ready(Ok(StartHandshake { accepted, io }));
182                }
183                Ok(None) => {}
184                Err((err, alert)) => {
185                    this.alert = Some((err, alert));
186                }
187            }
188        }
189    }
190}
191
192/// An incoming connection received through [`LazyConfigAcceptor`].
193///
194/// This contains the generic `IO` asynchronous transport,
195/// [`ClientHello`](rustls::server::ClientHello) data,
196/// and all the state required to continue the TLS handshake (e.g. via
197/// [`StartHandshake::into_stream`]).
198#[non_exhaustive]
199#[derive(Debug)]
200pub struct StartHandshake<IO> {
201    pub accepted: rustls::server::Accepted,
202    pub io: IO,
203}
204
205impl<IO> StartHandshake<IO>
206where
207    IO: AsyncRead + AsyncWrite + Unpin,
208{
209    /// Create a new object from an `IO` transport and prior TLS metadata.
210    pub fn from_parts(accepted: rustls::server::Accepted, transport: IO) -> Self {
211        Self {
212            accepted,
213            io: transport,
214        }
215    }
216
217    pub fn client_hello(&self) -> rustls::server::ClientHello<'_> {
218        self.accepted.client_hello()
219    }
220
221    pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> {
222        self.into_stream_with(config, |_| ())
223    }
224
225    pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO>
226    where
227        F: FnOnce(&mut ServerConnection),
228    {
229        let mut conn = match self.accepted.into_connection(config) {
230            Ok(conn) => conn,
231            Err((error, alert)) => {
232                return Accept(MidHandshake::SendAlert {
233                    io: self.io,
234                    alert,
235                    // TODO(eliza): should this really return an `io::Error`?
236                    // Probably not...
237                    error: io::Error::new(io::ErrorKind::InvalidData, error),
238                });
239            }
240        };
241        f(&mut conn);
242
243        Accept(MidHandshake::Handshaking(TlsStream {
244            session: conn,
245            io: self.io,
246            state: TlsState::Stream,
247            need_flush: false,
248        }))
249    }
250}
251
252/// Future returned from `TlsAcceptor::accept` which will resolve
253/// once the accept handshake has finished.
254pub struct Accept<IO>(MidHandshake<TlsStream<IO>>);
255
256impl<IO> Accept<IO> {
257    #[inline]
258    pub fn into_fallible(self) -> FallibleAccept<IO> {
259        FallibleAccept(self.0)
260    }
261
262    pub fn get_ref(&self) -> Option<&IO> {
263        match &self.0 {
264            MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
265            MidHandshake::SendAlert { io, .. } => Some(io),
266            MidHandshake::Error { io, .. } => Some(io),
267            MidHandshake::End => None,
268        }
269    }
270
271    pub fn get_mut(&mut self) -> Option<&mut IO> {
272        match &mut self.0 {
273            MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
274            MidHandshake::SendAlert { io, .. } => Some(io),
275            MidHandshake::Error { io, .. } => Some(io),
276            MidHandshake::End => None,
277        }
278    }
279}
280
281impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
282    type Output = io::Result<TlsStream<IO>>;
283
284    #[inline]
285    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
286        Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
287    }
288}
289
290/// Like [Accept], but returns `IO` on failure.
291pub struct FallibleAccept<IO>(MidHandshake<TlsStream<IO>>);
292
293impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleAccept<IO> {
294    type Output = Result<TlsStream<IO>, (io::Error, IO)>;
295
296    #[inline]
297    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
298        Pin::new(&mut self.0).poll(cx)
299    }
300}
301
302/// A wrapper around an underlying raw stream which implements the TLS or SSL
303/// protocol.
304#[derive(Debug)]
305pub struct TlsStream<IO> {
306    pub(crate) io: IO,
307    pub(crate) session: ServerConnection,
308    pub(crate) state: TlsState,
309    pub(crate) need_flush: bool,
310}
311
312impl<IO> TlsStream<IO> {
313    #[inline]
314    pub fn get_ref(&self) -> (&IO, &ServerConnection) {
315        (&self.io, &self.session)
316    }
317
318    #[inline]
319    pub fn get_mut(&mut self) -> (&mut IO, &mut ServerConnection) {
320        (&mut self.io, &mut self.session)
321    }
322
323    #[inline]
324    pub fn into_inner(self) -> (IO, ServerConnection) {
325        (self.io, self.session)
326    }
327}
328
329impl<IO> IoSession for TlsStream<IO> {
330    type Io = IO;
331    type Session = ServerConnection;
332
333    #[inline]
334    fn skip_handshake(&self) -> bool {
335        false
336    }
337
338    #[inline]
339    fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session, &mut bool) {
340        (
341            &mut self.state,
342            &mut self.io,
343            &mut self.session,
344            &mut self.need_flush,
345        )
346    }
347
348    #[inline]
349    fn into_io(self) -> Self::Io {
350        self.io
351    }
352}
353
354impl<IO> AsyncRead for TlsStream<IO>
355where
356    IO: AsyncRead + AsyncWrite + Unpin,
357{
358    fn poll_read(
359        mut self: Pin<&mut Self>,
360        cx: &mut Context<'_>,
361        buf: &mut ReadBuf<'_>,
362    ) -> Poll<io::Result<()>> {
363        let data = ready!(self.as_mut().poll_fill_buf(cx))?;
364        let len = data.len().min(buf.remaining());
365        buf.put_slice(&data[..len]);
366        self.consume(len);
367        Poll::Ready(Ok(()))
368    }
369}
370
371impl<IO> AsyncBufRead for TlsStream<IO>
372where
373    IO: AsyncRead + AsyncWrite + Unpin,
374{
375    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
376        match self.state {
377            TlsState::Stream | TlsState::WriteShutdown => {
378                let this = self.get_mut();
379                let stream =
380                    Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
381
382                match stream.poll_fill_buf(cx) {
383                    Poll::Ready(Ok(buf)) => {
384                        if buf.is_empty() {
385                            this.state.shutdown_read();
386                        }
387
388                        Poll::Ready(Ok(buf))
389                    }
390                    Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
391                        this.state.shutdown_read();
392                        Poll::Ready(Err(err))
393                    }
394                    output => output,
395                }
396            }
397            TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(&[])),
398            #[cfg(feature = "early-data")]
399            ref s => unreachable!("server TLS can not hit this state: {:?}", s),
400        }
401    }
402
403    fn consume(mut self: Pin<&mut Self>, amt: usize) {
404        self.session.reader().consume(amt);
405    }
406}
407
408impl<IO> AsyncWrite for TlsStream<IO>
409where
410    IO: AsyncRead + AsyncWrite + Unpin,
411{
412    /// Note: that it does not guarantee the final data to be sent.
413    /// To be cautious, you must manually call `flush`.
414    fn poll_write(
415        self: Pin<&mut Self>,
416        cx: &mut Context<'_>,
417        buf: &[u8],
418    ) -> Poll<io::Result<usize>> {
419        let this = self.get_mut();
420        let mut stream =
421            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
422        stream.as_mut_pin().poll_write(cx, buf)
423    }
424
425    /// Note: that it does not guarantee the final data to be sent.
426    /// To be cautious, you must manually call `flush`.
427    fn poll_write_vectored(
428        self: Pin<&mut Self>,
429        cx: &mut Context<'_>,
430        bufs: &[io::IoSlice<'_>],
431    ) -> Poll<io::Result<usize>> {
432        let this = self.get_mut();
433        let mut stream =
434            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
435        stream.as_mut_pin().poll_write_vectored(cx, bufs)
436    }
437
438    #[inline]
439    fn is_write_vectored(&self) -> bool {
440        true
441    }
442
443    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
444        let this = self.get_mut();
445        let mut stream =
446            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
447        stream.as_mut_pin().poll_flush(cx)
448    }
449
450    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
451        if self.state.writeable() {
452            self.session.send_close_notify();
453            self.state.shutdown_write();
454        }
455
456        let this = self.get_mut();
457        let mut stream =
458            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
459        stream.as_mut_pin().poll_shutdown(cx)
460    }
461}
462
463#[cfg(unix)]
464impl<IO> AsRawFd for TlsStream<IO>
465where
466    IO: AsRawFd,
467{
468    fn as_raw_fd(&self) -> RawFd {
469        self.get_ref().0.as_raw_fd()
470    }
471}
472
473#[cfg(windows)]
474impl<IO> AsRawSocket for TlsStream<IO>
475where
476    IO: AsRawSocket,
477{
478    fn as_raw_socket(&self) -> RawSocket {
479        self.get_ref().0.as_raw_socket()
480    }
481}