remove total byte copy amount

This commit is contained in:
Conrad Ludgate
2025-05-19 22:38:09 +01:00
parent 14312f1a9a
commit 727c333831

View File

@@ -10,8 +10,8 @@ use crate::metrics::Direction;
enum TransferState {
Running(CopyBuffer, Direction),
ShuttingDown(u64),
Done(u64),
ShuttingDown,
Done,
}
#[derive(Debug)]
@@ -47,7 +47,7 @@ fn transfer_one_direction<A, B>(
f: &mut impl for<'a> FnMut(Direction, &'a [u8]),
r: &mut A,
w: &mut B,
) -> Poll<Result<u64, ErrorDirection>>
) -> Poll<Result<(), ErrorDirection>>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
@@ -57,14 +57,14 @@ where
loop {
match state {
TransferState::Running(buf, dir) => {
let count = ready!(buf.poll_copy(cx, |b| f(*dir, b), r.as_mut(), w.as_mut()))?;
*state = TransferState::ShuttingDown(count);
ready!(buf.poll_copy(cx, |b| f(*dir, b), r.as_mut(), w.as_mut()))?;
*state = TransferState::ShuttingDown;
}
TransferState::ShuttingDown(count) => {
TransferState::ShuttingDown => {
ready!(w.as_mut().poll_shutdown(cx)).map_err(ErrorDirection::Write)?;
*state = TransferState::Done(*count);
*state = TransferState::Done;
}
TransferState::Done(count) => return Poll::Ready(Ok(*count)),
TransferState::Done => return Poll::Ready(Ok(())),
}
}
}
@@ -73,7 +73,7 @@ pub async fn copy_bidirectional_client_compute<Client, Compute>(
client: &mut Client,
compute: &mut Compute,
mut f: impl for<'a> FnMut(Direction, &'a [u8]),
) -> Result<(u64, u64), ErrorSource>
) -> Result<(), ErrorSource>
where
Client: AsyncRead + AsyncWrite + Unpin + ?Sized,
Compute: AsyncRead + AsyncWrite + Unpin + ?Sized,
@@ -94,11 +94,11 @@ where
// TODO: 1 info log, with a enum label for close direction.
// Early termination checks from compute to client.
if let TransferState::Done(_) = compute_to_client {
if let TransferState::Running(buf, _) = &client_to_compute {
if let TransferState::Done = compute_to_client {
if let TransferState::Running(..) = &client_to_compute {
info!("Compute is done, terminate client");
// Initiate shutdown
client_to_compute = TransferState::ShuttingDown(buf.amt);
client_to_compute = TransferState::ShuttingDown;
client_to_compute_result =
transfer_one_direction(cx, &mut client_to_compute, &mut f, client, compute)
.map_err(ErrorSource::from_client)?;
@@ -106,22 +106,21 @@ where
}
// Early termination checks from client to compute.
if let TransferState::Done(_) = client_to_compute {
if let TransferState::Running(buf, _) = &compute_to_client {
if let TransferState::Done = client_to_compute {
if let TransferState::Running(..) = &compute_to_client {
info!("Client is done, terminate compute");
// Initiate shutdown
compute_to_client = TransferState::ShuttingDown(buf.amt);
compute_to_client = TransferState::ShuttingDown;
compute_to_client_result =
transfer_one_direction(cx, &mut compute_to_client, &mut f, compute, client)
.map_err(ErrorSource::from_compute)?;
}
}
// It is not a problem if ready! returns early ... (comment remains the same)
let client_to_compute = ready!(client_to_compute_result);
let compute_to_client = ready!(compute_to_client_result);
ready!(client_to_compute_result);
ready!(compute_to_client_result);
Poll::Ready(Ok((client_to_compute, compute_to_client)))
Poll::Ready(Ok(()))
})
.await
}
@@ -132,7 +131,6 @@ pub(super) struct CopyBuffer {
need_flush: bool,
pos: usize,
cap: usize,
amt: u64,
buf: Box<[u8]>,
}
const DEFAULT_BUF_SIZE: usize = 1024;
@@ -144,7 +142,6 @@ impl CopyBuffer {
need_flush: false,
pos: 0,
cap: 0,
amt: 0,
buf: vec![0; DEFAULT_BUF_SIZE].into_boxed_slice(),
}
}
@@ -205,7 +202,7 @@ impl CopyBuffer {
mut f: impl for<'a> FnMut(&'a [u8]),
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<Result<u64, ErrorDirection>>
) -> Poll<Result<(), ErrorDirection>>
where
R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized,
@@ -245,7 +242,6 @@ impl CopyBuffer {
))));
}
self.pos += i;
self.amt += i as u64;
self.need_flush = true;
}
@@ -265,7 +261,7 @@ impl CopyBuffer {
// data and finish the transfer.
if self.read_done {
ready!(writer.as_mut().poll_flush(cx)).map_err(ErrorDirection::Write)?;
return Poll::Ready(Ok(self.amt));
return Poll::Ready(Ok(()));
}
}
}
@@ -273,7 +269,7 @@ impl CopyBuffer {
#[cfg(test)]
mod tests {
use tokio::io::AsyncWriteExt;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use super::*;
@@ -288,15 +284,22 @@ mod tests {
compute_client.write_all(b"Neon").await.unwrap();
compute_client.shutdown().await.unwrap();
let result =
copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy, |_, _| {})
.await
.unwrap();
copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy, |_, _| {})
.await
.unwrap();
drop(client_proxy);
drop(compute_proxy);
// Assert correct transferred amounts
let (client_to_compute_count, compute_to_client_count) = result;
assert_eq!(client_to_compute_count, 5); // 'hello' was transferred
assert_eq!(compute_to_client_count, 4); // response only partially transferred or not at all
let mut client_recv = vec![];
client_client.read_buf(&mut client_recv).await.unwrap();
let mut compute_recv = vec![];
compute_client.read_buf(&mut compute_recv).await.unwrap();
assert_eq!(compute_recv, b"hello");
assert_eq!(client_recv, b"Neon");
}
#[tokio::test]
@@ -312,14 +315,21 @@ mod tests {
.await
.unwrap();
let result =
copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy, |_, _| {})
.await
.unwrap();
copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy, |_, _| {})
.await
.unwrap();
drop(client_proxy);
drop(compute_proxy);
// Assert correct transferred amounts
let (client_to_compute_count, compute_to_client_count) = result;
assert_eq!(compute_to_client_count, 5); // 'hello' was transferred
assert!(client_to_compute_count <= 8); // response only partially transferred or not at all
let mut client_recv = vec![];
client_client.read_buf(&mut client_recv).await.unwrap();
let mut compute_recv = vec![];
compute_client.read_buf(&mut compute_recv).await.unwrap();
assert_eq!(client_recv, b"hello");
assert_eq!(compute_recv, b"Neon Ser");
}
}