From 29c4db16588dbc485310d44730fb784bef0c778a Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 19 May 2025 22:45:28 +0100 Subject: [PATCH] update error handling --- proxy/src/proxy/copy_bidirectional.rs | 123 ++++++++++++-------------- 1 file changed, 56 insertions(+), 67 deletions(-) diff --git a/proxy/src/proxy/copy_bidirectional.rs b/proxy/src/proxy/copy_bidirectional.rs index 801720e34d..fa0d2d76ed 100644 --- a/proxy/src/proxy/copy_bidirectional.rs +++ b/proxy/src/proxy/copy_bidirectional.rs @@ -9,45 +9,40 @@ use tracing::info; use crate::metrics::Direction; enum TransferState { - Running(CopyBuffer, Direction), - ShuttingDown, + Running(CopyBuffer), + ShuttingDown(Direction), Done, } -#[derive(Debug)] -pub(crate) enum ErrorDirection { - Read(io::Error), - Write(io::Error), -} - -impl ErrorSource { - fn from_client(err: ErrorDirection) -> ErrorSource { - match err { - ErrorDirection::Read(client) => Self::Client(client), - ErrorDirection::Write(compute) => Self::Compute(compute), - } - } - fn from_compute(err: ErrorDirection) -> ErrorSource { - match err { - ErrorDirection::Write(client) => Self::Client(client), - ErrorDirection::Read(compute) => Self::Compute(compute), - } - } -} - #[derive(Debug)] pub enum ErrorSource { Client(io::Error), Compute(io::Error), } +impl ErrorSource { + fn read(dir: Direction, err: io::Error) -> Self { + match dir { + Direction::ComputeToClient => ErrorSource::Compute(err), + Direction::ClientToCompute => ErrorSource::Client(err), + } + } + + fn write(dir: Direction, err: io::Error) -> Self { + match dir { + Direction::ComputeToClient => ErrorSource::Client(err), + Direction::ClientToCompute => ErrorSource::Compute(err), + } + } +} + fn transfer_one_direction( cx: &mut Context<'_>, state: &mut TransferState, f: &mut impl for<'a> FnMut(Direction, &'a [u8]), r: &mut A, w: &mut B, -) -> Poll> +) -> Poll> where A: AsyncRead + AsyncWrite + Unpin + ?Sized, B: AsyncRead + AsyncWrite + Unpin + ?Sized, @@ -56,12 +51,12 @@ where let mut w = Pin::new(w); loop { match state { - TransferState::Running(buf, dir) => { - ready!(buf.poll_copy(cx, |b| f(*dir, b), r.as_mut(), w.as_mut()))?; - *state = TransferState::ShuttingDown; + TransferState::Running(buf) => { + ready!(buf.poll_copy(cx, f, r.as_mut(), w.as_mut()))?; + *state = TransferState::ShuttingDown(buf.dir); } - TransferState::ShuttingDown => { - ready!(w.as_mut().poll_shutdown(cx)).map_err(ErrorDirection::Write)?; + TransferState::ShuttingDown(dir) => { + ready!(w.as_mut().poll_shutdown(cx)).map_err(|e| ErrorSource::write(*dir, e))?; *state = TransferState::Done; } TransferState::Done => return Poll::Ready(Ok(())), @@ -78,42 +73,36 @@ where Client: AsyncRead + AsyncWrite + Unpin + ?Sized, Compute: AsyncRead + AsyncWrite + Unpin + ?Sized, { - let mut client_to_compute = - TransferState::Running(CopyBuffer::new(), Direction::ClientToCompute); - let mut compute_to_client = - TransferState::Running(CopyBuffer::new(), Direction::ComputeToClient); + let mut client_to_compute = TransferState::Running(CopyBuffer::new(Direction::ClientToCompute)); + let mut compute_to_client = TransferState::Running(CopyBuffer::new(Direction::ComputeToClient)); poll_fn(|cx| { let mut client_to_compute_result = - transfer_one_direction(cx, &mut client_to_compute, &mut f, client, compute) - .map_err(ErrorSource::from_client)?; + transfer_one_direction(cx, &mut client_to_compute, &mut f, client, compute)?; let mut compute_to_client_result = - transfer_one_direction(cx, &mut compute_to_client, &mut f, compute, client) - .map_err(ErrorSource::from_compute)?; + transfer_one_direction(cx, &mut compute_to_client, &mut f, compute, client)?; // TODO: 1 info log, with a enum label for close direction. // Early termination checks from compute to client. if let TransferState::Done = compute_to_client { - if let TransferState::Running(..) = &client_to_compute { + if let TransferState::Running(buf) = &client_to_compute { info!("Compute is done, terminate client"); // Initiate shutdown - client_to_compute = TransferState::ShuttingDown; + client_to_compute = TransferState::ShuttingDown(buf.dir); client_to_compute_result = - transfer_one_direction(cx, &mut client_to_compute, &mut f, client, compute) - .map_err(ErrorSource::from_client)?; + transfer_one_direction(cx, &mut client_to_compute, &mut f, client, compute)?; } } // Early termination checks from client to compute. if let TransferState::Done = client_to_compute { - if let TransferState::Running(..) = &compute_to_client { + if let TransferState::Running(buf) = &compute_to_client { info!("Client is done, terminate compute"); // Initiate shutdown - compute_to_client = TransferState::ShuttingDown; + compute_to_client = TransferState::ShuttingDown(buf.dir); compute_to_client_result = - transfer_one_direction(cx, &mut compute_to_client, &mut f, compute, client) - .map_err(ErrorSource::from_compute)?; + transfer_one_direction(cx, &mut compute_to_client, &mut f, compute, client)?; } } @@ -125,8 +114,8 @@ where .await } -#[derive(Debug)] pub(super) struct CopyBuffer { + dir: Direction, read_done: bool, need_flush: bool, pos: usize, @@ -136,8 +125,9 @@ pub(super) struct CopyBuffer { const DEFAULT_BUF_SIZE: usize = 1024; impl CopyBuffer { - pub(super) fn new() -> Self { + pub(super) fn new(dir: Direction) -> Self { Self { + dir, read_done: false, need_flush: false, pos: 0, @@ -149,9 +139,9 @@ impl CopyBuffer { fn poll_fill_buf( &mut self, cx: &mut Context<'_>, - f: &mut impl for<'a> FnMut(&'a [u8]), + f: &mut impl for<'a> FnMut(Direction, &'a [u8]), reader: Pin<&mut R>, - ) -> Poll> + ) -> Poll> where R: AsyncRead + ?Sized, { @@ -160,23 +150,23 @@ impl CopyBuffer { buf.set_filled(me.cap); let res = reader.poll_read(cx, &mut buf); - f(&buf.filled()[me.cap..]); + f(me.dir, &buf.filled()[me.cap..]); if let Poll::Ready(Ok(())) = res { let filled_len = buf.filled().len(); me.read_done = me.cap == filled_len; me.cap = filled_len; } - res + res.map_err(|e| ErrorSource::read(me.dir, e)) } fn poll_write_buf( &mut self, cx: &mut Context<'_>, - f: &mut impl for<'a> FnMut(&'a [u8]), + f: &mut impl for<'a> FnMut(Direction, &'a [u8]), mut reader: Pin<&mut R>, mut writer: Pin<&mut W>, - ) -> Poll> + ) -> Poll> where R: AsyncRead + ?Sized, W: AsyncWrite + ?Sized, @@ -187,22 +177,21 @@ impl CopyBuffer { // Top up the buffer towards full if we can read a bit more // data - this should improve the chances of a large write if !me.read_done && me.cap < me.buf.len() { - ready!(me.poll_fill_buf(cx, f, reader.as_mut())) - .map_err(ErrorDirection::Read)?; + ready!(me.poll_fill_buf(cx, f, reader.as_mut()))?; } Poll::Pending } - res @ Poll::Ready(_) => res.map_err(ErrorDirection::Write), + res @ Poll::Ready(_) => res.map_err(|e| ErrorSource::write(me.dir, e)), } } pub(super) fn poll_copy( &mut self, cx: &mut Context<'_>, - mut f: impl for<'a> FnMut(&'a [u8]), + f: &mut impl for<'a> FnMut(Direction, &'a [u8]), mut reader: Pin<&mut R>, mut writer: Pin<&mut W>, - ) -> Poll> + ) -> Poll> where R: AsyncRead + ?Sized, W: AsyncWrite + ?Sized, @@ -211,9 +200,9 @@ impl CopyBuffer { // If there is some space left in our buffer, then we try to read some // data to continue, thus maximizing the chances of a large write. if self.cap < self.buf.len() && !self.read_done { - match self.poll_fill_buf(cx, &mut f, reader.as_mut()) { + match self.poll_fill_buf(cx, f, reader.as_mut()) { Poll::Ready(Ok(())) => (), - Poll::Ready(Err(err)) => return Poll::Ready(Err(ErrorDirection::Read(err))), + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), Poll::Pending => { // Ignore pending reads when our buffer is not empty, because // we can try to write data immediately. @@ -222,7 +211,7 @@ impl CopyBuffer { // when the reader depends on buffered writer. if self.need_flush { ready!(writer.as_mut().poll_flush(cx)) - .map_err(ErrorDirection::Write)?; + .map_err(|e| ErrorSource::write(self.dir, e))?; self.need_flush = false; } @@ -234,12 +223,11 @@ impl CopyBuffer { // If our buffer has some data, let's write it out! while self.pos < self.cap { - let i = ready!(self.poll_write_buf(cx, &mut f, reader.as_mut(), writer.as_mut()))?; + let i = ready!(self.poll_write_buf(cx, f, reader.as_mut(), writer.as_mut()))?; if i == 0 { - return Poll::Ready(Err(ErrorDirection::Write(io::Error::new( - io::ErrorKind::WriteZero, - "write zero byte into writer", - )))); + let err = + io::Error::new(io::ErrorKind::WriteZero, "write zero byte into writer"); + return Poll::Ready(Err(ErrorSource::write(self.dir, err))); } self.pos += i; self.need_flush = true; @@ -260,7 +248,8 @@ impl CopyBuffer { // If we've written all the data and we've seen EOF, flush out the // data and finish the transfer. if self.read_done { - ready!(writer.as_mut().poll_flush(cx)).map_err(ErrorDirection::Write)?; + ready!(writer.as_mut().poll_flush(cx)) + .map_err(|e| ErrorSource::write(self.dir, e))?; return Poll::Ready(Ok(())); } }