use super::AsyncBufWrite;
use futures_core::ready;
use pin_project_lite::pin_project;
use std::{
cmp::min,
fmt, io,
pin::Pin,
task::{Context, Poll},
};
use tokio::io::AsyncWrite;
const DEFAULT_BUF_SIZE: usize = 8192;
pin_project! {
pub struct BufWriter<W> {
#[pin]
inner: W,
buf: Box<[u8]>,
written: usize,
buffered: usize,
}
}
impl<W: AsyncWrite> BufWriter<W> {
pub fn new(inner: W) -> Self {
Self::with_capacity(DEFAULT_BUF_SIZE, inner)
}
pub fn with_capacity(cap: usize, inner: W) -> Self {
Self {
inner,
buf: vec![0; cap].into(),
written: 0,
buffered: 0,
}
}
fn partial_flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let mut this = self.project();
let mut ret = Ok(());
while *this.written < *this.buffered {
match this
.inner
.as_mut()
.poll_write(cx, &this.buf[*this.written..*this.buffered])
{
Poll::Pending => {
break;
}
Poll::Ready(Ok(0)) => {
ret = Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write the buffered data",
));
break;
}
Poll::Ready(Ok(n)) => *this.written += n,
Poll::Ready(Err(e)) => {
ret = Err(e);
break;
}
}
}
if *this.written > 0 {
this.buf.copy_within(*this.written..*this.buffered, 0);
*this.buffered -= *this.written;
*this.written = 0;
Poll::Ready(ret)
} else if *this.buffered == 0 {
Poll::Ready(ret)
} else {
ret?;
Poll::Pending
}
}
fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let mut this = self.project();
let mut ret = Ok(());
while *this.written < *this.buffered {
match ready!(this
.inner
.as_mut()
.poll_write(cx, &this.buf[*this.written..*this.buffered]))
{
Ok(0) => {
ret = Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write the buffered data",
));
break;
}
Ok(n) => *this.written += n,
Err(e) => {
ret = Err(e);
break;
}
}
}
this.buf.copy_within(*this.written..*this.buffered, 0);
*this.buffered -= *this.written;
*this.written = 0;
Poll::Ready(ret)
}
}
impl<W> BufWriter<W> {
pub fn get_ref(&self) -> &W {
&self.inner
}
pub fn get_mut(&mut self) -> &mut W {
&mut self.inner
}
pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
self.project().inner
}
pub fn into_inner(self) -> W {
self.inner
}
}
impl<W: AsyncWrite> AsyncWrite for BufWriter<W> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.as_mut().project();
if *this.buffered + buf.len() > this.buf.len() {
ready!(self.as_mut().partial_flush_buf(cx))?;
}
let this = self.as_mut().project();
if buf.len() >= this.buf.len() {
if *this.buffered == 0 {
this.inner.poll_write(cx, buf)
} else {
Poll::Pending
}
} else {
let len = min(this.buf.len() - *this.buffered, buf.len());
this.buf[*this.buffered..*this.buffered + len].copy_from_slice(&buf[..len]);
*this.buffered += len;
Poll::Ready(Ok(len))
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
ready!(self.as_mut().flush_buf(cx))?;
self.project().inner.poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
ready!(self.as_mut().flush_buf(cx))?;
self.project().inner.poll_shutdown(cx)
}
}
impl<W: AsyncWrite> AsyncBufWrite for BufWriter<W> {
fn poll_partial_flush_buf(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<&mut [u8]>> {
ready!(self.as_mut().partial_flush_buf(cx))?;
let this = self.project();
Poll::Ready(Ok(&mut this.buf[*this.buffered..]))
}
fn produce(self: Pin<&mut Self>, amt: usize) {
*self.project().buffered += amt;
}
}
impl<W: fmt::Debug> fmt::Debug for BufWriter<W> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BufWriter")
.field("writer", &self.inner)
.field(
"buffer",
&format_args!("{}/{}", self.buffered, self.buf.len()),
)
.field("written", &self.written)
.finish()
}
}