async_compression/generic/write/
buf_writer.rs1use super::AsyncBufWrite;
6use compression_core::util::WriteBuffer;
7use std::{
8 fmt, io,
9 pin::Pin,
10 task::{ready, Context, Poll},
11};
12
13const DEFAULT_BUF_SIZE: usize = 8192;
14
15pub struct BufWriter {
16 buf: Box<[u8]>,
17 written: usize,
18 buffered: usize,
19}
20
21impl fmt::Debug for BufWriter {
22 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23 f.debug_struct("GenericBufWriter")
24 .field(
25 "buffer",
26 &format_args!("{}/{}", self.buffered, self.buf.len()),
27 )
28 .field("written", &self.written)
29 .finish()
30 }
31}
32
33impl BufWriter {
34 pub fn new() -> Self {
37 Self::with_capacity(DEFAULT_BUF_SIZE)
38 }
39
40 pub fn with_capacity(cap: usize) -> Self {
42 Self {
43 buf: vec![0; cap].into(),
44 written: 0,
45 buffered: 0,
46 }
47 }
48
49 fn remove_written(&mut self) {
51 self.buf.copy_within(self.written..self.buffered, 0);
52 self.buffered -= self.written;
53 self.written = 0;
54 }
55
56 fn do_flush(
57 &mut self,
58 poll_write: &mut dyn FnMut(&[u8]) -> Poll<io::Result<usize>>,
59 ) -> Poll<io::Result<()>> {
60 while self.written < self.buffered {
61 let bytes_written = ready!(poll_write(&self.buf[self.written..self.buffered]))?;
62 if bytes_written == 0 {
63 return Poll::Ready(Err(io::Error::new(
64 io::ErrorKind::WriteZero,
65 "failed to write the buffered data",
66 )));
67 }
68
69 self.written += bytes_written;
70 }
71
72 Poll::Ready(Ok(()))
73 }
74
75 fn partial_flush_buf(
76 &mut self,
77 poll_write: &mut dyn FnMut(&[u8]) -> Poll<io::Result<usize>>,
78 ) -> Poll<io::Result<()>> {
79 let ret = if let Poll::Ready(res) = self.do_flush(poll_write) {
80 res
81 } else {
82 Ok(())
83 };
84
85 if self.written > 0 || self.buffered < self.buf.len() {
86 Poll::Ready(ret)
87 } else {
88 ret?;
89 Poll::Pending
90 }
91 }
92
93 pub fn flush_buf(
94 &mut self,
95 poll_write: &mut dyn FnMut(&[u8]) -> Poll<io::Result<usize>>,
96 ) -> Poll<io::Result<()>> {
97 ready!(self.do_flush(poll_write))?;
98
99 debug_assert_eq!(self.buffered, self.written);
100 self.buffered = 0;
101 self.written = 0;
102
103 Poll::Ready(Ok(()))
104 }
105
106 pub fn poll_write(
107 &mut self,
108 buf: &[u8],
109 poll_write: &mut dyn FnMut(&[u8]) -> Poll<io::Result<usize>>,
110 ) -> Poll<io::Result<usize>> {
111 if buf.len() >= self.buf.len() {
112 ready!(self.flush_buf(poll_write))?;
113 poll_write(buf)
114 } else if (self.buf.len() - self.buffered) >= buf.len() {
115 self.buf[self.buffered..].copy_from_slice(buf);
116 self.buffered += buf.len();
117
118 Poll::Ready(Ok(buf.len()))
119 } else {
120 ready!(self.partial_flush_buf(poll_write))?;
121 if self.written > 0 {
122 self.remove_written();
123 }
124
125 let len = buf.len().min(self.buf.len() - self.buffered);
126 self.buf[self.buffered..self.buffered + len].copy_from_slice(&buf[..len]);
127 self.buffered += len;
128
129 Poll::Ready(Ok(len))
130 }
131 }
132
133 pub fn poll_partial_flush_buf(
134 &mut self,
135 poll_write: &mut dyn FnMut(&[u8]) -> Poll<io::Result<usize>>,
136 ) -> Poll<io::Result<Buffer<'_>>> {
137 ready!(self.partial_flush_buf(poll_write))?;
138
139 if self.written >= (self.buffered / 3)
143 || self.written >= 512
144 || self.buffered == self.buf.len()
145 {
146 self.remove_written();
147 }
148
149 Poll::Ready(Ok(Buffer {
150 write_buffer: WriteBuffer::new_initialized(&mut self.buf[self.buffered..]),
151 buffered: &mut self.buffered,
152 }))
153 }
154}
155
156pub struct Buffer<'a> {
157 buffered: &'a mut usize,
158 pub write_buffer: WriteBuffer<'a>,
159}
160
161impl Drop for Buffer<'_> {
162 fn drop(&mut self) {
163 *self.buffered += self.write_buffer.written_len();
164 }
165}
166
167macro_rules! impl_buf_writer {
168 ($poll_close: tt) => {
169 use crate::generic::write::{AsyncBufWrite, BufWriter as GenericBufWriter, Buffer};
170 use pin_project_lite::pin_project;
171 use std::task::ready;
172
173 pin_project! {
174 #[derive(Debug)]
175 pub struct BufWriter<W> {
176 #[pin]
177 writer: W,
178 inner: GenericBufWriter,
179 }
180 }
181
182 impl<W> BufWriter<W> {
183 pub fn new(writer: W) -> Self {
186 Self {
187 writer,
188 inner: GenericBufWriter::new(),
189 }
190 }
191
192 pub fn with_capacity(cap: usize, writer: W) -> Self {
194 Self {
195 writer,
196 inner: GenericBufWriter::with_capacity(cap),
197 }
198 }
199
200 pub fn get_ref(&self) -> &W {
202 &self.writer
203 }
204
205 pub fn get_mut(&mut self) -> &mut W {
209 &mut self.writer
210 }
211
212 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
216 self.project().writer
217 }
218
219 pub fn into_inner(self) -> W {
223 self.writer
224 }
225 }
226
227 fn get_poll_write<'a, 'b, W: AsyncWrite>(
228 mut writer: Pin<&'a mut W>,
229 cx: &'a mut Context<'b>,
230 ) -> impl for<'buf> FnMut(&'buf [u8]) -> Poll<io::Result<usize>> + use<'a, 'b, W> {
231 move |buf| writer.as_mut().poll_write(cx, buf)
232 }
233
234 impl<W: AsyncWrite> BufWriter<W> {
235 fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
236 let this = self.project();
237 this.inner.flush_buf(&mut get_poll_write(this.writer, cx))
238 }
239 }
240
241 impl<W: AsyncWrite> AsyncWrite for BufWriter<W> {
242 fn poll_write(
243 mut self: Pin<&mut Self>,
244 cx: &mut Context<'_>,
245 buf: &[u8],
246 ) -> Poll<io::Result<usize>> {
247 let this = self.project();
248 this.inner
249 .poll_write(buf, &mut get_poll_write(this.writer, cx))
250 }
251
252 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
253 ready!(self.as_mut().flush_buf(cx))?;
254 self.project().writer.poll_flush(cx)
255 }
256
257 fn $poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
258 ready!(self.as_mut().flush_buf(cx))?;
259 self.project().writer.$poll_close(cx)
260 }
261 }
262
263 impl<W: AsyncWrite> AsyncBufWrite for BufWriter<W> {
264 fn poll_partial_flush_buf(
265 mut self: Pin<&mut Self>,
266 cx: &mut Context<'_>,
267 ) -> Poll<io::Result<Buffer<'_>>> {
268 let this = self.project();
269 this.inner
270 .poll_partial_flush_buf(&mut get_poll_write(this.writer, cx))
271 }
272 }
273 };
274}
275pub(crate) use impl_buf_writer;