async_compression/generic/write/
buf_writer.rs

1// Originally sourced from `futures_util::io::buf_writer`, needs to be redefined locally so that
2// the `AsyncBufWrite` impl can access its internals, and changed a bit to make it more efficient
3// with those methods.
4
5use super::AsyncBufWrite;
6use compression_core::util::WriteBuffer;
7use std::{
8    fmt, io,
9    pin::Pin,
10    task::{ready, Context, Poll},
11};
12
13const DEFAULT_BUF_SIZE: usize = 8192;
14
15pub struct BufWriter {
16    buf: Box<[u8]>,
17    written: usize,
18    buffered: usize,
19}
20
21impl fmt::Debug for BufWriter {
22    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23        f.debug_struct("GenericBufWriter")
24            .field(
25                "buffer",
26                &format_args!("{}/{}", self.buffered, self.buf.len()),
27            )
28            .field("written", &self.written)
29            .finish()
30    }
31}
32
33impl BufWriter {
34    /// Creates a new `BufWriter` with a default buffer capacity. The default is currently 8 KB,
35    /// but may change in the future.
36    pub fn new() -> Self {
37        Self::with_capacity(DEFAULT_BUF_SIZE)
38    }
39
40    /// Creates a new `BufWriter` with the specified buffer capacity.
41    pub fn with_capacity(cap: usize) -> Self {
42        Self {
43            buf: vec![0; cap].into(),
44            written: 0,
45            buffered: 0,
46        }
47    }
48
49    /// Remove the already written data
50    fn remove_written(&mut self) {
51        self.buf.copy_within(self.written..self.buffered, 0);
52        self.buffered -= self.written;
53        self.written = 0;
54    }
55
56    fn do_flush(
57        &mut self,
58        poll_write: &mut dyn FnMut(&[u8]) -> Poll<io::Result<usize>>,
59    ) -> Poll<io::Result<()>> {
60        while self.written < self.buffered {
61            let bytes_written = ready!(poll_write(&self.buf[self.written..self.buffered]))?;
62            if bytes_written == 0 {
63                return Poll::Ready(Err(io::Error::new(
64                    io::ErrorKind::WriteZero,
65                    "failed to write the buffered data",
66                )));
67            }
68
69            self.written += bytes_written;
70        }
71
72        Poll::Ready(Ok(()))
73    }
74
75    fn partial_flush_buf(
76        &mut self,
77        poll_write: &mut dyn FnMut(&[u8]) -> Poll<io::Result<usize>>,
78    ) -> Poll<io::Result<()>> {
79        let ret = if let Poll::Ready(res) = self.do_flush(poll_write) {
80            res
81        } else {
82            Ok(())
83        };
84
85        if self.written > 0 || self.buffered < self.buf.len() {
86            Poll::Ready(ret)
87        } else {
88            ret?;
89            Poll::Pending
90        }
91    }
92
93    pub fn flush_buf(
94        &mut self,
95        poll_write: &mut dyn FnMut(&[u8]) -> Poll<io::Result<usize>>,
96    ) -> Poll<io::Result<()>> {
97        ready!(self.do_flush(poll_write))?;
98
99        debug_assert_eq!(self.buffered, self.written);
100        self.buffered = 0;
101        self.written = 0;
102
103        Poll::Ready(Ok(()))
104    }
105
106    pub fn poll_write(
107        &mut self,
108        buf: &[u8],
109        poll_write: &mut dyn FnMut(&[u8]) -> Poll<io::Result<usize>>,
110    ) -> Poll<io::Result<usize>> {
111        if buf.len() >= self.buf.len() {
112            ready!(self.flush_buf(poll_write))?;
113            poll_write(buf)
114        } else if (self.buf.len() - self.buffered) >= buf.len() {
115            self.buf[self.buffered..].copy_from_slice(buf);
116            self.buffered += buf.len();
117
118            Poll::Ready(Ok(buf.len()))
119        } else {
120            ready!(self.partial_flush_buf(poll_write))?;
121            if self.written > 0 {
122                self.remove_written();
123            }
124
125            let len = buf.len().min(self.buf.len() - self.buffered);
126            self.buf[self.buffered..self.buffered + len].copy_from_slice(&buf[..len]);
127            self.buffered += len;
128
129            Poll::Ready(Ok(len))
130        }
131    }
132
133    pub fn poll_partial_flush_buf(
134        &mut self,
135        poll_write: &mut dyn FnMut(&[u8]) -> Poll<io::Result<usize>>,
136    ) -> Poll<io::Result<Buffer<'_>>> {
137        ready!(self.partial_flush_buf(poll_write))?;
138
139        // when the flushed data is larger than or equal to half of yet-to-be-flushed data,
140        // the copyback could use version of memcpy that do copies from the head of the buffer.
141        // Anything smaller than that, an overlap would happen that forces use of memmove.
142        if self.written >= (self.buffered / 3)
143            || self.written >= 512
144            || self.buffered == self.buf.len()
145        {
146            self.remove_written();
147        }
148
149        Poll::Ready(Ok(Buffer {
150            write_buffer: WriteBuffer::new_initialized(&mut self.buf[self.buffered..]),
151            buffered: &mut self.buffered,
152        }))
153    }
154}
155
156pub struct Buffer<'a> {
157    buffered: &'a mut usize,
158    pub write_buffer: WriteBuffer<'a>,
159}
160
161impl Drop for Buffer<'_> {
162    fn drop(&mut self) {
163        *self.buffered += self.write_buffer.written_len();
164    }
165}
166
167macro_rules! impl_buf_writer {
168    ($poll_close: tt) => {
169        use crate::generic::write::{AsyncBufWrite, BufWriter as GenericBufWriter, Buffer};
170        use pin_project_lite::pin_project;
171        use std::task::ready;
172
173        pin_project! {
174            #[derive(Debug)]
175            pub struct BufWriter<W> {
176                #[pin]
177                writer: W,
178                inner: GenericBufWriter,
179            }
180        }
181
182        impl<W> BufWriter<W> {
183            /// Creates a new `BufWriter` with a default buffer capacity. The default is currently 8 KB,
184            /// but may change in the future.
185            pub fn new(writer: W) -> Self {
186                Self {
187                    writer,
188                    inner: GenericBufWriter::new(),
189                }
190            }
191
192            /// Creates a new `BufWriter` with the specified buffer capacity.
193            pub fn with_capacity(cap: usize, writer: W) -> Self {
194                Self {
195                    writer,
196                    inner: GenericBufWriter::with_capacity(cap),
197                }
198            }
199
200            /// Gets a reference to the underlying writer.
201            pub fn get_ref(&self) -> &W {
202                &self.writer
203            }
204
205            /// Gets a mutable reference to the underlying writer.
206            ///
207            /// It is inadvisable to directly write to the underlying writer.
208            pub fn get_mut(&mut self) -> &mut W {
209                &mut self.writer
210            }
211
212            /// Gets a pinned mutable reference to the underlying writer.
213            ///
214            /// It is inadvisable to directly write to the underlying writer.
215            pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
216                self.project().writer
217            }
218
219            /// Consumes this `BufWriter`, returning the underlying writer.
220            ///
221            /// Note that any leftover data in the internal buffer is lost.
222            pub fn into_inner(self) -> W {
223                self.writer
224            }
225        }
226
227        fn get_poll_write<'a, 'b, W: AsyncWrite>(
228            mut writer: Pin<&'a mut W>,
229            cx: &'a mut Context<'b>,
230        ) -> impl for<'buf> FnMut(&'buf [u8]) -> Poll<io::Result<usize>> + use<'a, 'b, W> {
231            move |buf| writer.as_mut().poll_write(cx, buf)
232        }
233
234        impl<W: AsyncWrite> BufWriter<W> {
235            fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
236                let this = self.project();
237                this.inner.flush_buf(&mut get_poll_write(this.writer, cx))
238            }
239        }
240
241        impl<W: AsyncWrite> AsyncWrite for BufWriter<W> {
242            fn poll_write(
243                mut self: Pin<&mut Self>,
244                cx: &mut Context<'_>,
245                buf: &[u8],
246            ) -> Poll<io::Result<usize>> {
247                let this = self.project();
248                this.inner
249                    .poll_write(buf, &mut get_poll_write(this.writer, cx))
250            }
251
252            fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
253                ready!(self.as_mut().flush_buf(cx))?;
254                self.project().writer.poll_flush(cx)
255            }
256
257            fn $poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
258                ready!(self.as_mut().flush_buf(cx))?;
259                self.project().writer.$poll_close(cx)
260            }
261        }
262
263        impl<W: AsyncWrite> AsyncBufWrite for BufWriter<W> {
264            fn poll_partial_flush_buf(
265                mut self: Pin<&mut Self>,
266                cx: &mut Context<'_>,
267            ) -> Poll<io::Result<Buffer<'_>>> {
268                let this = self.project();
269                this.inner
270                    .poll_partial_flush_buf(&mut get_poll_write(this.writer, cx))
271            }
272        }
273    };
274}
275pub(crate) use impl_buf_writer;