mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-31 03:50:37 +00:00
WIP: avoid impls of trait Buffer
This commit is contained in:
@@ -1,7 +1,5 @@
|
||||
use std::mem::MaybeUninit;
|
||||
|
||||
use crate::virtual_file::owned_buffers_io;
|
||||
|
||||
pub struct Buf<const N: usize> {
|
||||
allocation: Box<[u8; N]>,
|
||||
written: usize,
|
||||
@@ -23,6 +21,7 @@ impl<const N: usize> Buf<N> {
|
||||
#[inline(always)]
|
||||
fn invariants(&self) {
|
||||
debug_assert!(self.written <= N, "{}", self.written);
|
||||
debug_assert!(self.allocation[self.written..N].iter().all(|v| *v == 0));
|
||||
}
|
||||
|
||||
pub fn as_zero_padded_slice(&self) -> &[u8; N] {
|
||||
@@ -44,32 +43,25 @@ unsafe impl<const N: usize> tokio_epoll_uring::IoBuf for Buf<N> {
|
||||
}
|
||||
|
||||
fn bytes_total(&self) -> usize {
|
||||
self.written // ?
|
||||
N
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> owned_buffers_io::write::Buffer for Buf<N> {
|
||||
const BUFFER_SIZE: usize = N;
|
||||
|
||||
/// panics if there's not enough capacity left
|
||||
fn extend_from_slice(&mut self, buf: &[u8]) {
|
||||
self.invariants();
|
||||
let can = N - self.written;
|
||||
let want = buf.len();
|
||||
assert!(want <= can, "{:x} {:x}", want, can);
|
||||
self.allocation[self.written..(self.written + want)].copy_from_slice(buf);
|
||||
self.written += want;
|
||||
self.invariants();
|
||||
/// SAFETY:
|
||||
///
|
||||
/// The [`Self::allocation`] is stable becauses boxes are stable.
|
||||
///
|
||||
unsafe impl<const N: usize> tokio_epoll_uring::IoBufMut for Buf<N> {
|
||||
fn stable_mut_ptr(&mut self) -> *mut u8 {
|
||||
self.allocation.as_mut_ptr()
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.written
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
unsafe fn set_init(&mut self, pos: usize) {
|
||||
self.invariants();
|
||||
self.written = 0;
|
||||
self.allocation[..].fill(0);
|
||||
if pos < self.written {
|
||||
self.allocation[pos..self.written].fill(0);
|
||||
}
|
||||
self.written = pos;
|
||||
self.invariants();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ use std::collections::HashSet;
|
||||
use std::future::Future;
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use bytes::BytesMut;
|
||||
use camino::{Utf8Path, Utf8PathBuf};
|
||||
use pageserver_api::shard::TenantShardId;
|
||||
use tokio::fs::{self, File, OpenOptions};
|
||||
@@ -196,7 +197,7 @@ async fn download_object<'a>(
|
||||
let size_tracking = size_tracking_writer::Writer::new(destination_file);
|
||||
let mut buffered = owned_buffers_io::write::BufferedWriter::new(
|
||||
size_tracking,
|
||||
owned_buffers_io::write::BytesMutBuffer::<{ super::BUFFER_SIZE }>::new(),
|
||||
BytesMut::with_capacity(super::BUFFER_SIZE),
|
||||
);
|
||||
while let Some(res) =
|
||||
futures::StreamExt::next(&mut download.download_stream).await
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use bytes::BytesMut;
|
||||
use tokio_epoll_uring::{BoundedBuf, IoBuf, Slice};
|
||||
|
||||
/// A trait for doing owned-buffer write IO.
|
||||
@@ -10,53 +9,35 @@ pub trait OwnedAsyncWriter {
|
||||
) -> std::io::Result<(usize, B::Buf)>;
|
||||
}
|
||||
|
||||
pub trait Buffer {
|
||||
const BUFFER_SIZE: usize;
|
||||
|
||||
fn len(&self) -> usize;
|
||||
pub trait IoBufMutExt: tokio_epoll_uring::IoBufMut {
|
||||
fn clear(&mut self);
|
||||
fn extend_from_slice(&mut self, other: &[u8]);
|
||||
}
|
||||
|
||||
pub struct BytesMutBuffer<const BUFFER_SIZE: usize> {
|
||||
buf: BytesMut,
|
||||
}
|
||||
|
||||
/// SAFETY: just forwards to the pre-existing impl for BytesMut
|
||||
unsafe impl<const BUFFER_SIZE: usize> IoBuf for BytesMutBuffer<BUFFER_SIZE> {
|
||||
fn stable_ptr(&self) -> *const u8 {
|
||||
IoBuf::stable_ptr(&self.buf)
|
||||
}
|
||||
|
||||
fn bytes_init(&self) -> usize {
|
||||
IoBuf::bytes_init(&self.buf)
|
||||
}
|
||||
|
||||
fn bytes_total(&self) -> usize {
|
||||
IoBuf::bytes_total(&self.buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const BUFFER_SIZE: usize> Buffer for BytesMutBuffer<BUFFER_SIZE> {
|
||||
const BUFFER_SIZE: usize = BUFFER_SIZE;
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.buf.len()
|
||||
}
|
||||
|
||||
impl<T> IoBufMutExt for T
|
||||
where
|
||||
T: tokio_epoll_uring::IoBufMut,
|
||||
{
|
||||
fn clear(&mut self) {
|
||||
self.buf.clear()
|
||||
// SAFETY: setting to 0 is always safe
|
||||
unsafe { self.set_init(0) }
|
||||
}
|
||||
|
||||
fn extend_from_slice(&mut self, other: &[u8]) {
|
||||
self.buf.extend_from_slice(other)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const BUFFER_SIZE: usize> BytesMutBuffer<BUFFER_SIZE> {
|
||||
pub fn new() -> Self {
|
||||
BytesMutBuffer {
|
||||
buf: BytesMut::with_capacity(BUFFER_SIZE),
|
||||
let remaining = self
|
||||
.bytes_total()
|
||||
.checked_sub(self.bytes_init())
|
||||
.expect("no method on self should allow bytes_init() to exceed bytes_total()");
|
||||
if other.len() > remaining {
|
||||
panic!("extend_from_slice() would extend beyond buffer capacity; remaining={} other.len()={}", remaining, other.len());
|
||||
}
|
||||
// SAFETY: we did bounds-checking above; non-overlapping is guaranteed by Rust borrowing
|
||||
// (self is borrowed mut, so `other` can't reference self or anything in it).
|
||||
unsafe {
|
||||
self.stable_mut_ptr()
|
||||
.add(self.bytes_init())
|
||||
.copy_from_nonoverlapping(other.as_ptr(), other.len());
|
||||
self.set_init(self.bytes_init() + other.len());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -88,7 +69,7 @@ pub struct BufferedWriter<B, W> {
|
||||
|
||||
impl<B, W> BufferedWriter<B, W>
|
||||
where
|
||||
B: Buffer + IoBuf + Send,
|
||||
B: IoBufMutExt + Send,
|
||||
W: OwnedAsyncWriter,
|
||||
{
|
||||
pub fn new(writer: W, buffer: B) -> Self {
|
||||
@@ -114,38 +95,39 @@ where
|
||||
Ok(writer)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn buf(&self) -> &B {
|
||||
self.buf
|
||||
.as_ref()
|
||||
.expect("must not use after we returned an error")
|
||||
}
|
||||
|
||||
pub async fn write_buffered<S: IoBuf>(&mut self, chunk: Slice<S>) -> std::io::Result<(usize, S)>
|
||||
where
|
||||
S: IoBuf + Send,
|
||||
{
|
||||
let chunk_len = chunk.len();
|
||||
let chunk_len = chunk.bytes_init();
|
||||
// avoid memcpy for the middle of the chunk
|
||||
if chunk.len() >= B::BUFFER_SIZE {
|
||||
if chunk.bytes_init() >= self.buf().bytes_total() {
|
||||
self.flush().await?;
|
||||
// do a big write, bypassing `buf`
|
||||
assert_eq!(
|
||||
self.buf
|
||||
.as_ref()
|
||||
.expect("must not use after an error")
|
||||
.len(),
|
||||
0
|
||||
);
|
||||
assert_eq!(self.buf().bytes_init(), 0);
|
||||
let (nwritten, chunk) = self.writer.write_all(chunk).await?;
|
||||
assert_eq!(nwritten, chunk_len);
|
||||
return Ok((nwritten, chunk));
|
||||
}
|
||||
// in-memory copy the < BUFFER_SIZED tail of the chunk
|
||||
assert!(chunk.len() < B::BUFFER_SIZE);
|
||||
assert!(chunk.len() < self.buf().bytes_total());
|
||||
let mut slice = &chunk[..];
|
||||
while !slice.is_empty() {
|
||||
let buf = self.buf.as_mut().expect("must not use after an error");
|
||||
let need = B::BUFFER_SIZE - buf.len();
|
||||
let need = buf.bytes_total() - buf.bytes_init();
|
||||
let have = slice.len();
|
||||
let n = std::cmp::min(need, have);
|
||||
buf.extend_from_slice(&slice[..n]);
|
||||
slice = &slice[n..];
|
||||
if buf.len() >= B::BUFFER_SIZE {
|
||||
assert_eq!(buf.len(), B::BUFFER_SIZE);
|
||||
if buf.bytes_init() >= buf.bytes_total() {
|
||||
assert_eq!(buf.bytes_init(), buf.bytes_total());
|
||||
self.flush().await?;
|
||||
}
|
||||
}
|
||||
@@ -162,13 +144,13 @@ where
|
||||
let chunk_len = chunk.len();
|
||||
while !chunk.is_empty() {
|
||||
let buf = self.buf.as_mut().expect("must not use after an error");
|
||||
let need = B::BUFFER_SIZE - buf.len();
|
||||
let need = buf.bytes_total() - buf.bytes_init();
|
||||
let have = chunk.len();
|
||||
let n = std::cmp::min(need, have);
|
||||
buf.extend_from_slice(&chunk[..n]);
|
||||
chunk = &chunk[n..];
|
||||
if buf.len() >= B::BUFFER_SIZE {
|
||||
assert_eq!(buf.len(), B::BUFFER_SIZE);
|
||||
if buf.bytes_init() >= buf.bytes_total() {
|
||||
assert_eq!(buf.bytes_init(), buf.bytes_total());
|
||||
self.flush().await?;
|
||||
}
|
||||
}
|
||||
@@ -181,7 +163,7 @@ where
|
||||
self.buf = Some(buf);
|
||||
return std::io::Result::Ok(());
|
||||
}
|
||||
let buf_len = buf.len();
|
||||
let buf_len = buf.bytes_init();
|
||||
let (nwritten, mut buf) = self.writer.write_all(buf).await?;
|
||||
assert_eq!(nwritten, buf_len);
|
||||
buf.clear();
|
||||
@@ -207,6 +189,8 @@ impl OwnedAsyncWriter for Vec<u8> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use bytes::BytesMut;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Default)]
|
||||
@@ -240,7 +224,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_buffered_writes_only() -> std::io::Result<()> {
|
||||
let recorder = RecorderWriter::default();
|
||||
let mut writer = BufferedWriter::new(recorder, BytesMutBuffer::<2>::new());
|
||||
let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
|
||||
write!(writer, b"a");
|
||||
write!(writer, b"b");
|
||||
write!(writer, b"c");
|
||||
@@ -257,7 +241,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_passthrough_writes_only() -> std::io::Result<()> {
|
||||
let recorder = RecorderWriter::default();
|
||||
let mut writer = BufferedWriter::new(recorder, BytesMutBuffer::<2>::new());
|
||||
let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
|
||||
write!(writer, b"abc");
|
||||
write!(writer, b"de");
|
||||
write!(writer, b"");
|
||||
@@ -273,7 +257,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_passthrough_write_with_nonempty_buffer() -> std::io::Result<()> {
|
||||
let recorder = RecorderWriter::default();
|
||||
let mut writer = BufferedWriter::new(recorder, BytesMutBuffer::<2>::new());
|
||||
let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
|
||||
write!(writer, b"a");
|
||||
write!(writer, b"bc");
|
||||
write!(writer, b"d");
|
||||
@@ -289,7 +273,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_write_all_borrowed_always_goes_through_buffer() -> std::io::Result<()> {
|
||||
let recorder = RecorderWriter::default();
|
||||
let mut writer = BufferedWriter::new(recorder, BytesMutBuffer::<2>::new());
|
||||
let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
|
||||
|
||||
writer.write_buffered_borrowed(b"abc").await?;
|
||||
writer.write_buffered_borrowed(b"d").await?;
|
||||
|
||||
Reference in New Issue
Block a user