diff --git a/proxy/src/proxy/copy_bidirectional.rs b/proxy/src/proxy/copy_bidirectional.rs index 5de4709c7f..e040e37b9d 100644 --- a/proxy/src/proxy/copy_bidirectional.rs +++ b/proxy/src/proxy/copy_bidirectional.rs @@ -14,6 +14,36 @@ enum TransferState { Done, } +impl TransferState { + #[inline(always)] + fn shutdown(&mut self) { + let Self::Running(_) = self else { return }; + + // we go via this cold function to actually drop the buffer and write the log. + // this is quite a bit more efficient as this is not a hot function for the passthrough. + self.shutdown_cold(); + } + + /// Drop the running state, and write a log. + #[cold] + #[inline(never)] + fn shutdown_cold(&mut self) { + let Self::Running(buf) = self else { return }; + match buf.dir { + Direction::ComputeToClient => info!("Client is done, terminate compute"), + Direction::ClientToCompute => info!("Compute is done, terminate client"), + } + *self = Self::ShuttingDown(buf.dir); + } +} + +/// Mark a value as being unlikely. +#[cold] +#[inline(always)] +fn cold(i: I) -> I { + i +} + #[derive(Debug)] pub enum ErrorSource { Client(io::Error), @@ -51,15 +81,17 @@ where let mut w = Pin::new(w); loop { match state { - TransferState::Running(buf) => { - ready!(buf.poll_copy(cx, f, r.as_mut(), w.as_mut()))?; - *state = TransferState::ShuttingDown(buf.dir); - } - TransferState::ShuttingDown(dir) => { - ready!(w.as_mut().poll_shutdown(cx)).map_err(|e| ErrorSource::write(*dir, e))?; - *state = TransferState::Done; - } - TransferState::Done => return Poll::Ready(Ok(())), + TransferState::Running(buf) => match buf.poll_copy(cx, f, r.as_mut(), w.as_mut()) { + Poll::Pending => break Poll::Pending, + Poll::Ready(Err(e)) => break Poll::Ready(Err(cold(e))), + Poll::Ready(Ok(())) => *state = TransferState::ShuttingDown(buf.dir), + }, + TransferState::ShuttingDown(dir) => match w.as_mut().poll_shutdown(cx) { + Poll::Pending => break Poll::Pending, + Poll::Ready(Err(e)) => break Poll::Ready(Err(ErrorSource::write(*dir, cold(e)))), + Poll::Ready(Ok(())) => *state = TransferState::Done, + }, + TransferState::Done => break Poll::Ready(Ok(())), } } } @@ -73,43 +105,30 @@ where Client: AsyncRead + AsyncWrite + Unpin + ?Sized, Compute: AsyncRead + AsyncWrite + Unpin + ?Sized, { - let mut client_to_compute = TransferState::Running(CopyBuffer::new(Direction::ClientToCompute)); - let mut compute_to_client = TransferState::Running(CopyBuffer::new(Direction::ComputeToClient)); + let f = &mut f; + let client_to_compute = + &mut TransferState::Running(CopyBuffer::new(Direction::ClientToCompute)); + let compute_to_client = + &mut TransferState::Running(CopyBuffer::new(Direction::ComputeToClient)); poll_fn(|cx| { - let mut client_to_compute_result = - transfer_one_direction(cx, &mut client_to_compute, &mut f, client, compute)?; - let mut compute_to_client_result = - transfer_one_direction(cx, &mut compute_to_client, &mut f, compute, client)?; - - // TODO: 1 info log, with a enum label for close direction. - - // Early termination checks from compute to client. - if let TransferState::Done = compute_to_client { - if let TransferState::Running(buf) = &client_to_compute { - info!("Compute is done, terminate client"); - // Initiate shutdown - client_to_compute = TransferState::ShuttingDown(buf.dir); - client_to_compute_result = - transfer_one_direction(cx, &mut client_to_compute, &mut f, client, compute)?; + match transfer_one_direction(cx, client_to_compute, f, client, compute) { + Poll::Ready(Err(e)) => return Poll::Ready(Err(cold(e))), + Poll::Ready(Ok(())) => { + compute_to_client.shutdown(); + return transfer_one_direction(cx, compute_to_client, f, compute, client); } + Poll::Pending => {} } - // Early termination checks from client to compute. - if let TransferState::Done = client_to_compute { - if let TransferState::Running(buf) = &compute_to_client { - info!("Client is done, terminate compute"); - // Initiate shutdown - compute_to_client = TransferState::ShuttingDown(buf.dir); - compute_to_client_result = - transfer_one_direction(cx, &mut compute_to_client, &mut f, compute, client)?; - } + match transfer_one_direction(cx, compute_to_client, f, compute, client) { + Poll::Ready(Err(e)) => return Poll::Ready(Err(cold(e))), + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(())) => {} } - ready!(client_to_compute_result); - ready!(compute_to_client_result); - - Poll::Ready(Ok(())) + client_to_compute.shutdown(); + transfer_one_direction(cx, client_to_compute, f, client, compute) }) .await } @@ -291,7 +310,7 @@ mod tests { compute_client.read_buf(&mut compute_recv).await.unwrap(); assert_eq!(compute_recv, b"hello"); - assert_eq!(client_recv, b"Neon"); + assert_eq!(client_recv, b""); } #[tokio::test]