update error handling

This commit is contained in:
Conrad Ludgate
2025-05-19 22:45:28 +01:00
parent 727c333831
commit 29c4db1658

View File

@@ -9,45 +9,40 @@ use tracing::info;
use crate::metrics::Direction;
enum TransferState {
Running(CopyBuffer, Direction),
ShuttingDown,
Running(CopyBuffer),
ShuttingDown(Direction),
Done,
}
#[derive(Debug)]
pub(crate) enum ErrorDirection {
Read(io::Error),
Write(io::Error),
}
impl ErrorSource {
fn from_client(err: ErrorDirection) -> ErrorSource {
match err {
ErrorDirection::Read(client) => Self::Client(client),
ErrorDirection::Write(compute) => Self::Compute(compute),
}
}
fn from_compute(err: ErrorDirection) -> ErrorSource {
match err {
ErrorDirection::Write(client) => Self::Client(client),
ErrorDirection::Read(compute) => Self::Compute(compute),
}
}
}
#[derive(Debug)]
pub enum ErrorSource {
Client(io::Error),
Compute(io::Error),
}
impl ErrorSource {
fn read(dir: Direction, err: io::Error) -> Self {
match dir {
Direction::ComputeToClient => ErrorSource::Compute(err),
Direction::ClientToCompute => ErrorSource::Client(err),
}
}
fn write(dir: Direction, err: io::Error) -> Self {
match dir {
Direction::ComputeToClient => ErrorSource::Client(err),
Direction::ClientToCompute => ErrorSource::Compute(err),
}
}
}
fn transfer_one_direction<A, B>(
cx: &mut Context<'_>,
state: &mut TransferState,
f: &mut impl for<'a> FnMut(Direction, &'a [u8]),
r: &mut A,
w: &mut B,
) -> Poll<Result<(), ErrorDirection>>
) -> Poll<Result<(), ErrorSource>>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
@@ -56,12 +51,12 @@ where
let mut w = Pin::new(w);
loop {
match state {
TransferState::Running(buf, dir) => {
ready!(buf.poll_copy(cx, |b| f(*dir, b), r.as_mut(), w.as_mut()))?;
*state = TransferState::ShuttingDown;
TransferState::Running(buf) => {
ready!(buf.poll_copy(cx, f, r.as_mut(), w.as_mut()))?;
*state = TransferState::ShuttingDown(buf.dir);
}
TransferState::ShuttingDown => {
ready!(w.as_mut().poll_shutdown(cx)).map_err(ErrorDirection::Write)?;
TransferState::ShuttingDown(dir) => {
ready!(w.as_mut().poll_shutdown(cx)).map_err(|e| ErrorSource::write(*dir, e))?;
*state = TransferState::Done;
}
TransferState::Done => return Poll::Ready(Ok(())),
@@ -78,42 +73,36 @@ where
Client: AsyncRead + AsyncWrite + Unpin + ?Sized,
Compute: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
let mut client_to_compute =
TransferState::Running(CopyBuffer::new(), Direction::ClientToCompute);
let mut compute_to_client =
TransferState::Running(CopyBuffer::new(), Direction::ComputeToClient);
let mut client_to_compute = TransferState::Running(CopyBuffer::new(Direction::ClientToCompute));
let mut compute_to_client = TransferState::Running(CopyBuffer::new(Direction::ComputeToClient));
poll_fn(|cx| {
let mut client_to_compute_result =
transfer_one_direction(cx, &mut client_to_compute, &mut f, client, compute)
.map_err(ErrorSource::from_client)?;
transfer_one_direction(cx, &mut client_to_compute, &mut f, client, compute)?;
let mut compute_to_client_result =
transfer_one_direction(cx, &mut compute_to_client, &mut f, compute, client)
.map_err(ErrorSource::from_compute)?;
transfer_one_direction(cx, &mut compute_to_client, &mut f, compute, client)?;
// TODO: 1 info log, with a enum label for close direction.
// Early termination checks from compute to client.
if let TransferState::Done = compute_to_client {
if let TransferState::Running(..) = &client_to_compute {
if let TransferState::Running(buf) = &client_to_compute {
info!("Compute is done, terminate client");
// Initiate shutdown
client_to_compute = TransferState::ShuttingDown;
client_to_compute = TransferState::ShuttingDown(buf.dir);
client_to_compute_result =
transfer_one_direction(cx, &mut client_to_compute, &mut f, client, compute)
.map_err(ErrorSource::from_client)?;
transfer_one_direction(cx, &mut client_to_compute, &mut f, client, compute)?;
}
}
// Early termination checks from client to compute.
if let TransferState::Done = client_to_compute {
if let TransferState::Running(..) = &compute_to_client {
if let TransferState::Running(buf) = &compute_to_client {
info!("Client is done, terminate compute");
// Initiate shutdown
compute_to_client = TransferState::ShuttingDown;
compute_to_client = TransferState::ShuttingDown(buf.dir);
compute_to_client_result =
transfer_one_direction(cx, &mut compute_to_client, &mut f, compute, client)
.map_err(ErrorSource::from_compute)?;
transfer_one_direction(cx, &mut compute_to_client, &mut f, compute, client)?;
}
}
@@ -125,8 +114,8 @@ where
.await
}
#[derive(Debug)]
pub(super) struct CopyBuffer {
dir: Direction,
read_done: bool,
need_flush: bool,
pos: usize,
@@ -136,8 +125,9 @@ pub(super) struct CopyBuffer {
const DEFAULT_BUF_SIZE: usize = 1024;
impl CopyBuffer {
pub(super) fn new() -> Self {
pub(super) fn new(dir: Direction) -> Self {
Self {
dir,
read_done: false,
need_flush: false,
pos: 0,
@@ -149,9 +139,9 @@ impl CopyBuffer {
fn poll_fill_buf<R>(
&mut self,
cx: &mut Context<'_>,
f: &mut impl for<'a> FnMut(&'a [u8]),
f: &mut impl for<'a> FnMut(Direction, &'a [u8]),
reader: Pin<&mut R>,
) -> Poll<io::Result<()>>
) -> Poll<Result<(), ErrorSource>>
where
R: AsyncRead + ?Sized,
{
@@ -160,23 +150,23 @@ impl CopyBuffer {
buf.set_filled(me.cap);
let res = reader.poll_read(cx, &mut buf);
f(&buf.filled()[me.cap..]);
f(me.dir, &buf.filled()[me.cap..]);
if let Poll::Ready(Ok(())) = res {
let filled_len = buf.filled().len();
me.read_done = me.cap == filled_len;
me.cap = filled_len;
}
res
res.map_err(|e| ErrorSource::read(me.dir, e))
}
fn poll_write_buf<R, W>(
&mut self,
cx: &mut Context<'_>,
f: &mut impl for<'a> FnMut(&'a [u8]),
f: &mut impl for<'a> FnMut(Direction, &'a [u8]),
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<Result<usize, ErrorDirection>>
) -> Poll<Result<usize, ErrorSource>>
where
R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized,
@@ -187,22 +177,21 @@ impl CopyBuffer {
// Top up the buffer towards full if we can read a bit more
// data - this should improve the chances of a large write
if !me.read_done && me.cap < me.buf.len() {
ready!(me.poll_fill_buf(cx, f, reader.as_mut()))
.map_err(ErrorDirection::Read)?;
ready!(me.poll_fill_buf(cx, f, reader.as_mut()))?;
}
Poll::Pending
}
res @ Poll::Ready(_) => res.map_err(ErrorDirection::Write),
res @ Poll::Ready(_) => res.map_err(|e| ErrorSource::write(me.dir, e)),
}
}
pub(super) fn poll_copy<R, W>(
&mut self,
cx: &mut Context<'_>,
mut f: impl for<'a> FnMut(&'a [u8]),
f: &mut impl for<'a> FnMut(Direction, &'a [u8]),
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<Result<(), ErrorDirection>>
) -> Poll<Result<(), ErrorSource>>
where
R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized,
@@ -211,9 +200,9 @@ impl CopyBuffer {
// If there is some space left in our buffer, then we try to read some
// data to continue, thus maximizing the chances of a large write.
if self.cap < self.buf.len() && !self.read_done {
match self.poll_fill_buf(cx, &mut f, reader.as_mut()) {
match self.poll_fill_buf(cx, f, reader.as_mut()) {
Poll::Ready(Ok(())) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err(ErrorDirection::Read(err))),
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => {
// Ignore pending reads when our buffer is not empty, because
// we can try to write data immediately.
@@ -222,7 +211,7 @@ impl CopyBuffer {
// when the reader depends on buffered writer.
if self.need_flush {
ready!(writer.as_mut().poll_flush(cx))
.map_err(ErrorDirection::Write)?;
.map_err(|e| ErrorSource::write(self.dir, e))?;
self.need_flush = false;
}
@@ -234,12 +223,11 @@ impl CopyBuffer {
// If our buffer has some data, let's write it out!
while self.pos < self.cap {
let i = ready!(self.poll_write_buf(cx, &mut f, reader.as_mut(), writer.as_mut()))?;
let i = ready!(self.poll_write_buf(cx, f, reader.as_mut(), writer.as_mut()))?;
if i == 0 {
return Poll::Ready(Err(ErrorDirection::Write(io::Error::new(
io::ErrorKind::WriteZero,
"write zero byte into writer",
))));
let err =
io::Error::new(io::ErrorKind::WriteZero, "write zero byte into writer");
return Poll::Ready(Err(ErrorSource::write(self.dir, err)));
}
self.pos += i;
self.need_flush = true;
@@ -260,7 +248,8 @@ impl CopyBuffer {
// If we've written all the data and we've seen EOF, flush out the
// data and finish the transfer.
if self.read_done {
ready!(writer.as_mut().poll_flush(cx)).map_err(ErrorDirection::Write)?;
ready!(writer.as_mut().poll_flush(cx))
.map_err(|e| ErrorSource::write(self.dir, e))?;
return Poll::Ready(Ok(()));
}
}