From a17a8828956d8d12300e78cc73bd3d4e18667023 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Tue, 20 May 2025 18:00:29 +0100 Subject: [PATCH] some changes --- proxy/src/proxy/copy_bidirectional.rs | 28 ++++++++++++++------------- proxy/src/proxy/passthrough.rs | 19 +++++++++++------- 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/proxy/src/proxy/copy_bidirectional.rs b/proxy/src/proxy/copy_bidirectional.rs index 227628ca75..200f24d50f 100644 --- a/proxy/src/proxy/copy_bidirectional.rs +++ b/proxy/src/proxy/copy_bidirectional.rs @@ -51,23 +51,29 @@ where Client: AsyncRead + AsyncWrite + Unpin + ?Sized, Compute: AsyncRead + AsyncWrite + Unpin + ?Sized, { - let mut client = Pin::new(client); - let mut compute = Pin::new(compute); - let f = &mut f; let mut client_to_compute = CopyBuffer::new(Direction::ClientToCompute); let mut compute_to_client = CopyBuffer::new(Direction::ComputeToClient); + let mut client = Pin::new(client); + let mut compute = Pin::new(compute); + // Initial copy hot path - poll_fn(|cx| -> Poll> { + let close_dir = poll_fn(|cx| -> Poll> { let copy1 = client_to_compute.poll_copy(cx, f, client.as_mut(), compute.as_mut())?; let copy2 = compute_to_client.poll_copy(cx, f, compute.as_mut(), client.as_mut())?; - if copy1.is_pending() && copy2.is_pending() { - return Poll::Pending; + match (copy1, copy2) { + (Poll::Pending, Poll::Pending) => Poll::Pending, + (Poll::Ready(_), _) => Poll::Ready(Ok(client_to_compute.dir)), + (_, Poll::Ready(_)) => Poll::Ready(Ok(compute_to_client.dir)), } + }) + .await?; - if copy1.is_ready() { + // initiate shutdown. + match close_dir { + Direction::ComputeToClient => { info!("Client is done, terminate compute"); // we will never write anymore data to the client. @@ -76,8 +82,7 @@ where // make sure to shutdown the client conn. compute_to_client.need_flush = true; } - - if copy2.is_ready() { + Direction::ClientToCompute => { info!("Compute is done, terminate client"); // we will never write anymore data to the compute. @@ -86,10 +91,7 @@ where // make sure to shutdown the compute conn. client_to_compute.need_flush = true; } - - Poll::Ready(Ok(())) - }) - .await?; + } // Finish sending the rest of the data to client/compute before shutting it down. // diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index 036a3e76e3..3078973908 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -8,13 +8,14 @@ use crate::compute::PostgresConnection; use crate::config::ComputeConfig; use crate::control_plane::messages::MetricsAuxInfo; use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard}; +use crate::proxy::copy_bidirectional_client_compute; use crate::stream::Stream; use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS}; /// Forward bytes in both directions (client <-> compute). #[tracing::instrument(level = "debug", skip_all)] pub(crate) async fn proxy_pass( - mut client: impl AsyncRead + AsyncWrite + Unpin, + mut client: Stream, mut compute: impl AsyncRead + AsyncWrite + Unpin, aux: MetricsAuxInfo, private_link_id: Option, @@ -43,12 +44,16 @@ pub(crate) async fn proxy_pass( // Starting from here we only proxy the client's traffic. debug!("performing the proxy pass..."); - let _ = crate::proxy::copy_bidirectional::copy_bidirectional_client_compute( - &mut client, - &mut compute, - inspect, - ) - .await?; + + // reduce branching internal to the hot path. + match &mut client { + Stream::Raw { raw } => { + copy_bidirectional_client_compute(raw, &mut compute, inspect).await? + } + Stream::Tls { tls, .. } => { + copy_bidirectional_client_compute(&mut *tls, &mut compute, inspect).await? + } + }; Ok(()) }