1use std::io::{self, BufRead as _, IoSlice, Read, Write};
2use std::ops::{Deref, DerefMut};
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use rustls::{ConnectionCommon, SideData};
7use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
8
9mod handshake;
10pub(crate) use handshake::{IoSession, MidHandshake};
11
12#[derive(Debug)]
13pub(crate) enum TlsState {
14 #[cfg(feature = "early-data")]
15 EarlyData(usize, Vec<u8>),
16 Stream,
17 ReadShutdown,
18 WriteShutdown,
19 FullyShutdown,
20}
21
22impl TlsState {
23 #[inline]
24 pub(crate) fn shutdown_read(&mut self) {
25 match *self {
26 Self::WriteShutdown | Self::FullyShutdown => *self = Self::FullyShutdown,
27 _ => *self = Self::ReadShutdown,
28 }
29 }
30
31 #[inline]
32 pub(crate) fn shutdown_write(&mut self) {
33 match *self {
34 Self::ReadShutdown | Self::FullyShutdown => *self = Self::FullyShutdown,
35 _ => *self = Self::WriteShutdown,
36 }
37 }
38
39 #[inline]
40 pub(crate) fn writeable(&self) -> bool {
41 !matches!(*self, Self::WriteShutdown | Self::FullyShutdown)
42 }
43
44 #[inline]
45 pub(crate) fn readable(&self) -> bool {
46 !matches!(*self, Self::ReadShutdown | Self::FullyShutdown)
47 }
48
49 #[inline]
50 #[cfg(feature = "early-data")]
51 pub(crate) fn is_early_data(&self) -> bool {
52 matches!(self, Self::EarlyData(..))
53 }
54
55 #[inline]
56 #[cfg(not(feature = "early-data"))]
57 pub(crate) const fn is_early_data(&self) -> bool {
58 false
59 }
60}
61
62pub(crate) struct Stream<'a, IO, C> {
63 pub(crate) io: &'a mut IO,
64 pub(crate) session: &'a mut C,
65 pub(crate) eof: bool,
66 pub(crate) need_flush: bool,
67}
68
69impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> Stream<'a, IO, C>
70where
71 C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
72 SD: SideData,
73{
74 pub(crate) fn new(io: &'a mut IO, session: &'a mut C) -> Self {
75 Stream {
76 io,
77 session,
78 eof: false,
81 need_flush: false,
83 }
84 }
85
86 pub(crate) fn set_eof(mut self, eof: bool) -> Self {
87 self.eof = eof;
88 self
89 }
90
91 pub(crate) fn set_need_flush(mut self, need_flush: bool) -> Self {
92 self.need_flush = need_flush;
93 self
94 }
95
96 pub(crate) fn as_mut_pin(&mut self) -> Pin<&mut Self> {
97 Pin::new(self)
98 }
99
100 pub(crate) fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
101 let mut reader = SyncReadAdapter { io: self.io, cx };
102
103 let n = match self.session.read_tls(&mut reader) {
104 Ok(n) => n,
105 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
106 Err(err) => return Poll::Ready(Err(err)),
107 };
108
109 self.session.process_new_packets().map_err(|err| {
110 let _ = self.write_io(cx);
114
115 io::Error::new(io::ErrorKind::InvalidData, err)
116 })?;
117
118 Poll::Ready(Ok(n))
119 }
120
121 pub(crate) fn write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
122 let mut writer = SyncWriteAdapter { io: self.io, cx };
123
124 match self.session.write_tls(&mut writer) {
125 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
126 result => Poll::Ready(result),
127 }
128 }
129
130 pub(crate) fn handshake(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> {
131 let mut wrlen = 0;
132 let mut rdlen = 0;
133
134 loop {
135 let mut write_would_block = false;
136 let mut read_would_block = false;
137
138 while self.session.wants_write() {
139 match self.write_io(cx) {
140 Poll::Ready(Ok(0)) => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())),
141 Poll::Ready(Ok(n)) => {
142 wrlen += n;
143 self.need_flush = true;
144 }
145 Poll::Pending => {
146 write_would_block = true;
147 break;
148 }
149 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
150 }
151 }
152
153 if self.need_flush {
154 match Pin::new(&mut self.io).poll_flush(cx) {
155 Poll::Ready(Ok(())) => self.need_flush = false,
156 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
157 Poll::Pending => write_would_block = true,
158 }
159 }
160
161 while !self.eof && self.session.wants_read() {
162 match self.read_io(cx) {
163 Poll::Ready(Ok(0)) => self.eof = true,
164 Poll::Ready(Ok(n)) => rdlen += n,
165 Poll::Pending => {
166 read_would_block = true;
167 break;
168 }
169 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
170 }
171 }
172
173 return match (self.eof, self.session.is_handshaking()) {
174 (true, true) => {
175 let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof");
176 Poll::Ready(Err(err))
177 }
178 (_, false) => Poll::Ready(Ok((rdlen, wrlen))),
179 (_, true) if write_would_block || read_would_block => {
180 if rdlen != 0 || wrlen != 0 {
181 Poll::Ready(Ok((rdlen, wrlen)))
182 } else {
183 Poll::Pending
184 }
185 }
186 (..) => continue,
187 };
188 }
189 }
190
191 pub(crate) fn poll_fill_buf(mut self, cx: &mut Context<'_>) -> Poll<io::Result<&'a [u8]>>
192 where
193 SD: 'a,
194 {
195 let mut io_pending = false;
196
197 while !self.eof && self.session.wants_read() {
199 match self.read_io(cx) {
200 Poll::Ready(Ok(0)) => {
201 break;
202 }
203 Poll::Ready(Ok(_)) => (),
204 Poll::Pending => {
205 io_pending = true;
206 break;
207 }
208 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
209 }
210 }
211
212 match self.session.reader().into_first_chunk() {
213 Ok(buf) => {
214 Poll::Ready(Ok(buf))
217 }
218 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
219 if !io_pending {
220 cx.waker().wake_by_ref();
226 }
227
228 Poll::Pending
229 }
230 Err(e) => Poll::Ready(Err(e)),
231 }
232 }
233}
234
235impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncRead for Stream<'a, IO, C>
236where
237 C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
238 SD: SideData + 'a,
239{
240 fn poll_read(
241 mut self: Pin<&mut Self>,
242 cx: &mut Context<'_>,
243 buf: &mut ReadBuf<'_>,
244 ) -> Poll<io::Result<()>> {
245 let data = ready!(self.as_mut().poll_fill_buf(cx))?;
246 let amount = buf.remaining().min(data.len());
247 buf.put_slice(&data[..amount]);
248 self.session.reader().consume(amount);
249 Poll::Ready(Ok(()))
250 }
251}
252
253impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncBufRead for Stream<'a, IO, C>
254where
255 C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
256 SD: SideData + 'a,
257{
258 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
259 let this = self.get_mut();
260 Stream {
261 io: this.io,
263 session: this.session,
264 ..*this
265 }
266 .poll_fill_buf(cx)
267 }
268
269 fn consume(mut self: Pin<&mut Self>, amt: usize) {
270 self.session.reader().consume(amt);
271 }
272}
273
274impl<IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncWrite for Stream<'_, IO, C>
275where
276 C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
277 SD: SideData,
278{
279 fn poll_write(
280 mut self: Pin<&mut Self>,
281 cx: &mut Context,
282 buf: &[u8],
283 ) -> Poll<io::Result<usize>> {
284 let mut pos = 0;
285
286 while pos != buf.len() {
287 let mut would_block = false;
288
289 match self.session.writer().write(&buf[pos..]) {
290 Ok(n) => pos += n,
291 Err(err) => return Poll::Ready(Err(err)),
292 };
293
294 while self.session.wants_write() {
295 match self.write_io(cx) {
296 Poll::Ready(Ok(0)) | Poll::Pending => {
297 would_block = true;
298 break;
299 }
300 Poll::Ready(Ok(_)) => (),
301 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
302 }
303 }
304
305 return match (pos, would_block) {
306 (0, true) => Poll::Pending,
307 (n, true) => Poll::Ready(Ok(n)),
308 (_, false) => continue,
309 };
310 }
311
312 Poll::Ready(Ok(pos))
313 }
314
315 fn poll_write_vectored(
316 mut self: Pin<&mut Self>,
317 cx: &mut Context<'_>,
318 bufs: &[IoSlice<'_>],
319 ) -> Poll<io::Result<usize>> {
320 if bufs.iter().all(|buf| buf.is_empty()) {
321 return Poll::Ready(Ok(0));
322 }
323
324 loop {
325 let mut would_block = false;
326 let written = self.session.writer().write_vectored(bufs)?;
327
328 while self.session.wants_write() {
329 match self.write_io(cx) {
330 Poll::Ready(Ok(0)) | Poll::Pending => {
331 would_block = true;
332 break;
333 }
334 Poll::Ready(Ok(_)) => (),
335 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
336 }
337 }
338
339 return match (written, would_block) {
340 (0, true) => Poll::Pending,
341 (0, false) => continue,
342 (n, _) => Poll::Ready(Ok(n)),
343 };
344 }
345 }
346
347 #[inline]
348 fn is_write_vectored(&self) -> bool {
349 true
350 }
351
352 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
353 self.session.writer().flush()?;
354 while self.session.wants_write() {
355 if ready!(self.write_io(cx))? == 0 {
356 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
357 }
358 }
359 Pin::new(&mut self.io).poll_flush(cx)
360 }
361
362 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
363 while self.session.wants_write() {
364 if ready!(self.write_io(cx))? == 0 {
365 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
366 }
367 }
368
369 Poll::Ready(match ready!(Pin::new(&mut self.io).poll_shutdown(cx)) {
370 Ok(()) => Ok(()),
371 Err(err) if err.kind() == io::ErrorKind::NotConnected => Ok(()),
373 Err(err) => Err(err),
374 })
375 }
376}
377
378pub(crate) struct SyncReadAdapter<'a, 'b, T> {
383 pub(crate) io: &'a mut T,
384 pub(crate) cx: &'a mut Context<'b>,
385}
386
387impl<T: AsyncRead + Unpin> Read for SyncReadAdapter<'_, '_, T> {
388 #[inline]
389 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
390 let mut buf = ReadBuf::new(buf);
391 match Pin::new(&mut self.io).poll_read(self.cx, &mut buf) {
392 Poll::Ready(Ok(())) => Ok(buf.filled().len()),
393 Poll::Ready(Err(err)) => Err(err),
394 Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
395 }
396 }
397}
398
399pub(crate) struct SyncWriteAdapter<'a, 'b, T> {
404 pub(crate) io: &'a mut T,
405 pub(crate) cx: &'a mut Context<'b>,
406}
407
408impl<T: Unpin> SyncWriteAdapter<'_, '_, T> {
409 #[inline]
410 fn poll_with<U>(
411 &mut self,
412 f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll<io::Result<U>>,
413 ) -> io::Result<U> {
414 match f(Pin::new(self.io), self.cx) {
415 Poll::Ready(result) => result,
416 Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
417 }
418 }
419}
420
421impl<T: AsyncWrite + Unpin> Write for SyncWriteAdapter<'_, '_, T> {
422 #[inline]
423 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
424 self.poll_with(|io, cx| io.poll_write(cx, buf))
425 }
426
427 #[inline]
428 fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
429 self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs))
430 }
431
432 fn flush(&mut self) -> io::Result<()> {
433 self.poll_with(|io, cx| io.poll_flush(cx))
434 }
435}
436
437#[cfg(test)]
438mod test_stream;