async_compression/generic/write/
buf_writer.rs

1// Originally sourced from `futures_util::io::buf_writer`, needs to be redefined locally so that
2// the `AsyncBufWrite` impl can access its internals, and changed a bit to make it more efficient
3// with those methods.
4
5use super::AsyncBufWrite;
6use compression_core::util::WriteBuffer;
7use futures_core::ready;
8use std::{
9    fmt, io,
10    pin::Pin,
11    task::{Context, Poll},
12};
13
14const DEFAULT_BUF_SIZE: usize = 8192;
15
16pub struct BufWriter {
17    buf: Box<[u8]>,
18    written: usize,
19    buffered: usize,
20}
21
22impl fmt::Debug for BufWriter {
23    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24        f.debug_struct("GenericBufWriter")
25            .field(
26                "buffer",
27                &format_args!("{}/{}", self.buffered, self.buf.len()),
28            )
29            .field("written", &self.written)
30            .finish()
31    }
32}
33
34impl BufWriter {
35    /// Creates a new `BufWriter` with a default buffer capacity. The default is currently 8 KB,
36    /// but may change in the future.
37    pub fn new() -> Self {
38        Self::with_capacity(DEFAULT_BUF_SIZE)
39    }
40
41    /// Creates a new `BufWriter` with the specified buffer capacity.
42    pub fn with_capacity(cap: usize) -> Self {
43        Self {
44            buf: vec![0; cap].into(),
45            written: 0,
46            buffered: 0,
47        }
48    }
49
50    /// Remove the already written data
51    fn remove_written(&mut self) {
52        self.buf.copy_within(self.written..self.buffered, 0);
53        self.buffered -= self.written;
54        self.written = 0;
55    }
56
57    fn do_flush(
58        &mut self,
59        poll_write: &mut dyn FnMut(&[u8]) -> Poll<io::Result<usize>>,
60    ) -> Poll<io::Result<()>> {
61        while self.written < self.buffered {
62            let bytes_written = ready!(poll_write(&self.buf[self.written..self.buffered]))?;
63            if bytes_written == 0 {
64                return Poll::Ready(Err(io::Error::new(
65                    io::ErrorKind::WriteZero,
66                    "failed to write the buffered data",
67                )));
68            }
69
70            self.written += bytes_written;
71        }
72
73        Poll::Ready(Ok(()))
74    }
75
76    fn partial_flush_buf(
77        &mut self,
78        poll_write: &mut dyn FnMut(&[u8]) -> Poll<io::Result<usize>>,
79    ) -> Poll<io::Result<()>> {
80        let ret = if let Poll::Ready(res) = self.do_flush(poll_write) {
81            res
82        } else {
83            Ok(())
84        };
85
86        if self.written > 0 || self.buffered < self.buf.len() {
87            Poll::Ready(ret)
88        } else {
89            ret?;
90            Poll::Pending
91        }
92    }
93
94    pub fn flush_buf(
95        &mut self,
96        poll_write: &mut dyn FnMut(&[u8]) -> Poll<io::Result<usize>>,
97    ) -> Poll<io::Result<()>> {
98        let ret = ready!(self.do_flush(poll_write));
99
100        debug_assert_eq!(self.buffered, self.written);
101        self.buffered = 0;
102        self.written = 0;
103
104        Poll::Ready(ret)
105    }
106
107    pub fn poll_write(
108        &mut self,
109        buf: &[u8],
110        poll_write: &mut dyn FnMut(&[u8]) -> Poll<io::Result<usize>>,
111    ) -> Poll<io::Result<usize>> {
112        if buf.len() >= self.buf.len() {
113            ready!(self.flush_buf(poll_write))?;
114            poll_write(buf)
115        } else if (self.buf.len() - self.buffered) >= buf.len() {
116            self.buf[self.buffered..].copy_from_slice(buf);
117            self.buffered += buf.len();
118
119            Poll::Ready(Ok(buf.len()))
120        } else {
121            ready!(self.partial_flush_buf(poll_write))?;
122            if self.written > 0 {
123                self.remove_written();
124            }
125
126            let len = buf.len().min(self.buf.len() - self.buffered);
127            self.buf[self.buffered..self.buffered + len].copy_from_slice(&buf[..len]);
128            self.buffered += len;
129
130            Poll::Ready(Ok(len))
131        }
132    }
133
134    pub fn poll_partial_flush_buf(
135        &mut self,
136        poll_write: &mut dyn FnMut(&[u8]) -> Poll<io::Result<usize>>,
137    ) -> Poll<io::Result<Buffer<'_>>> {
138        ready!(self.partial_flush_buf(poll_write))?;
139
140        // when the flushed data is larger than or equal to half of yet-to-be-flushed data,
141        // the copyback could use version of memcpy that do copies from the head of the buffer.
142        // Anything smaller than that, an overlap would happen that forces use of memmove.
143        if self.written >= (self.buffered / 3)
144            || self.written >= 512
145            || self.buffered == self.buf.len()
146        {
147            self.remove_written();
148        }
149
150        Poll::Ready(Ok(Buffer {
151            write_buffer: WriteBuffer::new_initialized(&mut self.buf[self.buffered..]),
152            buffered: &mut self.buffered,
153        }))
154    }
155}
156
157pub struct Buffer<'a> {
158    buffered: &'a mut usize,
159    pub write_buffer: WriteBuffer<'a>,
160}
161
162impl Drop for Buffer<'_> {
163    fn drop(&mut self) {
164        *self.buffered += self.write_buffer.written_len();
165    }
166}
167
168macro_rules! impl_buf_writer {
169    ($poll_close: tt) => {
170        use crate::generic::write::{AsyncBufWrite, BufWriter as GenericBufWriter, Buffer};
171        use futures_core::ready;
172        use pin_project_lite::pin_project;
173
174        pin_project! {
175            #[derive(Debug)]
176            pub struct BufWriter<W> {
177                #[pin]
178                writer: W,
179                inner: GenericBufWriter,
180            }
181        }
182
183        impl<W> BufWriter<W> {
184            /// Creates a new `BufWriter` with a default buffer capacity. The default is currently 8 KB,
185            /// but may change in the future.
186            pub fn new(writer: W) -> Self {
187                Self {
188                    writer,
189                    inner: GenericBufWriter::new(),
190                }
191            }
192
193            /// Creates a new `BufWriter` with the specified buffer capacity.
194            pub fn with_capacity(cap: usize, writer: W) -> Self {
195                Self {
196                    writer,
197                    inner: GenericBufWriter::with_capacity(cap),
198                }
199            }
200
201            /// Gets a reference to the underlying writer.
202            pub fn get_ref(&self) -> &W {
203                &self.writer
204            }
205
206            /// Gets a mutable reference to the underlying writer.
207            ///
208            /// It is inadvisable to directly write to the underlying writer.
209            pub fn get_mut(&mut self) -> &mut W {
210                &mut self.writer
211            }
212
213            /// Gets a pinned mutable reference to the underlying writer.
214            ///
215            /// It is inadvisable to directly write to the underlying writer.
216            pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
217                self.project().writer
218            }
219
220            /// Consumes this `BufWriter`, returning the underlying writer.
221            ///
222            /// Note that any leftover data in the internal buffer is lost.
223            pub fn into_inner(self) -> W {
224                self.writer
225            }
226        }
227
228        fn get_poll_write<'a, 'b, W: AsyncWrite>(
229            mut writer: Pin<&'a mut W>,
230            cx: &'a mut Context<'b>,
231        ) -> impl for<'buf> FnMut(&'buf [u8]) -> Poll<io::Result<usize>> + use<'a, 'b, W> {
232            move |buf| writer.as_mut().poll_write(cx, buf)
233        }
234
235        impl<W: AsyncWrite> BufWriter<W> {
236            fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
237                let this = self.project();
238                this.inner.flush_buf(&mut get_poll_write(this.writer, cx))
239            }
240        }
241
242        impl<W: AsyncWrite> AsyncWrite for BufWriter<W> {
243            fn poll_write(
244                mut self: Pin<&mut Self>,
245                cx: &mut Context<'_>,
246                buf: &[u8],
247            ) -> Poll<io::Result<usize>> {
248                let this = self.project();
249                this.inner
250                    .poll_write(buf, &mut get_poll_write(this.writer, cx))
251            }
252
253            fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
254                ready!(self.as_mut().flush_buf(cx))?;
255                self.project().writer.poll_flush(cx)
256            }
257
258            fn $poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
259                ready!(self.as_mut().flush_buf(cx))?;
260                self.project().writer.$poll_close(cx)
261            }
262        }
263
264        impl<W: AsyncWrite> AsyncBufWrite for BufWriter<W> {
265            fn poll_partial_flush_buf(
266                mut self: Pin<&mut Self>,
267                cx: &mut Context<'_>,
268            ) -> Poll<io::Result<Buffer<'_>>> {
269                let this = self.project();
270                this.inner
271                    .poll_partial_flush_buf(&mut get_poll_write(this.writer, cx))
272            }
273        }
274    };
275}
276pub(crate) use impl_buf_writer;