From 727c333831c49d16ba7887f1ea70f932c91b28e1 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 19 May 2025 22:38:09 +0100 Subject: [PATCH] remove total byte copy amount --- proxy/src/proxy/copy_bidirectional.rs | 88 +++++++++++++++------------ 1 file changed, 49 insertions(+), 39 deletions(-) diff --git a/proxy/src/proxy/copy_bidirectional.rs b/proxy/src/proxy/copy_bidirectional.rs index d89397508f..801720e34d 100644 --- a/proxy/src/proxy/copy_bidirectional.rs +++ b/proxy/src/proxy/copy_bidirectional.rs @@ -10,8 +10,8 @@ use crate::metrics::Direction; enum TransferState { Running(CopyBuffer, Direction), - ShuttingDown(u64), - Done(u64), + ShuttingDown, + Done, } #[derive(Debug)] @@ -47,7 +47,7 @@ fn transfer_one_direction( f: &mut impl for<'a> FnMut(Direction, &'a [u8]), r: &mut A, w: &mut B, -) -> Poll> +) -> Poll> where A: AsyncRead + AsyncWrite + Unpin + ?Sized, B: AsyncRead + AsyncWrite + Unpin + ?Sized, @@ -57,14 +57,14 @@ where loop { match state { TransferState::Running(buf, dir) => { - let count = ready!(buf.poll_copy(cx, |b| f(*dir, b), r.as_mut(), w.as_mut()))?; - *state = TransferState::ShuttingDown(count); + ready!(buf.poll_copy(cx, |b| f(*dir, b), r.as_mut(), w.as_mut()))?; + *state = TransferState::ShuttingDown; } - TransferState::ShuttingDown(count) => { + TransferState::ShuttingDown => { ready!(w.as_mut().poll_shutdown(cx)).map_err(ErrorDirection::Write)?; - *state = TransferState::Done(*count); + *state = TransferState::Done; } - TransferState::Done(count) => return Poll::Ready(Ok(*count)), + TransferState::Done => return Poll::Ready(Ok(())), } } } @@ -73,7 +73,7 @@ pub async fn copy_bidirectional_client_compute( client: &mut Client, compute: &mut Compute, mut f: impl for<'a> FnMut(Direction, &'a [u8]), -) -> Result<(u64, u64), ErrorSource> +) -> Result<(), ErrorSource> where Client: AsyncRead + AsyncWrite + Unpin + ?Sized, Compute: AsyncRead + AsyncWrite + Unpin + ?Sized, @@ -94,11 +94,11 @@ where // TODO: 1 info log, with a enum label for close direction. // Early termination checks from compute to client. - if let TransferState::Done(_) = compute_to_client { - if let TransferState::Running(buf, _) = &client_to_compute { + if let TransferState::Done = compute_to_client { + if let TransferState::Running(..) = &client_to_compute { info!("Compute is done, terminate client"); // Initiate shutdown - client_to_compute = TransferState::ShuttingDown(buf.amt); + client_to_compute = TransferState::ShuttingDown; client_to_compute_result = transfer_one_direction(cx, &mut client_to_compute, &mut f, client, compute) .map_err(ErrorSource::from_client)?; @@ -106,22 +106,21 @@ where } // Early termination checks from client to compute. - if let TransferState::Done(_) = client_to_compute { - if let TransferState::Running(buf, _) = &compute_to_client { + if let TransferState::Done = client_to_compute { + if let TransferState::Running(..) = &compute_to_client { info!("Client is done, terminate compute"); // Initiate shutdown - compute_to_client = TransferState::ShuttingDown(buf.amt); + compute_to_client = TransferState::ShuttingDown; compute_to_client_result = transfer_one_direction(cx, &mut compute_to_client, &mut f, compute, client) .map_err(ErrorSource::from_compute)?; } } - // It is not a problem if ready! returns early ... (comment remains the same) - let client_to_compute = ready!(client_to_compute_result); - let compute_to_client = ready!(compute_to_client_result); + ready!(client_to_compute_result); + ready!(compute_to_client_result); - Poll::Ready(Ok((client_to_compute, compute_to_client))) + Poll::Ready(Ok(())) }) .await } @@ -132,7 +131,6 @@ pub(super) struct CopyBuffer { need_flush: bool, pos: usize, cap: usize, - amt: u64, buf: Box<[u8]>, } const DEFAULT_BUF_SIZE: usize = 1024; @@ -144,7 +142,6 @@ impl CopyBuffer { need_flush: false, pos: 0, cap: 0, - amt: 0, buf: vec![0; DEFAULT_BUF_SIZE].into_boxed_slice(), } } @@ -205,7 +202,7 @@ impl CopyBuffer { mut f: impl for<'a> FnMut(&'a [u8]), mut reader: Pin<&mut R>, mut writer: Pin<&mut W>, - ) -> Poll> + ) -> Poll> where R: AsyncRead + ?Sized, W: AsyncWrite + ?Sized, @@ -245,7 +242,6 @@ impl CopyBuffer { )))); } self.pos += i; - self.amt += i as u64; self.need_flush = true; } @@ -265,7 +261,7 @@ impl CopyBuffer { // data and finish the transfer. if self.read_done { ready!(writer.as_mut().poll_flush(cx)).map_err(ErrorDirection::Write)?; - return Poll::Ready(Ok(self.amt)); + return Poll::Ready(Ok(())); } } } @@ -273,7 +269,7 @@ impl CopyBuffer { #[cfg(test)] mod tests { - use tokio::io::AsyncWriteExt; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; use super::*; @@ -288,15 +284,22 @@ mod tests { compute_client.write_all(b"Neon").await.unwrap(); compute_client.shutdown().await.unwrap(); - let result = - copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy, |_, _| {}) - .await - .unwrap(); + copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy, |_, _| {}) + .await + .unwrap(); + + drop(client_proxy); + drop(compute_proxy); // Assert correct transferred amounts - let (client_to_compute_count, compute_to_client_count) = result; - assert_eq!(client_to_compute_count, 5); // 'hello' was transferred - assert_eq!(compute_to_client_count, 4); // response only partially transferred or not at all + let mut client_recv = vec![]; + client_client.read_buf(&mut client_recv).await.unwrap(); + + let mut compute_recv = vec![]; + compute_client.read_buf(&mut compute_recv).await.unwrap(); + + assert_eq!(compute_recv, b"hello"); + assert_eq!(client_recv, b"Neon"); } #[tokio::test] @@ -312,14 +315,21 @@ mod tests { .await .unwrap(); - let result = - copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy, |_, _| {}) - .await - .unwrap(); + copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy, |_, _| {}) + .await + .unwrap(); + + drop(client_proxy); + drop(compute_proxy); // Assert correct transferred amounts - let (client_to_compute_count, compute_to_client_count) = result; - assert_eq!(compute_to_client_count, 5); // 'hello' was transferred - assert!(client_to_compute_count <= 8); // response only partially transferred or not at all + let mut client_recv = vec![]; + client_client.read_buf(&mut client_recv).await.unwrap(); + + let mut compute_recv = vec![]; + compute_client.read_buf(&mut compute_recv).await.unwrap(); + + assert_eq!(client_recv, b"hello"); + assert_eq!(compute_recv, b"Neon Ser"); } }