1use std::future::Future;
2use std::io::{self, BufRead as _};
3#[cfg(unix)]
4use std::os::unix::io::{AsRawFd, RawFd};
5#[cfg(windows)]
6use std::os::windows::io::{AsRawSocket, RawSocket};
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use rustls::server::AcceptedAlert;
12use rustls::{ServerConfig, ServerConnection};
13use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
14
15use crate::common::{IoSession, MidHandshake, Stream, SyncReadAdapter, SyncWriteAdapter, TlsState};
16
17#[derive(Clone)]
19pub struct TlsAcceptor {
20 inner: Arc<ServerConfig>,
21}
22
23impl From<Arc<ServerConfig>> for TlsAcceptor {
24 fn from(inner: Arc<ServerConfig>) -> Self {
25 Self { inner }
26 }
27}
28
29impl TlsAcceptor {
30 #[inline]
31 pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
32 where
33 IO: AsyncRead + AsyncWrite + Unpin,
34 {
35 self.accept_with(stream, |_| ())
36 }
37
38 pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
39 where
40 IO: AsyncRead + AsyncWrite + Unpin,
41 F: FnOnce(&mut ServerConnection),
42 {
43 let mut session = match ServerConnection::new(self.inner.clone()) {
44 Ok(session) => session,
45 Err(error) => {
46 return Accept(MidHandshake::Error {
47 io: stream,
48 error: io::Error::new(io::ErrorKind::Other, error),
51 });
52 }
53 };
54 f(&mut session);
55
56 Accept(MidHandshake::Handshaking(TlsStream {
57 session,
58 io: stream,
59 state: TlsState::Stream,
60 need_flush: false,
61 }))
62 }
63
64 pub fn config(&self) -> &Arc<ServerConfig> {
66 &self.inner
67 }
68}
69
70pub struct LazyConfigAcceptor<IO> {
71 acceptor: rustls::server::Acceptor,
72 io: Option<IO>,
73 alert: Option<(rustls::Error, AcceptedAlert)>,
74}
75
76impl<IO> LazyConfigAcceptor<IO>
77where
78 IO: AsyncRead + AsyncWrite + Unpin,
79{
80 #[inline]
81 pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self {
82 Self {
83 acceptor,
84 io: Some(io),
85 alert: None,
86 }
87 }
88
89 pub fn take_io(&mut self) -> Option<IO> {
131 self.io.take()
132 }
133}
134
135impl<IO> Future for LazyConfigAcceptor<IO>
136where
137 IO: AsyncRead + AsyncWrite + Unpin,
138{
139 type Output = Result<StartHandshake<IO>, io::Error>;
140
141 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
142 let this = self.get_mut();
143 loop {
144 let io = match this.io.as_mut() {
145 Some(io) => io,
146 None => {
147 return Poll::Ready(Err(io::Error::new(
148 io::ErrorKind::Other,
149 "acceptor cannot be polled after acceptance",
150 )))
151 }
152 };
153
154 if let Some((err, mut alert)) = this.alert.take() {
155 match alert.write(&mut SyncWriteAdapter { io, cx }) {
156 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
157 this.alert = Some((err, alert));
158 return Poll::Pending;
159 }
160 Ok(0) | Err(_) => {
161 return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, err)))
162 }
163 Ok(_) => {
164 this.alert = Some((err, alert));
165 continue;
166 }
167 };
168 }
169
170 let mut reader = SyncReadAdapter { io, cx };
171 match this.acceptor.read_tls(&mut reader) {
172 Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()).into(),
173 Ok(_) => {}
174 Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
175 Err(e) => return Err(e).into(),
176 }
177
178 match this.acceptor.accept() {
179 Ok(Some(accepted)) => {
180 let io = this.io.take().unwrap();
181 return Poll::Ready(Ok(StartHandshake { accepted, io }));
182 }
183 Ok(None) => {}
184 Err((err, alert)) => {
185 this.alert = Some((err, alert));
186 }
187 }
188 }
189 }
190}
191
192#[non_exhaustive]
199#[derive(Debug)]
200pub struct StartHandshake<IO> {
201 pub accepted: rustls::server::Accepted,
202 pub io: IO,
203}
204
205impl<IO> StartHandshake<IO>
206where
207 IO: AsyncRead + AsyncWrite + Unpin,
208{
209 pub fn from_parts(accepted: rustls::server::Accepted, transport: IO) -> Self {
211 Self {
212 accepted,
213 io: transport,
214 }
215 }
216
217 pub fn client_hello(&self) -> rustls::server::ClientHello<'_> {
218 self.accepted.client_hello()
219 }
220
221 pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> {
222 self.into_stream_with(config, |_| ())
223 }
224
225 pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO>
226 where
227 F: FnOnce(&mut ServerConnection),
228 {
229 let mut conn = match self.accepted.into_connection(config) {
230 Ok(conn) => conn,
231 Err((error, alert)) => {
232 return Accept(MidHandshake::SendAlert {
233 io: self.io,
234 alert,
235 error: io::Error::new(io::ErrorKind::InvalidData, error),
238 });
239 }
240 };
241 f(&mut conn);
242
243 Accept(MidHandshake::Handshaking(TlsStream {
244 session: conn,
245 io: self.io,
246 state: TlsState::Stream,
247 need_flush: false,
248 }))
249 }
250}
251
252pub struct Accept<IO>(MidHandshake<TlsStream<IO>>);
255
256impl<IO> Accept<IO> {
257 #[inline]
258 pub fn into_fallible(self) -> FallibleAccept<IO> {
259 FallibleAccept(self.0)
260 }
261
262 pub fn get_ref(&self) -> Option<&IO> {
263 match &self.0 {
264 MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
265 MidHandshake::SendAlert { io, .. } => Some(io),
266 MidHandshake::Error { io, .. } => Some(io),
267 MidHandshake::End => None,
268 }
269 }
270
271 pub fn get_mut(&mut self) -> Option<&mut IO> {
272 match &mut self.0 {
273 MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
274 MidHandshake::SendAlert { io, .. } => Some(io),
275 MidHandshake::Error { io, .. } => Some(io),
276 MidHandshake::End => None,
277 }
278 }
279}
280
281impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
282 type Output = io::Result<TlsStream<IO>>;
283
284 #[inline]
285 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
286 Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
287 }
288}
289
290pub struct FallibleAccept<IO>(MidHandshake<TlsStream<IO>>);
292
293impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleAccept<IO> {
294 type Output = Result<TlsStream<IO>, (io::Error, IO)>;
295
296 #[inline]
297 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
298 Pin::new(&mut self.0).poll(cx)
299 }
300}
301
302#[derive(Debug)]
305pub struct TlsStream<IO> {
306 pub(crate) io: IO,
307 pub(crate) session: ServerConnection,
308 pub(crate) state: TlsState,
309 pub(crate) need_flush: bool,
310}
311
312impl<IO> TlsStream<IO> {
313 #[inline]
314 pub fn get_ref(&self) -> (&IO, &ServerConnection) {
315 (&self.io, &self.session)
316 }
317
318 #[inline]
319 pub fn get_mut(&mut self) -> (&mut IO, &mut ServerConnection) {
320 (&mut self.io, &mut self.session)
321 }
322
323 #[inline]
324 pub fn into_inner(self) -> (IO, ServerConnection) {
325 (self.io, self.session)
326 }
327}
328
329impl<IO> IoSession for TlsStream<IO> {
330 type Io = IO;
331 type Session = ServerConnection;
332
333 #[inline]
334 fn skip_handshake(&self) -> bool {
335 false
336 }
337
338 #[inline]
339 fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session, &mut bool) {
340 (
341 &mut self.state,
342 &mut self.io,
343 &mut self.session,
344 &mut self.need_flush,
345 )
346 }
347
348 #[inline]
349 fn into_io(self) -> Self::Io {
350 self.io
351 }
352}
353
354impl<IO> AsyncRead for TlsStream<IO>
355where
356 IO: AsyncRead + AsyncWrite + Unpin,
357{
358 fn poll_read(
359 mut self: Pin<&mut Self>,
360 cx: &mut Context<'_>,
361 buf: &mut ReadBuf<'_>,
362 ) -> Poll<io::Result<()>> {
363 let data = ready!(self.as_mut().poll_fill_buf(cx))?;
364 let len = data.len().min(buf.remaining());
365 buf.put_slice(&data[..len]);
366 self.consume(len);
367 Poll::Ready(Ok(()))
368 }
369}
370
371impl<IO> AsyncBufRead for TlsStream<IO>
372where
373 IO: AsyncRead + AsyncWrite + Unpin,
374{
375 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
376 match self.state {
377 TlsState::Stream | TlsState::WriteShutdown => {
378 let this = self.get_mut();
379 let stream =
380 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
381
382 match stream.poll_fill_buf(cx) {
383 Poll::Ready(Ok(buf)) => {
384 if buf.is_empty() {
385 this.state.shutdown_read();
386 }
387
388 Poll::Ready(Ok(buf))
389 }
390 Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
391 this.state.shutdown_read();
392 Poll::Ready(Err(err))
393 }
394 output => output,
395 }
396 }
397 TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(&[])),
398 #[cfg(feature = "early-data")]
399 ref s => unreachable!("server TLS can not hit this state: {:?}", s),
400 }
401 }
402
403 fn consume(mut self: Pin<&mut Self>, amt: usize) {
404 self.session.reader().consume(amt);
405 }
406}
407
408impl<IO> AsyncWrite for TlsStream<IO>
409where
410 IO: AsyncRead + AsyncWrite + Unpin,
411{
412 fn poll_write(
415 self: Pin<&mut Self>,
416 cx: &mut Context<'_>,
417 buf: &[u8],
418 ) -> Poll<io::Result<usize>> {
419 let this = self.get_mut();
420 let mut stream =
421 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
422 stream.as_mut_pin().poll_write(cx, buf)
423 }
424
425 fn poll_write_vectored(
428 self: Pin<&mut Self>,
429 cx: &mut Context<'_>,
430 bufs: &[io::IoSlice<'_>],
431 ) -> Poll<io::Result<usize>> {
432 let this = self.get_mut();
433 let mut stream =
434 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
435 stream.as_mut_pin().poll_write_vectored(cx, bufs)
436 }
437
438 #[inline]
439 fn is_write_vectored(&self) -> bool {
440 true
441 }
442
443 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
444 let this = self.get_mut();
445 let mut stream =
446 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
447 stream.as_mut_pin().poll_flush(cx)
448 }
449
450 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
451 if self.state.writeable() {
452 self.session.send_close_notify();
453 self.state.shutdown_write();
454 }
455
456 let this = self.get_mut();
457 let mut stream =
458 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
459 stream.as_mut_pin().poll_shutdown(cx)
460 }
461}
462
463#[cfg(unix)]
464impl<IO> AsRawFd for TlsStream<IO>
465where
466 IO: AsRawFd,
467{
468 fn as_raw_fd(&self) -> RawFd {
469 self.get_ref().0.as_raw_fd()
470 }
471}
472
473#[cfg(windows)]
474impl<IO> AsRawSocket for TlsStream<IO>
475where
476 IO: AsRawSocket,
477{
478 fn as_raw_socket(&self) -> RawSocket {
479 self.get_ref().0.as_raw_socket()
480 }
481}