WIP: avoid impls of trait Buffer

This commit is contained in:
Christian Schwarz
2024-04-23 09:36:47 +00:00
parent 585a522ff4
commit 06cdbdff3d
3 changed files with 62 additions and 85 deletions

View File

@@ -1,7 +1,5 @@
use std::mem::MaybeUninit;
use crate::virtual_file::owned_buffers_io;
pub struct Buf<const N: usize> {
allocation: Box<[u8; N]>,
written: usize,
@@ -23,6 +21,7 @@ impl<const N: usize> Buf<N> {
#[inline(always)]
fn invariants(&self) {
debug_assert!(self.written <= N, "{}", self.written);
debug_assert!(self.allocation[self.written..N].iter().all(|v| *v == 0));
}
pub fn as_zero_padded_slice(&self) -> &[u8; N] {
@@ -44,32 +43,25 @@ unsafe impl<const N: usize> tokio_epoll_uring::IoBuf for Buf<N> {
}
fn bytes_total(&self) -> usize {
self.written // ?
N
}
}
impl<const N: usize> owned_buffers_io::write::Buffer for Buf<N> {
const BUFFER_SIZE: usize = N;
/// panics if there's not enough capacity left
fn extend_from_slice(&mut self, buf: &[u8]) {
self.invariants();
let can = N - self.written;
let want = buf.len();
assert!(want <= can, "{:x} {:x}", want, can);
self.allocation[self.written..(self.written + want)].copy_from_slice(buf);
self.written += want;
self.invariants();
/// SAFETY:
///
/// The [`Self::allocation`] is stable becauses boxes are stable.
///
unsafe impl<const N: usize> tokio_epoll_uring::IoBufMut for Buf<N> {
fn stable_mut_ptr(&mut self) -> *mut u8 {
self.allocation.as_mut_ptr()
}
fn len(&self) -> usize {
self.written
}
fn clear(&mut self) {
unsafe fn set_init(&mut self, pos: usize) {
self.invariants();
self.written = 0;
self.allocation[..].fill(0);
if pos < self.written {
self.allocation[pos..self.written].fill(0);
}
self.written = pos;
self.invariants();
}
}

View File

@@ -7,6 +7,7 @@ use std::collections::HashSet;
use std::future::Future;
use anyhow::{anyhow, Context};
use bytes::BytesMut;
use camino::{Utf8Path, Utf8PathBuf};
use pageserver_api::shard::TenantShardId;
use tokio::fs::{self, File, OpenOptions};
@@ -196,7 +197,7 @@ async fn download_object<'a>(
let size_tracking = size_tracking_writer::Writer::new(destination_file);
let mut buffered = owned_buffers_io::write::BufferedWriter::new(
size_tracking,
owned_buffers_io::write::BytesMutBuffer::<{ super::BUFFER_SIZE }>::new(),
BytesMut::with_capacity(super::BUFFER_SIZE),
);
while let Some(res) =
futures::StreamExt::next(&mut download.download_stream).await

View File

@@ -1,4 +1,3 @@
use bytes::BytesMut;
use tokio_epoll_uring::{BoundedBuf, IoBuf, Slice};
/// A trait for doing owned-buffer write IO.
@@ -10,53 +9,35 @@ pub trait OwnedAsyncWriter {
) -> std::io::Result<(usize, B::Buf)>;
}
pub trait Buffer {
const BUFFER_SIZE: usize;
fn len(&self) -> usize;
pub trait IoBufMutExt: tokio_epoll_uring::IoBufMut {
fn clear(&mut self);
fn extend_from_slice(&mut self, other: &[u8]);
}
pub struct BytesMutBuffer<const BUFFER_SIZE: usize> {
buf: BytesMut,
}
/// SAFETY: just forwards to the pre-existing impl for BytesMut
unsafe impl<const BUFFER_SIZE: usize> IoBuf for BytesMutBuffer<BUFFER_SIZE> {
fn stable_ptr(&self) -> *const u8 {
IoBuf::stable_ptr(&self.buf)
}
fn bytes_init(&self) -> usize {
IoBuf::bytes_init(&self.buf)
}
fn bytes_total(&self) -> usize {
IoBuf::bytes_total(&self.buf)
}
}
impl<const BUFFER_SIZE: usize> Buffer for BytesMutBuffer<BUFFER_SIZE> {
const BUFFER_SIZE: usize = BUFFER_SIZE;
fn len(&self) -> usize {
self.buf.len()
}
impl<T> IoBufMutExt for T
where
T: tokio_epoll_uring::IoBufMut,
{
fn clear(&mut self) {
self.buf.clear()
// SAFETY: setting to 0 is always safe
unsafe { self.set_init(0) }
}
fn extend_from_slice(&mut self, other: &[u8]) {
self.buf.extend_from_slice(other)
}
}
impl<const BUFFER_SIZE: usize> BytesMutBuffer<BUFFER_SIZE> {
pub fn new() -> Self {
BytesMutBuffer {
buf: BytesMut::with_capacity(BUFFER_SIZE),
let remaining = self
.bytes_total()
.checked_sub(self.bytes_init())
.expect("no method on self should allow bytes_init() to exceed bytes_total()");
if other.len() > remaining {
panic!("extend_from_slice() would extend beyond buffer capacity; remaining={} other.len()={}", remaining, other.len());
}
// SAFETY: we did bounds-checking above; non-overlapping is guaranteed by Rust borrowing
// (self is borrowed mut, so `other` can't reference self or anything in it).
unsafe {
self.stable_mut_ptr()
.add(self.bytes_init())
.copy_from_nonoverlapping(other.as_ptr(), other.len());
self.set_init(self.bytes_init() + other.len());
}
}
}
@@ -88,7 +69,7 @@ pub struct BufferedWriter<B, W> {
impl<B, W> BufferedWriter<B, W>
where
B: Buffer + IoBuf + Send,
B: IoBufMutExt + Send,
W: OwnedAsyncWriter,
{
pub fn new(writer: W, buffer: B) -> Self {
@@ -114,38 +95,39 @@ where
Ok(writer)
}
#[inline(always)]
fn buf(&self) -> &B {
self.buf
.as_ref()
.expect("must not use after we returned an error")
}
pub async fn write_buffered<S: IoBuf>(&mut self, chunk: Slice<S>) -> std::io::Result<(usize, S)>
where
S: IoBuf + Send,
{
let chunk_len = chunk.len();
let chunk_len = chunk.bytes_init();
// avoid memcpy for the middle of the chunk
if chunk.len() >= B::BUFFER_SIZE {
if chunk.bytes_init() >= self.buf().bytes_total() {
self.flush().await?;
// do a big write, bypassing `buf`
assert_eq!(
self.buf
.as_ref()
.expect("must not use after an error")
.len(),
0
);
assert_eq!(self.buf().bytes_init(), 0);
let (nwritten, chunk) = self.writer.write_all(chunk).await?;
assert_eq!(nwritten, chunk_len);
return Ok((nwritten, chunk));
}
// in-memory copy the < BUFFER_SIZED tail of the chunk
assert!(chunk.len() < B::BUFFER_SIZE);
assert!(chunk.len() < self.buf().bytes_total());
let mut slice = &chunk[..];
while !slice.is_empty() {
let buf = self.buf.as_mut().expect("must not use after an error");
let need = B::BUFFER_SIZE - buf.len();
let need = buf.bytes_total() - buf.bytes_init();
let have = slice.len();
let n = std::cmp::min(need, have);
buf.extend_from_slice(&slice[..n]);
slice = &slice[n..];
if buf.len() >= B::BUFFER_SIZE {
assert_eq!(buf.len(), B::BUFFER_SIZE);
if buf.bytes_init() >= buf.bytes_total() {
assert_eq!(buf.bytes_init(), buf.bytes_total());
self.flush().await?;
}
}
@@ -162,13 +144,13 @@ where
let chunk_len = chunk.len();
while !chunk.is_empty() {
let buf = self.buf.as_mut().expect("must not use after an error");
let need = B::BUFFER_SIZE - buf.len();
let need = buf.bytes_total() - buf.bytes_init();
let have = chunk.len();
let n = std::cmp::min(need, have);
buf.extend_from_slice(&chunk[..n]);
chunk = &chunk[n..];
if buf.len() >= B::BUFFER_SIZE {
assert_eq!(buf.len(), B::BUFFER_SIZE);
if buf.bytes_init() >= buf.bytes_total() {
assert_eq!(buf.bytes_init(), buf.bytes_total());
self.flush().await?;
}
}
@@ -181,7 +163,7 @@ where
self.buf = Some(buf);
return std::io::Result::Ok(());
}
let buf_len = buf.len();
let buf_len = buf.bytes_init();
let (nwritten, mut buf) = self.writer.write_all(buf).await?;
assert_eq!(nwritten, buf_len);
buf.clear();
@@ -207,6 +189,8 @@ impl OwnedAsyncWriter for Vec<u8> {
#[cfg(test)]
mod tests {
use bytes::BytesMut;
use super::*;
#[derive(Default)]
@@ -240,7 +224,7 @@ mod tests {
#[tokio::test]
async fn test_buffered_writes_only() -> std::io::Result<()> {
let recorder = RecorderWriter::default();
let mut writer = BufferedWriter::new(recorder, BytesMutBuffer::<2>::new());
let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
write!(writer, b"a");
write!(writer, b"b");
write!(writer, b"c");
@@ -257,7 +241,7 @@ mod tests {
#[tokio::test]
async fn test_passthrough_writes_only() -> std::io::Result<()> {
let recorder = RecorderWriter::default();
let mut writer = BufferedWriter::new(recorder, BytesMutBuffer::<2>::new());
let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
write!(writer, b"abc");
write!(writer, b"de");
write!(writer, b"");
@@ -273,7 +257,7 @@ mod tests {
#[tokio::test]
async fn test_passthrough_write_with_nonempty_buffer() -> std::io::Result<()> {
let recorder = RecorderWriter::default();
let mut writer = BufferedWriter::new(recorder, BytesMutBuffer::<2>::new());
let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
write!(writer, b"a");
write!(writer, b"bc");
write!(writer, b"d");
@@ -289,7 +273,7 @@ mod tests {
#[tokio::test]
async fn test_write_all_borrowed_always_goes_through_buffer() -> std::io::Result<()> {
let recorder = RecorderWriter::default();
let mut writer = BufferedWriter::new(recorder, BytesMutBuffer::<2>::new());
let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
writer.write_buffered_borrowed(b"abc").await?;
writer.write_buffered_borrowed(b"d").await?;