change OwnedAsyncWriter trait to use write_all_at

Signed-off-by: Yuchen Liang <yuchen@neon.tech>
This commit is contained in:
Yuchen Liang
2024-11-08 15:09:33 +00:00
parent dd1c45e896
commit 224cbb4025
2 changed files with 50 additions and 38 deletions

View File

@@ -1295,14 +1295,14 @@ impl Drop for VirtualFileInner {
}
impl OwnedAsyncWriter for VirtualFile {
#[inline(always)]
async fn write_all<Buf: IoBuf + Send>(
&mut self,
async fn write_all_at<Buf: IoBuf + Send>(
&self,
buf: FullSlice<Buf>,
offset: u64,
ctx: &RequestContext,
) -> std::io::Result<(usize, FullSlice<Buf>)> {
let (buf, res) = VirtualFile::write_all(self, buf, ctx).await;
res.map(move |v| (v, buf))
) -> std::io::Result<FullSlice<Buf>> {
let (buf, res) = VirtualFile::write_all_at(self, buf, offset, ctx).await;
res.map(|_| buf)
}
}
@@ -1560,6 +1560,7 @@ mod tests {
&ctx,
)
.await?;
file_a
.write_all(b"foobar".to_vec().slice_len(), &ctx)
.await?;

View File

@@ -8,11 +8,12 @@ use super::io_buf_ext::{FullSlice, IoBufExt};
/// A trait for doing owned-buffer write IO.
/// Think [`tokio::io::AsyncWrite`] but with owned buffers.
pub trait OwnedAsyncWriter {
async fn write_all<Buf: IoBuf + Send>(
&mut self,
async fn write_all_at<Buf: IoBuf + Send>(
&self,
buf: FullSlice<Buf>,
offset: u64,
ctx: &RequestContext,
) -> std::io::Result<(usize, FullSlice<Buf>)>;
) -> std::io::Result<FullSlice<Buf>>;
}
/// A wrapper aorund an [`OwnedAsyncWriter`] that uses a [`Buffer`] to batch
@@ -109,13 +110,12 @@ where
.pending(),
0
);
let (nwritten, chunk) = self
let chunk = self
.writer
.write_all(FullSlice::must_new(chunk), ctx)
.write_all_at(FullSlice::must_new(chunk), self.bytes_amount, ctx)
.await?;
self.bytes_amount += u64::try_from(nwritten).unwrap();
assert_eq!(nwritten, chunk_len);
return Ok((nwritten, chunk));
self.bytes_amount += u64::try_from(chunk_len).unwrap();
return Ok((chunk_len, chunk));
}
// in-memory copy the < BUFFER_SIZED tail of the chunk
assert!(chunk.len() < self.buf().cap());
@@ -170,9 +170,11 @@ where
return Ok(());
}
let slice = buf.flush();
let (nwritten, slice) = self.writer.write_all(slice, ctx).await?;
self.bytes_amount += u64::try_from(nwritten).unwrap();
assert_eq!(nwritten, buf_len);
let slice = self
.writer
.write_all_at(slice, self.bytes_amount, ctx)
.await?;
self.bytes_amount += u64::try_from(buf_len).unwrap();
self.buf = Some(Buffer::reuse_after_flush(
slice.into_raw_slice().into_inner(),
));
@@ -231,19 +233,10 @@ impl Buffer for BytesMut {
}
}
impl OwnedAsyncWriter for Vec<u8> {
async fn write_all<Buf: IoBuf + Send>(
&mut self,
buf: FullSlice<Buf>,
_: &RequestContext,
) -> std::io::Result<(usize, FullSlice<Buf>)> {
self.extend_from_slice(&buf[..]);
Ok((buf.len(), buf))
}
}
#[cfg(test)]
mod tests {
use std::sync::Mutex;
use bytes::BytesMut;
use super::*;
@@ -252,16 +245,34 @@ mod tests {
#[derive(Default)]
struct RecorderWriter {
writes: Vec<Vec<u8>>,
/// record bytes and write offsets.
writes: Mutex<Vec<(Vec<u8>, u64)>>,
}
impl RecorderWriter {
/// Gets recorded bytes and write offsets.
fn get_writes(&self) -> Vec<Vec<u8>> {
self.writes
.lock()
.unwrap()
.iter()
.map(|(buf, _)| buf.clone())
.collect()
}
}
impl OwnedAsyncWriter for RecorderWriter {
async fn write_all<Buf: IoBuf + Send>(
&mut self,
async fn write_all_at<Buf: IoBuf + Send>(
&self,
buf: FullSlice<Buf>,
offset: u64,
_: &RequestContext,
) -> std::io::Result<(usize, FullSlice<Buf>)> {
self.writes.push(Vec::from(&buf[..]));
Ok((buf.len(), buf))
) -> std::io::Result<FullSlice<Buf>> {
self.writes
.lock()
.unwrap()
.push((Vec::from(&buf[..]), offset));
Ok(buf)
}
}
@@ -288,7 +299,7 @@ mod tests {
write!(writer, b"e");
let (_, recorder) = writer.flush_and_into_inner(&test_ctx()).await?;
assert_eq!(
recorder.writes,
recorder.get_writes(),
vec![Vec::from(b"ab"), Vec::from(b"cd"), Vec::from(b"e")]
);
Ok(())
@@ -304,7 +315,7 @@ mod tests {
write!(writer, b"fghijk");
let (_, recorder) = writer.flush_and_into_inner(&test_ctx()).await?;
assert_eq!(
recorder.writes,
recorder.get_writes(),
vec![Vec::from(b"abc"), Vec::from(b"de"), Vec::from(b"fghijk")]
);
Ok(())
@@ -320,7 +331,7 @@ mod tests {
write!(writer, b"e");
let (_, recorder) = writer.flush_and_into_inner(&test_ctx()).await?;
assert_eq!(
recorder.writes,
recorder.get_writes(),
vec![Vec::from(b"a"), Vec::from(b"bc"), Vec::from(b"de")]
);
Ok(())
@@ -343,7 +354,7 @@ mod tests {
let (_, recorder) = writer.flush_and_into_inner(ctx).await?;
assert_eq!(
recorder.writes,
recorder.get_writes(),
{
let expect: &[&[u8]] = &[b"ab", b"cd", b"ef", b"gh", b"ij", b"kl", b"mn", b"o"];
expect