tokio_rustls/common/
mod.rs

1use std::io::{self, BufRead as _, IoSlice, Read, Write};
2use std::ops::{Deref, DerefMut};
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use rustls::{ConnectionCommon, SideData};
7use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
8
9mod handshake;
10pub(crate) use handshake::{IoSession, MidHandshake};
11
12#[derive(Debug)]
13pub(crate) enum TlsState {
14    #[cfg(feature = "early-data")]
15    EarlyData(usize, Vec<u8>),
16    Stream,
17    ReadShutdown,
18    WriteShutdown,
19    FullyShutdown,
20}
21
22impl TlsState {
23    #[inline]
24    pub(crate) fn shutdown_read(&mut self) {
25        match *self {
26            Self::WriteShutdown | Self::FullyShutdown => *self = Self::FullyShutdown,
27            _ => *self = Self::ReadShutdown,
28        }
29    }
30
31    #[inline]
32    pub(crate) fn shutdown_write(&mut self) {
33        match *self {
34            Self::ReadShutdown | Self::FullyShutdown => *self = Self::FullyShutdown,
35            _ => *self = Self::WriteShutdown,
36        }
37    }
38
39    #[inline]
40    pub(crate) fn writeable(&self) -> bool {
41        !matches!(*self, Self::WriteShutdown | Self::FullyShutdown)
42    }
43
44    #[inline]
45    pub(crate) fn readable(&self) -> bool {
46        !matches!(*self, Self::ReadShutdown | Self::FullyShutdown)
47    }
48
49    #[inline]
50    #[cfg(feature = "early-data")]
51    pub(crate) fn is_early_data(&self) -> bool {
52        matches!(self, Self::EarlyData(..))
53    }
54
55    #[inline]
56    #[cfg(not(feature = "early-data"))]
57    pub(crate) const fn is_early_data(&self) -> bool {
58        false
59    }
60}
61
62pub(crate) struct Stream<'a, IO, C> {
63    pub(crate) io: &'a mut IO,
64    pub(crate) session: &'a mut C,
65    pub(crate) eof: bool,
66    pub(crate) need_flush: bool,
67}
68
69impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> Stream<'a, IO, C>
70where
71    C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
72    SD: SideData,
73{
74    pub(crate) fn new(io: &'a mut IO, session: &'a mut C) -> Self {
75        Stream {
76            io,
77            session,
78            // The state so far is only used to detect EOF, so either Stream
79            // or EarlyData state should both be all right.
80            eof: false,
81            // Whether a previous flush returned pending, or a write occured without a flush.
82            need_flush: false,
83        }
84    }
85
86    pub(crate) fn set_eof(mut self, eof: bool) -> Self {
87        self.eof = eof;
88        self
89    }
90
91    pub(crate) fn set_need_flush(mut self, need_flush: bool) -> Self {
92        self.need_flush = need_flush;
93        self
94    }
95
96    pub(crate) fn as_mut_pin(&mut self) -> Pin<&mut Self> {
97        Pin::new(self)
98    }
99
100    pub(crate) fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
101        let mut reader = SyncReadAdapter { io: self.io, cx };
102
103        let n = match self.session.read_tls(&mut reader) {
104            Ok(n) => n,
105            Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
106            Err(err) => return Poll::Ready(Err(err)),
107        };
108
109        self.session.process_new_packets().map_err(|err| {
110            // In case we have an alert to send describing this error,
111            // try a last-gasp write -- but don't predate the primary
112            // error.
113            let _ = self.write_io(cx);
114
115            io::Error::new(io::ErrorKind::InvalidData, err)
116        })?;
117
118        Poll::Ready(Ok(n))
119    }
120
121    pub(crate) fn write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
122        let mut writer = SyncWriteAdapter { io: self.io, cx };
123
124        match self.session.write_tls(&mut writer) {
125            Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
126            result => Poll::Ready(result),
127        }
128    }
129
130    pub(crate) fn handshake(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> {
131        let mut wrlen = 0;
132        let mut rdlen = 0;
133
134        loop {
135            let mut write_would_block = false;
136            let mut read_would_block = false;
137
138            while self.session.wants_write() {
139                match self.write_io(cx) {
140                    Poll::Ready(Ok(0)) => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())),
141                    Poll::Ready(Ok(n)) => {
142                        wrlen += n;
143                        self.need_flush = true;
144                    }
145                    Poll::Pending => {
146                        write_would_block = true;
147                        break;
148                    }
149                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
150                }
151            }
152
153            if self.need_flush {
154                match Pin::new(&mut self.io).poll_flush(cx) {
155                    Poll::Ready(Ok(())) => self.need_flush = false,
156                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
157                    Poll::Pending => write_would_block = true,
158                }
159            }
160
161            while !self.eof && self.session.wants_read() {
162                match self.read_io(cx) {
163                    Poll::Ready(Ok(0)) => self.eof = true,
164                    Poll::Ready(Ok(n)) => rdlen += n,
165                    Poll::Pending => {
166                        read_would_block = true;
167                        break;
168                    }
169                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
170                }
171            }
172
173            return match (self.eof, self.session.is_handshaking()) {
174                (true, true) => {
175                    let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof");
176                    Poll::Ready(Err(err))
177                }
178                (_, false) => Poll::Ready(Ok((rdlen, wrlen))),
179                (_, true) if write_would_block || read_would_block => {
180                    if rdlen != 0 || wrlen != 0 {
181                        Poll::Ready(Ok((rdlen, wrlen)))
182                    } else {
183                        Poll::Pending
184                    }
185                }
186                (..) => continue,
187            };
188        }
189    }
190
191    pub(crate) fn poll_fill_buf(mut self, cx: &mut Context<'_>) -> Poll<io::Result<&'a [u8]>>
192    where
193        SD: 'a,
194    {
195        let mut io_pending = false;
196
197        // read a packet
198        while !self.eof && self.session.wants_read() {
199            match self.read_io(cx) {
200                Poll::Ready(Ok(0)) => {
201                    break;
202                }
203                Poll::Ready(Ok(_)) => (),
204                Poll::Pending => {
205                    io_pending = true;
206                    break;
207                }
208                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
209            }
210        }
211
212        match self.session.reader().into_first_chunk() {
213            Ok(buf) => {
214                // Note that this could be empty (i.e. EOF) if a `CloseNotify` has been
215                // received and there is no more buffered data.
216                Poll::Ready(Ok(buf))
217            }
218            Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
219                if !io_pending {
220                    // If `wants_read()` is satisfied, rustls will not return `WouldBlock`.
221                    // but if it does, we can try again.
222                    //
223                    // If the rustls state is abnormal, it may cause a cyclic wakeup.
224                    // but tokio's cooperative budget will prevent infinite wakeup.
225                    cx.waker().wake_by_ref();
226                }
227
228                Poll::Pending
229            }
230            Err(e) => Poll::Ready(Err(e)),
231        }
232    }
233}
234
235impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncRead for Stream<'a, IO, C>
236where
237    C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
238    SD: SideData + 'a,
239{
240    fn poll_read(
241        mut self: Pin<&mut Self>,
242        cx: &mut Context<'_>,
243        buf: &mut ReadBuf<'_>,
244    ) -> Poll<io::Result<()>> {
245        let data = ready!(self.as_mut().poll_fill_buf(cx))?;
246        let amount = buf.remaining().min(data.len());
247        buf.put_slice(&data[..amount]);
248        self.session.reader().consume(amount);
249        Poll::Ready(Ok(()))
250    }
251}
252
253impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncBufRead for Stream<'a, IO, C>
254where
255    C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
256    SD: SideData + 'a,
257{
258    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
259        let this = self.get_mut();
260        Stream {
261            // reborrow
262            io: this.io,
263            session: this.session,
264            ..*this
265        }
266        .poll_fill_buf(cx)
267    }
268
269    fn consume(mut self: Pin<&mut Self>, amt: usize) {
270        self.session.reader().consume(amt);
271    }
272}
273
274impl<IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncWrite for Stream<'_, IO, C>
275where
276    C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
277    SD: SideData,
278{
279    fn poll_write(
280        mut self: Pin<&mut Self>,
281        cx: &mut Context,
282        buf: &[u8],
283    ) -> Poll<io::Result<usize>> {
284        let mut pos = 0;
285
286        while pos != buf.len() {
287            let mut would_block = false;
288
289            match self.session.writer().write(&buf[pos..]) {
290                Ok(n) => pos += n,
291                Err(err) => return Poll::Ready(Err(err)),
292            };
293
294            while self.session.wants_write() {
295                match self.write_io(cx) {
296                    Poll::Ready(Ok(0)) | Poll::Pending => {
297                        would_block = true;
298                        break;
299                    }
300                    Poll::Ready(Ok(_)) => (),
301                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
302                }
303            }
304
305            return match (pos, would_block) {
306                (0, true) => Poll::Pending,
307                (n, true) => Poll::Ready(Ok(n)),
308                (_, false) => continue,
309            };
310        }
311
312        Poll::Ready(Ok(pos))
313    }
314
315    fn poll_write_vectored(
316        mut self: Pin<&mut Self>,
317        cx: &mut Context<'_>,
318        bufs: &[IoSlice<'_>],
319    ) -> Poll<io::Result<usize>> {
320        if bufs.iter().all(|buf| buf.is_empty()) {
321            return Poll::Ready(Ok(0));
322        }
323
324        loop {
325            let mut would_block = false;
326            let written = self.session.writer().write_vectored(bufs)?;
327
328            while self.session.wants_write() {
329                match self.write_io(cx) {
330                    Poll::Ready(Ok(0)) | Poll::Pending => {
331                        would_block = true;
332                        break;
333                    }
334                    Poll::Ready(Ok(_)) => (),
335                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
336                }
337            }
338
339            return match (written, would_block) {
340                (0, true) => Poll::Pending,
341                (0, false) => continue,
342                (n, _) => Poll::Ready(Ok(n)),
343            };
344        }
345    }
346
347    #[inline]
348    fn is_write_vectored(&self) -> bool {
349        true
350    }
351
352    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
353        self.session.writer().flush()?;
354        while self.session.wants_write() {
355            if ready!(self.write_io(cx))? == 0 {
356                return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
357            }
358        }
359        Pin::new(&mut self.io).poll_flush(cx)
360    }
361
362    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
363        while self.session.wants_write() {
364            if ready!(self.write_io(cx))? == 0 {
365                return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
366            }
367        }
368
369        Poll::Ready(match ready!(Pin::new(&mut self.io).poll_shutdown(cx)) {
370            Ok(()) => Ok(()),
371            // When trying to shutdown, not being connected seems fine
372            Err(err) if err.kind() == io::ErrorKind::NotConnected => Ok(()),
373            Err(err) => Err(err),
374        })
375    }
376}
377
378/// An adapter that implements a [`Read`] interface for [`AsyncRead`] types and an
379/// associated [`Context`].
380///
381/// Turns `Poll::Pending` into `WouldBlock`.
382pub(crate) struct SyncReadAdapter<'a, 'b, T> {
383    pub(crate) io: &'a mut T,
384    pub(crate) cx: &'a mut Context<'b>,
385}
386
387impl<T: AsyncRead + Unpin> Read for SyncReadAdapter<'_, '_, T> {
388    #[inline]
389    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
390        let mut buf = ReadBuf::new(buf);
391        match Pin::new(&mut self.io).poll_read(self.cx, &mut buf) {
392            Poll::Ready(Ok(())) => Ok(buf.filled().len()),
393            Poll::Ready(Err(err)) => Err(err),
394            Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
395        }
396    }
397}
398
399/// An adapter that implements a [`Write`] interface for [`AsyncWrite`] types and an
400/// associated [`Context`].
401///
402/// Turns `Poll::Pending` into `WouldBlock`.
403pub(crate) struct SyncWriteAdapter<'a, 'b, T> {
404    pub(crate) io: &'a mut T,
405    pub(crate) cx: &'a mut Context<'b>,
406}
407
408impl<T: Unpin> SyncWriteAdapter<'_, '_, T> {
409    #[inline]
410    fn poll_with<U>(
411        &mut self,
412        f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll<io::Result<U>>,
413    ) -> io::Result<U> {
414        match f(Pin::new(self.io), self.cx) {
415            Poll::Ready(result) => result,
416            Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
417        }
418    }
419}
420
421impl<T: AsyncWrite + Unpin> Write for SyncWriteAdapter<'_, '_, T> {
422    #[inline]
423    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
424        self.poll_with(|io, cx| io.poll_write(cx, buf))
425    }
426
427    #[inline]
428    fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
429        self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs))
430    }
431
432    fn flush(&mut self) -> io::Result<()> {
433        self.poll_with(|io, cx| io.poll_flush(cx))
434    }
435}
436
437#[cfg(test)]
438mod test_stream;