tokio/io/util/
buf_writer.rs1use crate::io::util::DEFAULT_BUF_SIZE;
2use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
3
4use pin_project_lite::pin_project;
5use std::fmt;
6use std::io::{self, IoSlice, SeekFrom, Write};
7use std::pin::Pin;
8use std::task::{ready, Context, Poll};
9
10pin_project! {
11    #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
32    pub struct BufWriter<W> {
33        #[pin]
34        pub(super) inner: W,
35        pub(super) buf: Vec<u8>,
36        pub(super) written: usize,
37        pub(super) seek_state: SeekState,
38    }
39}
40
41impl<W: AsyncWrite> BufWriter<W> {
42    pub fn new(inner: W) -> Self {
45        Self::with_capacity(DEFAULT_BUF_SIZE, inner)
46    }
47
48    pub fn with_capacity(cap: usize, inner: W) -> Self {
50        Self {
51            inner,
52            buf: Vec::with_capacity(cap),
53            written: 0,
54            seek_state: SeekState::Init,
55        }
56    }
57
58    fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
59        let mut me = self.project();
60
61        let len = me.buf.len();
62        let mut ret = Ok(());
63        while *me.written < len {
64            match ready!(me.inner.as_mut().poll_write(cx, &me.buf[*me.written..])) {
65                Ok(0) => {
66                    ret = Err(io::Error::new(
67                        io::ErrorKind::WriteZero,
68                        "failed to write the buffered data",
69                    ));
70                    break;
71                }
72                Ok(n) => *me.written += n,
73                Err(e) => {
74                    ret = Err(e);
75                    break;
76                }
77            }
78        }
79        if *me.written > 0 {
80            me.buf.drain(..*me.written);
81        }
82        *me.written = 0;
83        Poll::Ready(ret)
84    }
85
86    pub fn get_ref(&self) -> &W {
88        &self.inner
89    }
90
91    pub fn get_mut(&mut self) -> &mut W {
95        &mut self.inner
96    }
97
98    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
102        self.project().inner
103    }
104
105    pub fn into_inner(self) -> W {
109        self.inner
110    }
111
112    pub fn buffer(&self) -> &[u8] {
114        &self.buf
115    }
116}
117
118impl<W: AsyncWrite> AsyncWrite for BufWriter<W> {
119    fn poll_write(
120        mut self: Pin<&mut Self>,
121        cx: &mut Context<'_>,
122        buf: &[u8],
123    ) -> Poll<io::Result<usize>> {
124        if self.buf.len() + buf.len() > self.buf.capacity() {
125            ready!(self.as_mut().flush_buf(cx))?;
126        }
127
128        let me = self.project();
129        if buf.len() >= me.buf.capacity() {
130            me.inner.poll_write(cx, buf)
131        } else {
132            Poll::Ready(me.buf.write(buf))
133        }
134    }
135
136    fn poll_write_vectored(
137        mut self: Pin<&mut Self>,
138        cx: &mut Context<'_>,
139        mut bufs: &[IoSlice<'_>],
140    ) -> Poll<io::Result<usize>> {
141        if self.inner.is_write_vectored() {
142            let total_len = bufs
143                .iter()
144                .fold(0usize, |acc, b| acc.saturating_add(b.len()));
145            if total_len > self.buf.capacity() - self.buf.len() {
146                ready!(self.as_mut().flush_buf(cx))?;
147            }
148            let me = self.as_mut().project();
149            if total_len >= me.buf.capacity() {
150                me.inner.poll_write_vectored(cx, bufs)
155            } else {
156                bufs.iter().for_each(|b| me.buf.extend_from_slice(b));
157                Poll::Ready(Ok(total_len))
158            }
159        } else {
160            while bufs.first().map(|buf| buf.len()) == Some(0) {
162                bufs = &bufs[1..];
163            }
164            if bufs.is_empty() {
165                return Poll::Ready(Ok(0));
166            }
167            let first_len = bufs[0].len();
169            if first_len > self.buf.capacity() - self.buf.len() {
170                ready!(self.as_mut().flush_buf(cx))?;
171                debug_assert!(self.buf.is_empty());
172            }
173            let me = self.as_mut().project();
174            if first_len >= me.buf.capacity() {
175                debug_assert!(me.buf.is_empty());
178                return me.inner.poll_write(cx, &bufs[0]);
179            } else {
180                me.buf.extend_from_slice(&bufs[0]);
181                bufs = &bufs[1..];
182            }
183            let mut total_written = first_len;
184            debug_assert!(total_written != 0);
185            for buf in bufs {
187                if buf.len() > me.buf.capacity() - me.buf.len() {
188                    break;
189                } else {
190                    me.buf.extend_from_slice(buf);
191                    total_written += buf.len();
192                }
193            }
194            Poll::Ready(Ok(total_written))
195        }
196    }
197
198    fn is_write_vectored(&self) -> bool {
199        true
200    }
201
202    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
203        ready!(self.as_mut().flush_buf(cx))?;
204        self.get_pin_mut().poll_flush(cx)
205    }
206
207    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
208        ready!(self.as_mut().flush_buf(cx))?;
209        self.get_pin_mut().poll_shutdown(cx)
210    }
211}
212
213#[derive(Debug, Clone, Copy)]
214pub(super) enum SeekState {
215    Init,
217    Start(SeekFrom),
219    Pending,
221}
222
223impl<W: AsyncWrite + AsyncSeek> AsyncSeek for BufWriter<W> {
227    fn start_seek(self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> {
228        *self.project().seek_state = SeekState::Start(pos);
232        Ok(())
233    }
234
235    fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
236        let pos = match self.seek_state {
237            SeekState::Init => {
238                return self.project().inner.poll_complete(cx);
239            }
240            SeekState::Start(pos) => Some(pos),
241            SeekState::Pending => None,
242        };
243
244        ready!(self.as_mut().flush_buf(cx))?;
246
247        let mut me = self.project();
248        if let Some(pos) = pos {
249            ready!(me.inner.as_mut().poll_complete(cx))?;
251            if let Err(e) = me.inner.as_mut().start_seek(pos) {
252                *me.seek_state = SeekState::Init;
253                return Poll::Ready(Err(e));
254            }
255        }
256        match me.inner.poll_complete(cx) {
257            Poll::Ready(res) => {
258                *me.seek_state = SeekState::Init;
259                Poll::Ready(res)
260            }
261            Poll::Pending => {
262                *me.seek_state = SeekState::Pending;
263                Poll::Pending
264            }
265        }
266    }
267}
268
269impl<W: AsyncWrite + AsyncRead> AsyncRead for BufWriter<W> {
270    fn poll_read(
271        self: Pin<&mut Self>,
272        cx: &mut Context<'_>,
273        buf: &mut ReadBuf<'_>,
274    ) -> Poll<io::Result<()>> {
275        self.get_pin_mut().poll_read(cx, buf)
276    }
277}
278
279impl<W: AsyncWrite + AsyncBufRead> AsyncBufRead for BufWriter<W> {
280    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
281        self.get_pin_mut().poll_fill_buf(cx)
282    }
283
284    fn consume(self: Pin<&mut Self>, amt: usize) {
285        self.get_pin_mut().consume(amt);
286    }
287}
288
289impl<W: fmt::Debug> fmt::Debug for BufWriter<W> {
290    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
291        f.debug_struct("BufWriter")
292            .field("writer", &self.inner)
293            .field(
294                "buffer",
295                &format_args!("{}/{}", self.buf.len(), self.buf.capacity()),
296            )
297            .field("written", &self.written)
298            .finish()
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    #[test]
307    fn assert_unpin() {
308        crate::is_unpin::<BufWriter<()>>();
309    }
310}