diff --git a/pageserver/src/tenant/ephemeral_file/zero_padded_buffer.rs b/pageserver/src/tenant/ephemeral_file/zero_padded_buffer.rs index fc83377e92..36b66fa65b 100644 --- a/pageserver/src/tenant/ephemeral_file/zero_padded_buffer.rs +++ b/pageserver/src/tenant/ephemeral_file/zero_padded_buffer.rs @@ -1,7 +1,5 @@ use std::mem::MaybeUninit; -use crate::virtual_file::owned_buffers_io; - pub struct Buf { allocation: Box<[u8; N]>, written: usize, @@ -23,6 +21,7 @@ impl Buf { #[inline(always)] fn invariants(&self) { debug_assert!(self.written <= N, "{}", self.written); + debug_assert!(self.allocation[self.written..N].iter().all(|v| *v == 0)); } pub fn as_zero_padded_slice(&self) -> &[u8; N] { @@ -44,32 +43,25 @@ unsafe impl tokio_epoll_uring::IoBuf for Buf { } fn bytes_total(&self) -> usize { - self.written // ? + N } } -impl owned_buffers_io::write::Buffer for Buf { - const BUFFER_SIZE: usize = N; - - /// panics if there's not enough capacity left - fn extend_from_slice(&mut self, buf: &[u8]) { - self.invariants(); - let can = N - self.written; - let want = buf.len(); - assert!(want <= can, "{:x} {:x}", want, can); - self.allocation[self.written..(self.written + want)].copy_from_slice(buf); - self.written += want; - self.invariants(); +/// SAFETY: +/// +/// The [`Self::allocation`] is stable becauses boxes are stable. +/// +unsafe impl tokio_epoll_uring::IoBufMut for Buf { + fn stable_mut_ptr(&mut self) -> *mut u8 { + self.allocation.as_mut_ptr() } - fn len(&self) -> usize { - self.written - } - - fn clear(&mut self) { + unsafe fn set_init(&mut self, pos: usize) { self.invariants(); - self.written = 0; - self.allocation[..].fill(0); + if pos < self.written { + self.allocation[pos..self.written].fill(0); + } + self.written = pos; self.invariants(); } } diff --git a/pageserver/src/tenant/remote_timeline_client/download.rs b/pageserver/src/tenant/remote_timeline_client/download.rs index 56ddb9aa9a..5a005127ec 100644 --- a/pageserver/src/tenant/remote_timeline_client/download.rs +++ b/pageserver/src/tenant/remote_timeline_client/download.rs @@ -7,6 +7,7 @@ use std::collections::HashSet; use std::future::Future; use anyhow::{anyhow, Context}; +use bytes::BytesMut; use camino::{Utf8Path, Utf8PathBuf}; use pageserver_api::shard::TenantShardId; use tokio::fs::{self, File, OpenOptions}; @@ -196,7 +197,7 @@ async fn download_object<'a>( let size_tracking = size_tracking_writer::Writer::new(destination_file); let mut buffered = owned_buffers_io::write::BufferedWriter::new( size_tracking, - owned_buffers_io::write::BytesMutBuffer::<{ super::BUFFER_SIZE }>::new(), + BytesMut::with_capacity(super::BUFFER_SIZE), ); while let Some(res) = futures::StreamExt::next(&mut download.download_stream).await diff --git a/pageserver/src/virtual_file/owned_buffers_io/write.rs b/pageserver/src/virtual_file/owned_buffers_io/write.rs index 1a711aeb22..6efd332849 100644 --- a/pageserver/src/virtual_file/owned_buffers_io/write.rs +++ b/pageserver/src/virtual_file/owned_buffers_io/write.rs @@ -1,4 +1,3 @@ -use bytes::BytesMut; use tokio_epoll_uring::{BoundedBuf, IoBuf, Slice}; /// A trait for doing owned-buffer write IO. @@ -10,53 +9,35 @@ pub trait OwnedAsyncWriter { ) -> std::io::Result<(usize, B::Buf)>; } -pub trait Buffer { - const BUFFER_SIZE: usize; - - fn len(&self) -> usize; +pub trait IoBufMutExt: tokio_epoll_uring::IoBufMut { fn clear(&mut self); fn extend_from_slice(&mut self, other: &[u8]); } -pub struct BytesMutBuffer { - buf: BytesMut, -} - -/// SAFETY: just forwards to the pre-existing impl for BytesMut -unsafe impl IoBuf for BytesMutBuffer { - fn stable_ptr(&self) -> *const u8 { - IoBuf::stable_ptr(&self.buf) - } - - fn bytes_init(&self) -> usize { - IoBuf::bytes_init(&self.buf) - } - - fn bytes_total(&self) -> usize { - IoBuf::bytes_total(&self.buf) - } -} - -impl Buffer for BytesMutBuffer { - const BUFFER_SIZE: usize = BUFFER_SIZE; - - fn len(&self) -> usize { - self.buf.len() - } - +impl IoBufMutExt for T +where + T: tokio_epoll_uring::IoBufMut, +{ fn clear(&mut self) { - self.buf.clear() + // SAFETY: setting to 0 is always safe + unsafe { self.set_init(0) } } fn extend_from_slice(&mut self, other: &[u8]) { - self.buf.extend_from_slice(other) - } -} - -impl BytesMutBuffer { - pub fn new() -> Self { - BytesMutBuffer { - buf: BytesMut::with_capacity(BUFFER_SIZE), + let remaining = self + .bytes_total() + .checked_sub(self.bytes_init()) + .expect("no method on self should allow bytes_init() to exceed bytes_total()"); + if other.len() > remaining { + panic!("extend_from_slice() would extend beyond buffer capacity; remaining={} other.len()={}", remaining, other.len()); + } + // SAFETY: we did bounds-checking above; non-overlapping is guaranteed by Rust borrowing + // (self is borrowed mut, so `other` can't reference self or anything in it). + unsafe { + self.stable_mut_ptr() + .add(self.bytes_init()) + .copy_from_nonoverlapping(other.as_ptr(), other.len()); + self.set_init(self.bytes_init() + other.len()); } } } @@ -88,7 +69,7 @@ pub struct BufferedWriter { impl BufferedWriter where - B: Buffer + IoBuf + Send, + B: IoBufMutExt + Send, W: OwnedAsyncWriter, { pub fn new(writer: W, buffer: B) -> Self { @@ -114,38 +95,39 @@ where Ok(writer) } + #[inline(always)] + fn buf(&self) -> &B { + self.buf + .as_ref() + .expect("must not use after we returned an error") + } + pub async fn write_buffered(&mut self, chunk: Slice) -> std::io::Result<(usize, S)> where S: IoBuf + Send, { - let chunk_len = chunk.len(); + let chunk_len = chunk.bytes_init(); // avoid memcpy for the middle of the chunk - if chunk.len() >= B::BUFFER_SIZE { + if chunk.bytes_init() >= self.buf().bytes_total() { self.flush().await?; // do a big write, bypassing `buf` - assert_eq!( - self.buf - .as_ref() - .expect("must not use after an error") - .len(), - 0 - ); + assert_eq!(self.buf().bytes_init(), 0); let (nwritten, chunk) = self.writer.write_all(chunk).await?; assert_eq!(nwritten, chunk_len); return Ok((nwritten, chunk)); } // in-memory copy the < BUFFER_SIZED tail of the chunk - assert!(chunk.len() < B::BUFFER_SIZE); + assert!(chunk.len() < self.buf().bytes_total()); let mut slice = &chunk[..]; while !slice.is_empty() { let buf = self.buf.as_mut().expect("must not use after an error"); - let need = B::BUFFER_SIZE - buf.len(); + let need = buf.bytes_total() - buf.bytes_init(); let have = slice.len(); let n = std::cmp::min(need, have); buf.extend_from_slice(&slice[..n]); slice = &slice[n..]; - if buf.len() >= B::BUFFER_SIZE { - assert_eq!(buf.len(), B::BUFFER_SIZE); + if buf.bytes_init() >= buf.bytes_total() { + assert_eq!(buf.bytes_init(), buf.bytes_total()); self.flush().await?; } } @@ -162,13 +144,13 @@ where let chunk_len = chunk.len(); while !chunk.is_empty() { let buf = self.buf.as_mut().expect("must not use after an error"); - let need = B::BUFFER_SIZE - buf.len(); + let need = buf.bytes_total() - buf.bytes_init(); let have = chunk.len(); let n = std::cmp::min(need, have); buf.extend_from_slice(&chunk[..n]); chunk = &chunk[n..]; - if buf.len() >= B::BUFFER_SIZE { - assert_eq!(buf.len(), B::BUFFER_SIZE); + if buf.bytes_init() >= buf.bytes_total() { + assert_eq!(buf.bytes_init(), buf.bytes_total()); self.flush().await?; } } @@ -181,7 +163,7 @@ where self.buf = Some(buf); return std::io::Result::Ok(()); } - let buf_len = buf.len(); + let buf_len = buf.bytes_init(); let (nwritten, mut buf) = self.writer.write_all(buf).await?; assert_eq!(nwritten, buf_len); buf.clear(); @@ -207,6 +189,8 @@ impl OwnedAsyncWriter for Vec { #[cfg(test)] mod tests { + use bytes::BytesMut; + use super::*; #[derive(Default)] @@ -240,7 +224,7 @@ mod tests { #[tokio::test] async fn test_buffered_writes_only() -> std::io::Result<()> { let recorder = RecorderWriter::default(); - let mut writer = BufferedWriter::new(recorder, BytesMutBuffer::<2>::new()); + let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2)); write!(writer, b"a"); write!(writer, b"b"); write!(writer, b"c"); @@ -257,7 +241,7 @@ mod tests { #[tokio::test] async fn test_passthrough_writes_only() -> std::io::Result<()> { let recorder = RecorderWriter::default(); - let mut writer = BufferedWriter::new(recorder, BytesMutBuffer::<2>::new()); + let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2)); write!(writer, b"abc"); write!(writer, b"de"); write!(writer, b""); @@ -273,7 +257,7 @@ mod tests { #[tokio::test] async fn test_passthrough_write_with_nonempty_buffer() -> std::io::Result<()> { let recorder = RecorderWriter::default(); - let mut writer = BufferedWriter::new(recorder, BytesMutBuffer::<2>::new()); + let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2)); write!(writer, b"a"); write!(writer, b"bc"); write!(writer, b"d"); @@ -289,7 +273,7 @@ mod tests { #[tokio::test] async fn test_write_all_borrowed_always_goes_through_buffer() -> std::io::Result<()> { let recorder = RecorderWriter::default(); - let mut writer = BufferedWriter::new(recorder, BytesMutBuffer::<2>::new()); + let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2)); writer.write_buffered_borrowed(b"abc").await?; writer.write_buffered_borrowed(b"d").await?;