over-optimise copy_bidirectional

This commit is contained in:
Conrad Ludgate
2025-05-19 23:16:13 +01:00
parent f1e3f2259e
commit e5fd708495

View File

@@ -14,6 +14,36 @@ enum TransferState {
Done,
}
impl TransferState {
#[inline(always)]
fn shutdown(&mut self) {
let Self::Running(_) = self else { return };
// we go via this cold function to actually drop the buffer and write the log.
// this is quite a bit more efficient as this is not a hot function for the passthrough.
self.shutdown_cold();
}
/// Drop the running state, and write a log.
#[cold]
#[inline(never)]
fn shutdown_cold(&mut self) {
let Self::Running(buf) = self else { return };
match buf.dir {
Direction::ComputeToClient => info!("Client is done, terminate compute"),
Direction::ClientToCompute => info!("Compute is done, terminate client"),
}
*self = Self::ShuttingDown(buf.dir);
}
}
/// Mark a value as being unlikely.
#[cold]
#[inline(always)]
fn cold<I>(i: I) -> I {
i
}
#[derive(Debug)]
pub enum ErrorSource {
Client(io::Error),
@@ -51,15 +81,17 @@ where
let mut w = Pin::new(w);
loop {
match state {
TransferState::Running(buf) => {
ready!(buf.poll_copy(cx, f, r.as_mut(), w.as_mut()))?;
*state = TransferState::ShuttingDown(buf.dir);
}
TransferState::ShuttingDown(dir) => {
ready!(w.as_mut().poll_shutdown(cx)).map_err(|e| ErrorSource::write(*dir, e))?;
*state = TransferState::Done;
}
TransferState::Done => return Poll::Ready(Ok(())),
TransferState::Running(buf) => match buf.poll_copy(cx, f, r.as_mut(), w.as_mut()) {
Poll::Pending => break Poll::Pending,
Poll::Ready(Err(e)) => break Poll::Ready(Err(cold(e))),
Poll::Ready(Ok(())) => *state = TransferState::ShuttingDown(buf.dir),
},
TransferState::ShuttingDown(dir) => match w.as_mut().poll_shutdown(cx) {
Poll::Pending => break Poll::Pending,
Poll::Ready(Err(e)) => break Poll::Ready(Err(ErrorSource::write(*dir, cold(e)))),
Poll::Ready(Ok(())) => *state = TransferState::Done,
},
TransferState::Done => break Poll::Ready(Ok(())),
}
}
}
@@ -73,43 +105,30 @@ where
Client: AsyncRead + AsyncWrite + Unpin + ?Sized,
Compute: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
let mut client_to_compute = TransferState::Running(CopyBuffer::new(Direction::ClientToCompute));
let mut compute_to_client = TransferState::Running(CopyBuffer::new(Direction::ComputeToClient));
let f = &mut f;
let client_to_compute =
&mut TransferState::Running(CopyBuffer::new(Direction::ClientToCompute));
let compute_to_client =
&mut TransferState::Running(CopyBuffer::new(Direction::ComputeToClient));
poll_fn(|cx| {
let mut client_to_compute_result =
transfer_one_direction(cx, &mut client_to_compute, &mut f, client, compute)?;
let mut compute_to_client_result =
transfer_one_direction(cx, &mut compute_to_client, &mut f, compute, client)?;
// TODO: 1 info log, with a enum label for close direction.
// Early termination checks from compute to client.
if let TransferState::Done = compute_to_client {
if let TransferState::Running(buf) = &client_to_compute {
info!("Compute is done, terminate client");
// Initiate shutdown
client_to_compute = TransferState::ShuttingDown(buf.dir);
client_to_compute_result =
transfer_one_direction(cx, &mut client_to_compute, &mut f, client, compute)?;
match transfer_one_direction(cx, client_to_compute, f, client, compute) {
Poll::Ready(Err(e)) => return Poll::Ready(Err(cold(e))),
Poll::Ready(Ok(())) => {
compute_to_client.shutdown();
return transfer_one_direction(cx, compute_to_client, f, compute, client);
}
Poll::Pending => {}
}
// Early termination checks from client to compute.
if let TransferState::Done = client_to_compute {
if let TransferState::Running(buf) = &compute_to_client {
info!("Client is done, terminate compute");
// Initiate shutdown
compute_to_client = TransferState::ShuttingDown(buf.dir);
compute_to_client_result =
transfer_one_direction(cx, &mut compute_to_client, &mut f, compute, client)?;
}
match transfer_one_direction(cx, compute_to_client, f, compute, client) {
Poll::Ready(Err(e)) => return Poll::Ready(Err(cold(e))),
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(())) => {}
}
ready!(client_to_compute_result);
ready!(compute_to_client_result);
Poll::Ready(Ok(()))
client_to_compute.shutdown();
transfer_one_direction(cx, client_to_compute, f, client, compute)
})
.await
}
@@ -291,7 +310,7 @@ mod tests {
compute_client.read_buf(&mut compute_recv).await.unwrap();
assert_eq!(compute_recv, b"hello");
assert_eq!(client_recv, b"Neon");
assert_eq!(client_recv, b"");
}
#[tokio::test]