From 0f7f743f376fcec38dd1f3c039c2a902aa077bed Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Tue, 23 Apr 2024 13:09:00 +0000 Subject: [PATCH] refactor(owned_buffers_io::BufferedWriter): be generic over the type of buffer --- .../tenant/remote_timeline_client/download.rs | 9 +- .../virtual_file/owned_buffers_io/write.rs | 147 +++++++++++++----- 2 files changed, 110 insertions(+), 46 deletions(-) diff --git a/pageserver/src/tenant/remote_timeline_client/download.rs b/pageserver/src/tenant/remote_timeline_client/download.rs index 6ee8ad7155..517d02f29a 100644 --- a/pageserver/src/tenant/remote_timeline_client/download.rs +++ b/pageserver/src/tenant/remote_timeline_client/download.rs @@ -7,6 +7,7 @@ use std::collections::HashSet; use std::future::Future; use anyhow::{anyhow, Context}; +use bytes::BytesMut; use camino::{Utf8Path, Utf8PathBuf}; use pageserver_api::shard::TenantShardId; use tokio::fs::{self, File, OpenOptions}; @@ -194,10 +195,10 @@ async fn download_object<'a>( // There's chunks_vectored() on the stream. let (bytes_amount, destination_file) = async { let size_tracking = size_tracking_writer::Writer::new(destination_file); - let mut buffered = owned_buffers_io::write::BufferedWriter::< - { super::BUFFER_SIZE }, - _, - >::new(size_tracking); + let mut buffered = owned_buffers_io::write::BufferedWriter::::new( + size_tracking, + BytesMut::with_capacity(super::BUFFER_SIZE), + ); while let Some(res) = futures::StreamExt::next(&mut download.download_stream).await { diff --git a/pageserver/src/virtual_file/owned_buffers_io/write.rs b/pageserver/src/virtual_file/owned_buffers_io/write.rs index f1812d9b51..6b3a02c71a 100644 --- a/pageserver/src/virtual_file/owned_buffers_io/write.rs +++ b/pageserver/src/virtual_file/owned_buffers_io/write.rs @@ -10,14 +10,14 @@ pub trait OwnedAsyncWriter { ) -> std::io::Result<(usize, B::Buf)>; } -/// A wrapper aorund an [`OwnedAsyncWriter`] that batches smaller writers -/// into `BUFFER_SIZE`-sized writes. +/// A wrapper aorund an [`OwnedAsyncWriter`] that uses a [`Buffer`] to batch +/// small writes into larger writes of size [`Buffer::cap`]. /// /// # Passthrough Of Large Writers /// -/// Buffered writes larger than the `BUFFER_SIZE` cause the internal -/// buffer to be flushed, even if it is not full yet. Then, the large -/// buffered write is passed through to the unerlying [`OwnedAsyncWriter`]. +/// Calls to [`BufferedWriter::write_buffered`] that are larger than [`Buffer::cap`] +/// cause the internal buffer to be flushed prematurely so that the large +/// buffered write is passed through to the underlying [`OwnedAsyncWriter`]. /// /// This pass-through is generally beneficial for throughput, but if /// the storage backend of the [`OwnedAsyncWriter`] is a shared resource, @@ -25,24 +25,25 @@ pub trait OwnedAsyncWriter { /// /// In such cases, a different implementation that always buffers in memory /// may be preferable. -pub struct BufferedWriter { +pub struct BufferedWriter { writer: W, - // invariant: always remains Some(buf) - // with buf.capacity() == BUFFER_SIZE except - // - while IO is ongoing => goes back to Some() once the IO completed successfully - // - after an IO error => stays `None` forever - // In these exceptional cases, it's `None`. - buf: Option, + /// invariant: always remains Some(buf) except + /// - while IO is ongoing => goes back to Some() once the IO completed successfully + /// - after an IO error => stays `None` forever + /// In these exceptional cases, it's `None`. + buf: Option, } -impl BufferedWriter +impl BufferedWriter where + B: Buffer + Send, + Buf: IoBuf + Send, W: OwnedAsyncWriter, { - pub fn new(writer: W) -> Self { + pub fn new(writer: W, buf: B) -> Self { Self { writer, - buf: Some(BytesMut::with_capacity(BUFFER_SIZE)), + buf: Some(buf), } } @@ -53,61 +54,121 @@ where Ok(writer) } - pub async fn write_buffered(&mut self, chunk: Slice) -> std::io::Result<()> + #[inline(always)] + fn buf(&self) -> &B { + self.buf + .as_ref() + .expect("must not use after we returned an error") + } + + pub async fn write_buffered(&mut self, chunk: Slice) -> std::io::Result<(usize, S)> where - B: IoBuf + Send, + S: IoBuf + Send, { + let chunk_len = chunk.len(); // avoid memcpy for the middle of the chunk - if chunk.len() >= BUFFER_SIZE { + if chunk.len() >= self.buf().cap() { self.flush().await?; // do a big write, bypassing `buf` assert_eq!( self.buf .as_ref() .expect("must not use after an error") - .len(), + .pending(), 0 ); - let chunk_len = chunk.len(); let (nwritten, chunk) = self.writer.write_all(chunk).await?; assert_eq!(nwritten, chunk_len); - drop(chunk); - return Ok(()); + return Ok((nwritten, chunk)); } // in-memory copy the < BUFFER_SIZED tail of the chunk - assert!(chunk.len() < BUFFER_SIZE); - let mut chunk = &chunk[..]; - while !chunk.is_empty() { + assert!(chunk.len() < self.buf().cap()); + let mut slice = &chunk[..]; + while !slice.is_empty() { let buf = self.buf.as_mut().expect("must not use after an error"); - let need = BUFFER_SIZE - buf.len(); - let have = chunk.len(); + let need = buf.cap() - buf.pending(); + let have = slice.len(); let n = std::cmp::min(need, have); - buf.extend_from_slice(&chunk[..n]); - chunk = &chunk[n..]; - if buf.len() >= BUFFER_SIZE { - assert_eq!(buf.len(), BUFFER_SIZE); + buf.extend_from_slice(&slice[..n]); + slice = &slice[n..]; + if buf.pending() >= buf.cap() { + assert_eq!(buf.pending(), buf.cap()); self.flush().await?; } } - assert!(chunk.is_empty(), "by now we should have drained the chunk"); - Ok(()) + assert!(slice.is_empty(), "by now we should have drained the chunk"); + Ok((chunk_len, chunk.into_inner())) } async fn flush(&mut self) -> std::io::Result<()> { let buf = self.buf.take().expect("must not use after an error"); - if buf.is_empty() { + let buf_len = buf.pending(); + if buf_len == 0 { self.buf = Some(buf); - return std::io::Result::Ok(()); + return Ok(()); } - let buf_len = buf.len(); - let (nwritten, mut buf) = self.writer.write_all(buf).await?; + let (nwritten, io_buf) = self.writer.write_all(buf.flush()).await?; assert_eq!(nwritten, buf_len); - buf.clear(); - self.buf = Some(buf); + self.buf = Some(Buffer::reuse_after_flush(io_buf)); Ok(()) } } +/// A [`Buffer`] is used by [`BufferedWriter`] to batch smaller writes into larger ones. +pub trait Buffer { + type IoBuf: IoBuf; + + /// Capacity of the buffer. Must not change over the lifetime `self`.` + fn cap(&self) -> usize; + + /// Add data to the buffer. + /// Panics if there is not enough room to accomodate `other`'s content, i.e., + /// panics if `other.len() > self.cap() - self.pending()`. + fn extend_from_slice(&mut self, other: &[u8]); + + /// Number of bytes in the buffer. + fn pending(&self) -> usize; + + /// Turns `self` into a [`tokio_epoll_uring::Slice`] of the pending data + /// so we can use [`tokio_epoll_uring`] to write it to disk. + fn flush(self) -> Slice; + + /// After the write to disk is done and we have gotten back the slice, + /// [`BufferedWriter`] uses this method to re-use the io buffer. + fn reuse_after_flush(iobuf: Self::IoBuf) -> Self; +} + +impl Buffer for BytesMut { + type IoBuf = BytesMut; + + #[inline(always)] + fn cap(&self) -> usize { + self.capacity() + } + + fn extend_from_slice(&mut self, other: &[u8]) { + BytesMut::extend_from_slice(self, other) + } + + #[inline(always)] + fn pending(&self) -> usize { + self.len() + } + + fn flush(self) -> Slice { + if self.is_empty() { + return self.slice_full(); + } + let len = self.len(); + self.slice(0..len) + } + + fn reuse_after_flush(mut iobuf: BytesMut) -> Self { + iobuf.clear(); + iobuf + } +} + impl OwnedAsyncWriter for Vec { async fn write_all, Buf: IoBuf + Send>( &mut self, @@ -125,6 +186,8 @@ impl OwnedAsyncWriter for Vec { #[cfg(test)] mod tests { + use bytes::BytesMut; + use super::*; #[derive(Default)] @@ -158,7 +221,7 @@ mod tests { #[tokio::test] async fn test_buffered_writes_only() -> std::io::Result<()> { let recorder = RecorderWriter::default(); - let mut writer = BufferedWriter::<2, _>::new(recorder); + let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2)); write!(writer, b"a"); write!(writer, b"b"); write!(writer, b"c"); @@ -175,7 +238,7 @@ mod tests { #[tokio::test] async fn test_passthrough_writes_only() -> std::io::Result<()> { let recorder = RecorderWriter::default(); - let mut writer = BufferedWriter::<2, _>::new(recorder); + let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2)); write!(writer, b"abc"); write!(writer, b"de"); write!(writer, b""); @@ -191,7 +254,7 @@ mod tests { #[tokio::test] async fn test_passthrough_write_with_nonempty_buffer() -> std::io::Result<()> { let recorder = RecorderWriter::default(); - let mut writer = BufferedWriter::<2, _>::new(recorder); + let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2)); write!(writer, b"a"); write!(writer, b"bc"); write!(writer, b"d");