use pin_project_lite::pin_project; use std::pin::Pin; use std::{io, task}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; pin_project! { /// This stream tracks all writes and calls user provided /// callback when the underlying stream is flushed. pub struct MeasuredStream { #[pin] stream: S, write_count: usize, inc_read_count: R, inc_write_count: W, } } impl MeasuredStream { pub fn new(stream: S, inc_read_count: R, inc_write_count: W) -> Self { Self { stream, write_count: 0, inc_read_count, inc_write_count, } } } impl AsyncRead for MeasuredStream { fn poll_read( self: Pin<&mut Self>, context: &mut task::Context<'_>, buf: &mut ReadBuf<'_>, ) -> task::Poll> { let this = self.project(); let filled = buf.filled().len(); this.stream.poll_read(context, buf).map_ok(|()| { let cnt = buf.filled().len() - filled; // Increment the read count. (this.inc_read_count)(cnt); }) } } impl AsyncWrite for MeasuredStream { fn poll_write( self: Pin<&mut Self>, context: &mut task::Context<'_>, buf: &[u8], ) -> task::Poll> { let this = self.project(); this.stream.poll_write(context, buf).map_ok(|cnt| { // Increment the write count. *this.write_count += cnt; cnt }) } fn poll_flush( self: Pin<&mut Self>, context: &mut task::Context<'_>, ) -> task::Poll> { let this = self.project(); this.stream.poll_flush(context).map_ok(|()| { // Call the user provided callback and reset the write count. (this.inc_write_count)(*this.write_count); *this.write_count = 0; }) } fn poll_shutdown( self: Pin<&mut Self>, context: &mut task::Context<'_>, ) -> task::Poll> { self.project().stream.poll_shutdown(context) } }