implement blob-writer io functionalities

Signed-off-by: Yuchen Liang <yuchen@neon.tech>
This commit is contained in:
Yuchen Liang
2024-12-06 21:50:53 +00:00
parent 0bab7e3086
commit d079bf1d48
10 changed files with 281 additions and 197 deletions

View File

@@ -1084,7 +1084,7 @@ pub mod virtual_file {
impl IoMode {
pub const fn preferred() -> Self {
Self::Buffered
Self::Direct
}
}

View File

@@ -18,16 +18,18 @@ use async_compression::Level;
use bytes::{BufMut, BytesMut};
use pageserver_api::models::ImageCompressionAlgorithm;
use tokio::io::AsyncWriteExt;
use tokio_epoll_uring::{BoundedBuf, IoBuf, Slice};
use tokio_epoll_uring::IoBuf;
use tracing::warn;
use crate::context::RequestContext;
use crate::page_cache::PAGE_SZ;
use crate::tenant::block_io::BlockCursor;
use crate::virtual_file::owned_buffers_io::io_buf_ext::{FullSlice, IoBufExt};
use crate::virtual_file::VirtualFile;
use crate::virtual_file::owned_buffers_io::write::BufferedWriter;
use crate::virtual_file::{IoBufferMut, VirtualFile};
use std::cmp::min;
use std::io::{Error, ErrorKind};
use std::sync::Arc;
#[derive(Copy, Clone, Debug)]
pub struct CompressionInfo {
@@ -158,135 +160,160 @@ pub(super) const BYTE_ZSTD: u8 = BYTE_UNCOMPRESSED | 0x10;
/// If a `BlobWriter` is dropped, the internal buffer will be
/// discarded. You need to call [`flush_buffer`](Self::flush_buffer)
/// manually before dropping.
pub struct BlobWriter<const BUFFERED: bool> {
inner: VirtualFile,
offset: u64,
/// A buffer to save on write calls, only used if BUFFERED=true
buf: Vec<u8>,
pub struct BlobWriter {
/// We do tiny writes for the length headers; they need to be in an owned buffer;
io_buf: Option<BytesMut>,
writer: BufferedWriter<IoBufferMut, VirtualFile>,
offset: u64,
}
impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
pub fn new(inner: VirtualFile, start_offset: u64) -> Self {
Self {
inner,
offset: start_offset,
buf: Vec::with_capacity(Self::CAPACITY),
impl BlobWriter {
pub fn new(
file: Arc<VirtualFile>,
start_offset: u64,
gate: &utils::sync::gate::Gate,
ctx: &RequestContext,
) -> anyhow::Result<Self> {
Ok(Self {
io_buf: Some(BytesMut::new()),
}
writer: BufferedWriter::new(
file,
start_offset,
|| IoBufferMut::with_capacity(Self::CAPACITY),
gate.enter()?,
ctx,
),
offset: start_offset,
})
}
pub fn size(&self) -> u64 {
self.offset
}
const CAPACITY: usize = if BUFFERED { 64 * 1024 } else { 0 };
const CAPACITY: usize = 64 * 1024;
/// Writes the given buffer directly to the underlying `VirtualFile`.
/// You need to make sure that the internal buffer is empty, otherwise
/// data will be written in wrong order.
#[inline(always)]
async fn write_all_unbuffered<Buf: IoBuf + Send>(
&mut self,
src_buf: FullSlice<Buf>,
ctx: &RequestContext,
) -> (FullSlice<Buf>, Result<(), Error>) {
let (src_buf, res) = self.inner.write_all(src_buf, ctx).await;
let nbytes = match res {
Ok(nbytes) => nbytes,
Err(e) => return (src_buf, Err(e)),
};
self.offset += nbytes as u64;
(src_buf, Ok(()))
}
// #[inline(always)]
// async fn write_all_unbuffered<Buf: IoBuf + Send>(
// &mut self,
// src_buf: FullSlice<Buf>,
// ctx: &RequestContext,
// ) -> (FullSlice<Buf>, Result<(), Error>) {
// let (src_buf, res) = self.inner.write_all_at(src_buf, self.offset, ctx).await;
// let nbytes = match res {
// Ok(nbytes) => nbytes,
// Err(e) => return (src_buf, Err(e)),
// };
// self.offset += nbytes as u64;
// (src_buf, Ok(()))
// }
#[inline(always)]
/// Flushes the internal buffer to the underlying `VirtualFile`.
pub async fn flush_buffer(&mut self, ctx: &RequestContext) -> Result<(), Error> {
let buf = std::mem::take(&mut self.buf);
let (slice, res) = self.inner.write_all(buf.slice_len(), ctx).await;
res?;
let mut buf = slice.into_raw_slice().into_inner();
buf.clear();
self.buf = buf;
Ok(())
}
// #[inline(always)]
// /// Flushes the internal buffer to the underlying `VirtualFile`.
// async fn flush_buffer(&mut self, ctx: &RequestContext) -> Result<(), Error> {
// let buf = std::mem::take(&mut self.buf);
// let (slice, res) = self.inner.write_all(buf.slice_len(), ctx).await;
// res?;
// let mut buf = slice.into_raw_slice().into_inner();
// buf.clear();
// self.buf = buf;
// Ok(())
// }
#[inline(always)]
/// Writes as much of `src_buf` into the internal buffer as it fits
fn write_into_buffer(&mut self, src_buf: &[u8]) -> usize {
let remaining = Self::CAPACITY - self.buf.len();
let to_copy = src_buf.len().min(remaining);
self.buf.extend_from_slice(&src_buf[..to_copy]);
self.offset += to_copy as u64;
to_copy
}
// #[inline(always)]
// /// Writes as much of `src_buf` into the internal buffer as it fits
// fn write_into_buffer(&mut self, src_buf: &[u8]) -> usize {
// let remaining = Self::CAPACITY - self.buf.len();
// let to_copy = src_buf.len().min(remaining);
// self.buf.extend_from_slice(&src_buf[..to_copy]);
// self.offset += to_copy as u64;
// to_copy
// }
/// Internal, possibly buffered, write function
async fn write_all<Buf: IoBuf + Send>(
&mut self,
src_buf: FullSlice<Buf>,
ctx: &RequestContext,
) -> (FullSlice<Buf>, Result<(), Error>) {
let src_buf = src_buf.into_raw_slice();
let src_buf_bounds = src_buf.bounds();
let restore = move |src_buf_slice: Slice<_>| {
FullSlice::must_new(Slice::from_buf_bounds(
src_buf_slice.into_inner(),
src_buf_bounds,
))
};
let res = self
.writer
.write_buffered_borrowed(&src_buf, ctx)
.await
.map(|len| {
self.offset += len as u64;
()
});
if !BUFFERED {
assert!(self.buf.is_empty());
return self
.write_all_unbuffered(FullSlice::must_new(src_buf), ctx)
.await;
}
let remaining = Self::CAPACITY - self.buf.len();
let src_buf_len = src_buf.bytes_init();
if src_buf_len == 0 {
return (restore(src_buf), Ok(()));
}
let mut src_buf = src_buf.slice(0..src_buf_len);
// First try to copy as much as we can into the buffer
if remaining > 0 {
let copied = self.write_into_buffer(&src_buf);
src_buf = src_buf.slice(copied..);
}
// Then, if the buffer is full, flush it out
if self.buf.len() == Self::CAPACITY {
if let Err(e) = self.flush_buffer(ctx).await {
return (restore(src_buf), Err(e));
}
}
// Finally, write the tail of src_buf:
// If it wholly fits into the buffer without
// completely filling it, then put it there.
// If not, write it out directly.
let src_buf = if !src_buf.is_empty() {
assert_eq!(self.buf.len(), 0);
if src_buf.len() < Self::CAPACITY {
let copied = self.write_into_buffer(&src_buf);
// We just verified above that src_buf fits into our internal buffer.
assert_eq!(copied, src_buf.len());
restore(src_buf)
} else {
let (src_buf, res) = self
.write_all_unbuffered(FullSlice::must_new(src_buf), ctx)
.await;
if let Err(e) = res {
return (src_buf, Err(e));
}
src_buf
}
} else {
restore(src_buf)
};
(src_buf, Ok(()))
(src_buf, res)
}
// /// Internal, possibly buffered, write function
// async fn write_all_old<Buf: IoBuf + Send>(
// &mut self,
// src_buf: FullSlice<Buf>,
// ctx: &RequestContext,
// ) -> (FullSlice<Buf>, Result<(), Error>) {
// let src_buf = src_buf.into_raw_slice();
// let src_buf_bounds = src_buf.bounds();
// let restore = move |src_buf_slice: Slice<_>| {
// FullSlice::must_new(Slice::from_buf_bounds(
// src_buf_slice.into_inner(),
// src_buf_bounds,
// ))
// };
// if !BUFFERED {
// assert!(self.buf.is_empty());
// return self
// .write_all_unbuffered(FullSlice::must_new(src_buf), ctx)
// .await;
// }
// let remaining = Self::CAPACITY - self.buf.len();
// let src_buf_len = src_buf.bytes_init();
// if src_buf_len == 0 {
// return (restore(src_buf), Ok(()));
// }
// let mut src_buf = src_buf.slice(0..src_buf_len);
// // First try to copy as much as we can into the buffer
// if remaining > 0 {
// let copied = self.write_into_buffer(&src_buf);
// src_buf = src_buf.slice(copied..);
// }
// // Then, if the buffer is full, flush it out
// if self.buf.len() == Self::CAPACITY {
// if let Err(e) = self.flush_buffer(ctx).await {
// return (restore(src_buf), Err(e));
// }
// }
// // Finally, write the tail of src_buf:
// // If it wholly fits into the buffer without
// // completely filling it, then put it there.
// // If not, write it out directly.
// let src_buf = if !src_buf.is_empty() {
// assert_eq!(self.buf.len(), 0);
// if src_buf.len() < Self::CAPACITY {
// let copied = self.write_into_buffer(&src_buf);
// // We just verified above that src_buf fits into our internal buffer.
// assert_eq!(copied, src_buf.len());
// restore(src_buf)
// } else {
// let (src_buf, res) = self
// .write_all_unbuffered(FullSlice::must_new(src_buf), ctx)
// .await;
// if let Err(e) = res {
// return (src_buf, Err(e));
// }
// src_buf
// }
// } else {
// restore(src_buf)
// };
// (src_buf, Ok(()))
// }
/// Write a blob of data. Returns the offset that it was written to,
/// which can be used to retrieve the data later.
pub async fn write_blob<Buf: IoBuf + Send>(
@@ -308,7 +335,7 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
ctx: &RequestContext,
algorithm: ImageCompressionAlgorithm,
) -> (FullSlice<Buf>, Result<(u64, CompressionInfo), Error>) {
let offset = self.offset;
let offset = self.size();
let mut compression_info = CompressionInfo {
written_compressed: false,
compressed_size: None,
@@ -384,16 +411,15 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
};
(srcbuf, res.map(|_| (offset, compression_info)))
}
}
impl BlobWriter<true> {
/// Access the underlying `VirtualFile`.
///
/// This function flushes the internal buffer before giving access
/// to the underlying `VirtualFile`.
pub async fn into_inner(mut self, ctx: &RequestContext) -> Result<VirtualFile, Error> {
self.flush_buffer(ctx).await?;
Ok(self.inner)
pub async fn into_inner(self, ctx: &RequestContext) -> Result<VirtualFile, Error> {
let (_, file) = self.writer.shutdown(ctx).await?;
Ok(file)
}
/// Access the underlying `VirtualFile`.
@@ -401,14 +427,7 @@ impl BlobWriter<true> {
/// Unlike [`into_inner`](Self::into_inner), this doesn't flush
/// the internal buffer before giving access.
pub fn into_inner_no_flush(self) -> VirtualFile {
self.inner
}
}
impl BlobWriter<false> {
/// Access the underlying `VirtualFile`.
pub fn into_inner(self) -> VirtualFile {
self.inner
self.writer.shutdown_no_flush()
}
}
@@ -420,23 +439,24 @@ pub(crate) mod tests {
use camino_tempfile::Utf8TempDir;
use rand::{Rng, SeedableRng};
async fn round_trip_test<const BUFFERED: bool>(blobs: &[Vec<u8>]) -> Result<(), Error> {
round_trip_test_compressed::<BUFFERED>(blobs, false).await
async fn round_trip_test(blobs: &[Vec<u8>]) -> Result<(), Error> {
round_trip_test_compressed(blobs, false).await
}
pub(crate) async fn write_maybe_compressed<const BUFFERED: bool>(
pub(crate) async fn write_maybe_compressed(
blobs: &[Vec<u8>],
compression: bool,
ctx: &RequestContext,
) -> Result<(Utf8TempDir, Utf8PathBuf, Vec<u64>), Error> {
let temp_dir = camino_tempfile::tempdir()?;
let pathbuf = temp_dir.path().join("file");
let gate = utils::sync::gate::Gate::default();
// Write part (in block to drop the file)
let mut offsets = Vec::new();
{
let file = VirtualFile::create(pathbuf.as_path(), ctx).await?;
let mut wtr = BlobWriter::<BUFFERED>::new(file, 0);
let file = Arc::new(VirtualFile::create_v2(pathbuf.as_path(), ctx).await?);
let mut wtr = BlobWriter::new(file, 0, &gate, ctx).unwrap();
for blob in blobs.iter() {
let (_, res) = if compression {
let res = wtr
@@ -458,20 +478,18 @@ pub(crate) mod tests {
let (_, res) = wtr.write_blob(vec![0; PAGE_SZ].slice_len(), ctx).await;
let offs = res?;
println!("Writing final blob at offs={offs}");
wtr.flush_buffer(ctx).await?;
wtr.into_inner(ctx).await?;
}
Ok((temp_dir, pathbuf, offsets))
}
async fn round_trip_test_compressed<const BUFFERED: bool>(
blobs: &[Vec<u8>],
compression: bool,
) -> Result<(), Error> {
async fn round_trip_test_compressed(blobs: &[Vec<u8>], compression: bool) -> Result<(), Error> {
let ctx = RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error);
let (_temp_dir, pathbuf, offsets) =
write_maybe_compressed::<BUFFERED>(blobs, compression, &ctx).await?;
write_maybe_compressed(blobs, compression, &ctx).await?;
let file = VirtualFile::open(pathbuf, &ctx).await?;
println!("Done writing!");
let file = VirtualFile::open_v2(pathbuf, &ctx).await?;
let rdr = BlockReaderRef::VirtualFile(&file);
let rdr = BlockCursor::new_with_compression(rdr, compression);
for (idx, (blob, offset)) in blobs.iter().zip(offsets.iter()).enumerate() {
@@ -492,8 +510,7 @@ pub(crate) mod tests {
#[tokio::test]
async fn test_one() -> Result<(), Error> {
let blobs = &[vec![12, 21, 22]];
round_trip_test::<false>(blobs).await?;
round_trip_test::<true>(blobs).await?;
round_trip_test(blobs).await?;
Ok(())
}
@@ -505,10 +522,8 @@ pub(crate) mod tests {
Vec::new(),
b"foobar".to_vec(),
];
round_trip_test::<false>(blobs).await?;
round_trip_test::<true>(blobs).await?;
round_trip_test_compressed::<false>(blobs, true).await?;
round_trip_test_compressed::<true>(blobs, true).await?;
round_trip_test(blobs).await?;
round_trip_test_compressed(blobs, true).await?;
Ok(())
}
@@ -522,10 +537,8 @@ pub(crate) mod tests {
vec![0xf3; 24 * PAGE_SZ],
b"foobar".to_vec(),
];
round_trip_test::<false>(blobs).await?;
round_trip_test::<true>(blobs).await?;
round_trip_test_compressed::<false>(blobs, true).await?;
round_trip_test_compressed::<true>(blobs, true).await?;
round_trip_test(blobs).await?;
round_trip_test_compressed(blobs, true).await?;
Ok(())
}
@@ -534,8 +547,7 @@ pub(crate) mod tests {
let blobs = (0..PAGE_SZ / 8)
.map(|v| random_array(v * 16))
.collect::<Vec<_>>();
round_trip_test::<false>(&blobs).await?;
round_trip_test::<true>(&blobs).await?;
round_trip_test(&blobs).await?;
Ok(())
}
@@ -552,8 +564,7 @@ pub(crate) mod tests {
random_array(sz.into())
})
.collect::<Vec<_>>();
round_trip_test::<false>(&blobs).await?;
round_trip_test::<true>(&blobs).await?;
round_trip_test(&blobs).await?;
Ok(())
}
@@ -564,8 +575,7 @@ pub(crate) mod tests {
random_array(PAGE_SZ - 4),
random_array(PAGE_SZ - 4),
];
round_trip_test::<false>(blobs).await?;
round_trip_test::<true>(blobs).await?;
round_trip_test(blobs).await?;
Ok(())
}
}

View File

@@ -72,6 +72,7 @@ impl EphemeralFile {
bytes_written: 0,
buffered_writer: owned_buffers_io::write::BufferedWriter::new(
file,
0,
|| IoBufferMut::with_capacity(TAIL_SZ),
gate.enter()?,
ctx,
@@ -180,7 +181,7 @@ impl super::storage_layer::inmemory_layer::vectored_dio_read::File for Ephemeral
dst: tokio_epoll_uring::Slice<B>,
ctx: &'a RequestContext,
) -> std::io::Result<(tokio_epoll_uring::Slice<B>, usize)> {
let submitted_offset = self.buffered_writer.bytes_submitted();
let submitted_offset = self.buffered_writer.submit_offset();
let mutable = self.buffered_writer.inspect_mutable();
let mutable = &mutable[0..mutable.pending()];

View File

@@ -227,6 +227,7 @@ async fn download_object<'a>(
let mut buffered = owned_buffers_io::write::BufferedWriter::<IoBufferMut, _>::new(
destination_file,
0,
|| IoBufferMut::with_capacity(super::BUFFER_SIZE),
gate.enter().map_err(|_| DownloadError::Cancelled)?,
ctx,
@@ -244,7 +245,7 @@ async fn download_object<'a>(
};
buffered.write_buffered_borrowed(&chunk, ctx).await?;
}
let inner = buffered.flush_and_into_inner(ctx).await?;
let inner = buffered.shutdown(ctx).await?;
Ok(inner)
}
.await?;

View File

@@ -392,7 +392,7 @@ struct DeltaLayerWriterInner {
tree: DiskBtreeBuilder<BlockBuf, DELTA_KEY_SIZE>,
blob_writer: BlobWriter<true>,
blob_writer: BlobWriter,
// Number of key-lsns in the layer.
num_keys: usize,
@@ -419,10 +419,13 @@ impl DeltaLayerWriterInner {
let path =
DeltaLayer::temp_path_for(conf, &tenant_shard_id, &timeline_id, key_start, &lsn_range);
let mut file = VirtualFile::create(&path, ctx).await?;
// make room for the header block
file.seek(SeekFrom::Start(PAGE_SZ as u64)).await?;
let blob_writer = BlobWriter::new(file, PAGE_SZ as u64);
let file = Arc::new(VirtualFile::create(&path, ctx).await?);
// FIXME(yuchen): propagate &gate from parent
let gate = utils::sync::gate::Gate::default();
// Start at PAGE_SZ, make room for the header block
let blob_writer = BlobWriter::new(file, PAGE_SZ as u64, &gate, ctx)?;
// Initialize the b-tree index builder
let block_buf = BlockBuf::new();

View File

@@ -62,6 +62,7 @@ use std::io::SeekFrom;
use std::ops::Range;
use std::os::unix::prelude::FileExt;
use std::str::FromStr;
use std::sync::Arc;
use tokio::sync::OnceCell;
use tokio_stream::StreamExt;
use tracing::*;
@@ -724,7 +725,7 @@ struct ImageLayerWriterInner {
// Number of keys in the layer.
num_keys: usize,
blob_writer: BlobWriter<false>,
blob_writer: BlobWriter,
tree: DiskBtreeBuilder<BlockBuf, KEY_SIZE>,
#[cfg(feature = "testing")]
@@ -755,19 +756,24 @@ impl ImageLayerWriterInner {
},
);
trace!("creating image layer {}", path);
let mut file = {
VirtualFile::open_with_options(
&path,
virtual_file::OpenOptions::new()
.write(true)
.create_new(true),
ctx,
let file = {
Arc::new(
VirtualFile::open_with_options(
&path,
virtual_file::OpenOptions::new()
.write(true)
.create_new(true),
ctx,
)
.await?,
)
.await?
};
// make room for the header block
file.seek(SeekFrom::Start(PAGE_SZ as u64)).await?;
let blob_writer = BlobWriter::new(file, PAGE_SZ as u64);
// FIXME(yuchen): propagate &gate from parent
let gate = utils::sync::gate::Gate::default();
// Start at `PAGE_SZ` to make room for the header block.
let blob_writer = BlobWriter::new(file, PAGE_SZ as u64, &gate, ctx)?;
// Initialize the b-tree index builder
let block_buf = BlockBuf::new();
@@ -873,7 +879,7 @@ impl ImageLayerWriterInner {
crate::metrics::COMPRESSION_IMAGE_INPUT_BYTES_CHOSEN.inc_by(self.uncompressed_bytes_chosen);
crate::metrics::COMPRESSION_IMAGE_OUTPUT_BYTES.inc_by(compressed_size);
let mut file = self.blob_writer.into_inner();
let mut file = self.blob_writer.into_inner(ctx).await?;
// Write out the index
file.seek(SeekFrom::Start(index_start_blk as u64 * PAGE_SZ as u64))
@@ -1038,7 +1044,7 @@ impl ImageLayerWriter {
impl Drop for ImageLayerWriter {
fn drop(&mut self) {
if let Some(inner) = self.inner.take() {
inner.blob_writer.into_inner().remove();
inner.blob_writer.into_inner_no_flush().remove();
}
}
}

View File

@@ -910,9 +910,9 @@ mod tests {
async fn round_trip_test_compressed(blobs: &[Vec<u8>], compression: bool) -> Result<(), Error> {
let ctx = RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error);
let (_temp_dir, pathbuf, offsets) =
write_maybe_compressed::<true>(blobs, compression, &ctx).await?;
write_maybe_compressed(blobs, compression, &ctx).await?;
let file = VirtualFile::open(&pathbuf, &ctx).await?;
let file = VirtualFile::open_v2(&pathbuf, &ctx).await?;
let file_len = std::fs::metadata(&pathbuf)?.len();
// Multiply by two (compressed data might need more space), and add a few bytes for the header

View File

@@ -975,7 +975,10 @@ impl VirtualFileInner {
) -> (FullSlice<B>, Result<usize, Error>) {
let file_guard = match self.lock_file().await {
Ok(file_guard) => file_guard,
Err(e) => return (buf, Err(e)),
Err(e) => {
println!("ERRORED :(");
return (buf, Err(e));
}
};
observe_duration!(StorageIoOperation::Write, {
let ((_file_guard, buf), result) =
@@ -1328,8 +1331,16 @@ impl OwnedAsyncWriter for VirtualFile {
offset: u64,
ctx: &RequestContext,
) -> std::io::Result<FullSlice<Buf>> {
println!(
"offset={offset}, buf={:?}, buflen={}",
buf.as_ptr(),
buf.len()
);
assert_eq!(offset % 512, 0);
assert_eq!(buf.as_ptr().align_offset(512), 0);
let (buf, res) = VirtualFile::write_all_at(self, buf, offset, ctx).await;
res.map(|_| buf)
let x = res.map(|_| buf).unwrap();
Ok(x)
}
}

View File

@@ -1,6 +1,7 @@
mod flush;
use std::sync::Arc;
use bytes::BufMut;
use flush::FlushHandle;
use tokio_epoll_uring::IoBuf;
@@ -54,8 +55,8 @@ pub struct BufferedWriter<B: Buffer, W> {
mutable: Option<B>,
/// A handle to the background flush task for writting data to disk.
flush_handle: FlushHandle<B::IoBuf, W>,
/// The number of bytes submitted to the background task.
bytes_submitted: u64,
/// The next offset to be submitted to the background task.
submit_offset: u64,
}
impl<B, Buf, W> BufferedWriter<B, W>
@@ -69,6 +70,7 @@ where
/// The `buf_new` function provides a way to initialize the owned buffers used by this writer.
pub fn new(
writer: Arc<W>,
start_offset: u64,
buf_new: impl Fn() -> B,
gate_guard: utils::sync::gate::GateGuard,
ctx: &RequestContext,
@@ -82,7 +84,7 @@ where
gate_guard,
ctx.attached_child(),
),
bytes_submitted: 0,
submit_offset: start_offset,
}
}
@@ -91,8 +93,8 @@ where
}
/// Returns the number of bytes submitted to the background flush task.
pub fn bytes_submitted(&self) -> u64 {
self.bytes_submitted
pub fn submit_offset(&self) -> u64 {
self.submit_offset
}
/// Panics if used after any of the write paths returned an error
@@ -107,24 +109,42 @@ where
}
#[cfg_attr(target_os = "macos", allow(dead_code))]
pub async fn flush_and_into_inner(
mut self,
ctx: &RequestContext,
) -> std::io::Result<(u64, Arc<W>)> {
self.flush(ctx).await?;
pub async fn shutdown(mut self, ctx: &RequestContext) -> std::io::Result<(u64, W)> {
let buf = self.mutable_mut();
if buf.pending() < buf.cap() {
let count = buf.pending().next_multiple_of(512) - buf.pending();
buf.extend_with(0, count);
}
if let Some(control) = self.flush(ctx).await? {
control.release().await;
}
let Self {
mutable: buf,
writer,
mut flush_handle,
bytes_submitted: bytes_amount,
submit_offset: bytes_amount,
} = self;
flush_handle.shutdown().await?;
assert!(buf.is_some());
let writer = Arc::into_inner(writer).expect("writer is the only strong reference");
Ok((bytes_amount, writer))
}
/// Gets a reference to the mutable in-memory buffer.
#[cfg_attr(target_os = "macos", allow(dead_code))]
pub fn shutdown_no_flush(self) -> W {
let Self {
mutable: _,
writer,
flush_handle,
submit_offset: _,
} = self;
flush_handle.shutdown_no_flush();
let writer = Arc::into_inner(writer).expect("writer is the only strong reference");
writer
}
/// Gets a immutable reference to the mutable in-memory buffer.
#[inline(always)]
fn mutable(&self) -> &B {
self.mutable
@@ -132,6 +152,14 @@ where
.expect("must not use after we returned an error")
}
/// Gets a mutable reference to the mutable in-memory buffer.
#[inline(always)]
fn mutable_mut(&mut self) -> &mut B {
self.mutable
.as_mut()
.expect("must not use after we returned an error")
}
pub async fn write_buffered_borrowed(
&mut self,
chunk: &[u8],
@@ -153,7 +181,7 @@ where
let chunk_len = chunk.len();
let mut control: Option<FlushControl> = None;
while !chunk.is_empty() {
let buf = self.mutable.as_mut().expect("must not use after an error");
let buf = self.mutable_mut();
let need = buf.cap() - buf.pending();
let have = chunk.len();
let n = std::cmp::min(need, have);
@@ -178,8 +206,8 @@ where
self.mutable = Some(buf);
return Ok(None);
}
let (recycled, flush_control) = self.flush_handle.flush(buf, self.bytes_submitted).await?;
self.bytes_submitted += u64::try_from(buf_len).unwrap();
let (recycled, flush_control) = self.flush_handle.flush(buf, self.submit_offset).await?;
self.submit_offset += u64::try_from(buf_len).unwrap();
self.mutable = Some(recycled);
Ok(Some(flush_control))
}
@@ -197,6 +225,10 @@ pub trait Buffer {
/// panics if `other.len() > self.cap() - self.pending()`.
fn extend_from_slice(&mut self, other: &[u8]);
/// Add `count` bytes `val` into `self`.
/// Panics if `count > self.cap() - self.pending()`.
fn extend_with(&mut self, val: u8, count: usize);
/// Number of bytes in the buffer.
fn pending(&self) -> usize;
@@ -224,6 +256,14 @@ impl Buffer for IoBufferMut {
IoBufferMut::extend_from_slice(self, other);
}
fn extend_with(&mut self, val: u8, count: usize) {
if self.len() + count > self.cap() {
panic!("Buffer capacity exceeded");
}
IoBufferMut::put_bytes(self, val, count);
}
fn pending(&self) -> usize {
self.len()
}
@@ -295,6 +335,7 @@ mod tests {
let gate = utils::sync::gate::Gate::default();
let mut writer = BufferedWriter::<_, RecorderWriter>::new(
recorder,
0,
|| IoBufferMut::with_capacity(2),
gate.enter()?,
ctx,
@@ -309,7 +350,7 @@ mod tests {
writer.write_buffered_borrowed(b"j", ctx).await?;
writer.write_buffered_borrowed(b"klmno", ctx).await?;
let (_, recorder) = writer.flush_and_into_inner(ctx).await?;
let (_, recorder) = writer.shutdown(ctx).await?;
assert_eq!(
recorder.get_writes(),
{

View File

@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::{marker::PhantomData, sync::Arc};
use utils::sync::duplex;
@@ -22,7 +22,9 @@ pub struct FlushHandleInner<Buf, W> {
/// and receives recyled buffer.
channel: duplex::mpsc::Duplex<FlushRequest<Buf>, FullSlice<Buf>>,
/// Join handle for the background flush task.
join_handle: tokio::task::JoinHandle<std::io::Result<Arc<W>>>,
join_handle: tokio::task::JoinHandle<std::io::Result<()>>,
_phantom: PhantomData<W>,
}
struct FlushRequest<Buf> {
@@ -137,6 +139,7 @@ where
inner: Some(FlushHandleInner {
channel: front,
join_handle,
_phantom: PhantomData,
}),
maybe_flushed: None,
}
@@ -176,7 +179,7 @@ where
Ok((recycled, flush_control))
}
async fn handle_error<T>(&mut self) -> std::io::Result<T> {
pub(super) async fn handle_error<T>(&mut self) -> std::io::Result<T> {
Err(self
.shutdown()
.await
@@ -184,7 +187,7 @@ where
}
/// Cleans up the channel, join the flush task.
pub async fn shutdown(&mut self) -> std::io::Result<Arc<W>> {
pub async fn shutdown(&mut self) -> std::io::Result<()> {
let handle = self
.inner
.take()
@@ -193,6 +196,14 @@ where
handle.join_handle.await.unwrap()
}
pub fn shutdown_no_flush(mut self) {
let handle = self
.inner
.take()
.expect("must not use after we returned an error");
handle.join_handle.abort();
}
/// Gets a mutable reference to the inner handle. Panics if [`Self::inner`] is `None`.
/// This only happens if the handle is used after an error.
fn inner_mut(&mut self) -> &mut FlushHandleInner<Buf, W> {
@@ -236,7 +247,7 @@ where
/// Runs the background flush task.
/// The passed in slice is immediately sent back to the flush handle through the duplex channel.
async fn run(mut self, slice: FullSlice<Buf>) -> std::io::Result<Arc<W>> {
async fn run(mut self, slice: FullSlice<Buf>) -> std::io::Result<()> {
// Sends the extra buffer back to the handle.
self.channel.send(slice).await.map_err(|_| {
std::io::Error::new(std::io::ErrorKind::BrokenPipe, "flush handle closed early")
@@ -272,8 +283,8 @@ where
continue;
}
}
Ok(self.writer)
drop(self);
Ok(())
}
}
@@ -308,7 +319,7 @@ impl FlushNotStarted {
impl FlushInProgress {
/// Waits until background flush is done.
pub async fn wait_until_flush_is_done(self) -> FlushDone {
self.done_flush_rx.await.unwrap();
let _ = self.done_flush_rx.await;
FlushDone
}
}