rustls/
stream.rs

1use core::ops::{Deref, DerefMut};
2use std::io::{BufRead, IoSlice, Read, Result, Write};
3
4use crate::conn::{ConnectionCommon, SideData};
5
6/// This type implements `io::Read` and `io::Write`, encapsulating
7/// a Connection `C` and an underlying transport `T`, such as a socket.
8///
9/// Relies on [`ConnectionCommon::complete_io()`] to perform the necessary I/O.
10///
11/// This allows you to use a rustls Connection like a normal stream.
12#[derive(Debug)]
13pub struct Stream<'a, C: 'a + ?Sized, T: 'a + Read + Write + ?Sized> {
14    /// Our TLS connection
15    pub conn: &'a mut C,
16
17    /// The underlying transport, like a socket
18    pub sock: &'a mut T,
19}
20
21impl<'a, C, T, S> Stream<'a, C, T>
22where
23    C: 'a + DerefMut + Deref<Target = ConnectionCommon<S>>,
24    T: 'a + Read + Write,
25    S: SideData,
26{
27    /// Make a new Stream using the Connection `conn` and socket-like object
28    /// `sock`.  This does not fail and does no IO.
29    pub fn new(conn: &'a mut C, sock: &'a mut T) -> Self {
30        Self { conn, sock }
31    }
32
33    /// If we're handshaking, complete all the IO for that.
34    /// If we have data to write, write it all.
35    fn complete_prior_io(&mut self) -> Result<()> {
36        if self.conn.is_handshaking() {
37            self.conn.complete_io(self.sock)?;
38        }
39
40        if self.conn.wants_write() {
41            self.conn.complete_io(self.sock)?;
42        }
43
44        Ok(())
45    }
46
47    fn prepare_read(&mut self) -> Result<()> {
48        self.complete_prior_io()?;
49
50        // We call complete_io() in a loop since a single call may read only
51        // a partial packet from the underlying transport. A full packet is
52        // needed to get more plaintext, which we must do if EOF has not been
53        // hit.
54        while self.conn.wants_read() {
55            if self.conn.complete_io(self.sock)?.0 == 0 {
56                break;
57            }
58        }
59
60        Ok(())
61    }
62
63    // Implements `BufRead::fill_buf` but with more flexible lifetimes, so StreamOwned can reuse it
64    fn fill_buf(mut self) -> Result<&'a [u8]>
65    where
66        S: 'a,
67    {
68        self.prepare_read()?;
69        self.conn.reader().into_first_chunk()
70    }
71}
72
73impl<'a, C, T, S> Read for Stream<'a, C, T>
74where
75    C: 'a + DerefMut + Deref<Target = ConnectionCommon<S>>,
76    T: 'a + Read + Write,
77    S: SideData,
78{
79    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
80        self.prepare_read()?;
81        self.conn.reader().read(buf)
82    }
83
84    #[cfg(read_buf)]
85    fn read_buf(&mut self, cursor: core::io::BorrowedCursor<'_>) -> Result<()> {
86        self.prepare_read()?;
87        self.conn.reader().read_buf(cursor)
88    }
89}
90
91impl<'a, C, T, S> BufRead for Stream<'a, C, T>
92where
93    C: 'a + DerefMut + Deref<Target = ConnectionCommon<S>>,
94    T: 'a + Read + Write,
95    S: 'a + SideData,
96{
97    fn fill_buf(&mut self) -> Result<&[u8]> {
98        // reborrow to get an owned `Stream`
99        Stream {
100            conn: self.conn,
101            sock: self.sock,
102        }
103        .fill_buf()
104    }
105
106    fn consume(&mut self, amt: usize) {
107        self.conn.reader().consume(amt)
108    }
109}
110
111impl<'a, C, T, S> Write for Stream<'a, C, T>
112where
113    C: 'a + DerefMut + Deref<Target = ConnectionCommon<S>>,
114    T: 'a + Read + Write,
115    S: SideData,
116{
117    fn write(&mut self, buf: &[u8]) -> Result<usize> {
118        self.complete_prior_io()?;
119
120        let len = self.conn.writer().write(buf)?;
121
122        // Try to write the underlying transport here, but don't let
123        // any errors mask the fact we've consumed `len` bytes.
124        // Callers will learn of permanent errors on the next call.
125        let _ = self.conn.complete_io(self.sock);
126
127        Ok(len)
128    }
129
130    fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> Result<usize> {
131        self.complete_prior_io()?;
132
133        let len = self
134            .conn
135            .writer()
136            .write_vectored(bufs)?;
137
138        // Try to write the underlying transport here, but don't let
139        // any errors mask the fact we've consumed `len` bytes.
140        // Callers will learn of permanent errors on the next call.
141        let _ = self.conn.complete_io(self.sock);
142
143        Ok(len)
144    }
145
146    fn flush(&mut self) -> Result<()> {
147        self.complete_prior_io()?;
148
149        self.conn.writer().flush()?;
150        if self.conn.wants_write() {
151            self.conn.complete_io(self.sock)?;
152        }
153        Ok(())
154    }
155}
156
157/// This type implements `io::Read` and `io::Write`, encapsulating
158/// and owning a Connection `C` and an underlying transport `T`, such as a socket.
159///
160/// Relies on [`ConnectionCommon::complete_io()`] to perform the necessary I/O.
161///
162/// This allows you to use a rustls Connection like a normal stream.
163#[derive(Debug)]
164pub struct StreamOwned<C: Sized, T: Read + Write + Sized> {
165    /// Our connection
166    pub conn: C,
167
168    /// The underlying transport, like a socket
169    pub sock: T,
170}
171
172impl<C, T, S> StreamOwned<C, T>
173where
174    C: DerefMut + Deref<Target = ConnectionCommon<S>>,
175    T: Read + Write,
176    S: SideData,
177{
178    /// Make a new StreamOwned taking the Connection `conn` and socket-like
179    /// object `sock`.  This does not fail and does no IO.
180    ///
181    /// This is the same as `Stream::new` except `conn` and `sock` are
182    /// moved into the StreamOwned.
183    pub fn new(conn: C, sock: T) -> Self {
184        Self { conn, sock }
185    }
186
187    /// Get a reference to the underlying socket
188    pub fn get_ref(&self) -> &T {
189        &self.sock
190    }
191
192    /// Get a mutable reference to the underlying socket
193    pub fn get_mut(&mut self) -> &mut T {
194        &mut self.sock
195    }
196
197    /// Extract the `conn` and `sock` parts from the `StreamOwned`
198    pub fn into_parts(self) -> (C, T) {
199        (self.conn, self.sock)
200    }
201}
202
203impl<'a, C, T, S> StreamOwned<C, T>
204where
205    C: DerefMut + Deref<Target = ConnectionCommon<S>>,
206    T: Read + Write,
207    S: SideData,
208{
209    fn as_stream(&'a mut self) -> Stream<'a, C, T> {
210        Stream {
211            conn: &mut self.conn,
212            sock: &mut self.sock,
213        }
214    }
215}
216
217impl<C, T, S> Read for StreamOwned<C, T>
218where
219    C: DerefMut + Deref<Target = ConnectionCommon<S>>,
220    T: Read + Write,
221    S: SideData,
222{
223    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
224        self.as_stream().read(buf)
225    }
226
227    #[cfg(read_buf)]
228    fn read_buf(&mut self, cursor: core::io::BorrowedCursor<'_>) -> Result<()> {
229        self.as_stream().read_buf(cursor)
230    }
231}
232
233impl<C, T, S> BufRead for StreamOwned<C, T>
234where
235    C: DerefMut + Deref<Target = ConnectionCommon<S>>,
236    T: Read + Write,
237    S: 'static + SideData,
238{
239    fn fill_buf(&mut self) -> Result<&[u8]> {
240        self.as_stream().fill_buf()
241    }
242
243    fn consume(&mut self, amt: usize) {
244        self.as_stream().consume(amt)
245    }
246}
247
248impl<C, T, S> Write for StreamOwned<C, T>
249where
250    C: DerefMut + Deref<Target = ConnectionCommon<S>>,
251    T: Read + Write,
252    S: SideData,
253{
254    fn write(&mut self, buf: &[u8]) -> Result<usize> {
255        self.as_stream().write(buf)
256    }
257
258    fn flush(&mut self) -> Result<()> {
259        self.as_stream().flush()
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use std::net::TcpStream;
266
267    use super::{Stream, StreamOwned};
268    use crate::client::ClientConnection;
269    use crate::server::ServerConnection;
270
271    #[test]
272    fn stream_can_be_created_for_connection_and_tcpstream() {
273        type _Test<'a> = Stream<'a, ClientConnection, TcpStream>;
274    }
275
276    #[test]
277    fn streamowned_can_be_created_for_client_and_tcpstream() {
278        type _Test = StreamOwned<ClientConnection, TcpStream>;
279    }
280
281    #[test]
282    fn streamowned_can_be_created_for_server_and_tcpstream() {
283        type _Test = StreamOwned<ServerConnection, TcpStream>;
284    }
285}