diff --git a/proxy/src/binary/pg_sni_router.rs b/proxy/src/binary/pg_sni_router.rs index de74f606b5..e07e20318f 100644 --- a/proxy/src/binary/pg_sni_router.rs +++ b/proxy/src/binary/pg_sni_router.rs @@ -297,11 +297,9 @@ async fn handle_client( // Starting from here we only proxy the client's traffic. info!("performing the proxy pass..."); - // match copy_bidirectional_client_compute(&mut tls_stream, &mut client).await { - // Ok(_) => Ok(()), - // Err(ErrorSource::Client(err)) => Err(err).context("client"), - // Err(ErrorSource::Compute(err)) => Err(err).context("compute"), - // } - - Ok(()) + match copy_bidirectional_client_compute(&mut tls_stream, &mut client, |_, _| {}).await { + Ok(_) => Ok(()), + Err(ErrorSource::Client(err)) => Err(err).context("client"), + Err(ErrorSource::Compute(err)) => Err(err).context("compute"), + } } diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index e5fc0b724b..a8ae783f19 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -200,8 +200,10 @@ pub enum HttpDirection { #[derive(FixedCardinalityLabel, Copy, Clone)] #[label(singleton = "direction")] pub enum Direction { - Tx, - Rx, + #[label(rename = "tx")] + ComputeToClient, + #[label(rename = "rx")] + ClientToCompute, } #[derive(FixedCardinalityLabel, Clone, Copy, Debug)] diff --git a/proxy/src/proxy/copy_bidirectional.rs b/proxy/src/proxy/copy_bidirectional.rs index 8a0c53ef54..56f4946549 100644 --- a/proxy/src/proxy/copy_bidirectional.rs +++ b/proxy/src/proxy/copy_bidirectional.rs @@ -6,11 +6,11 @@ use std::task::{Context, Poll, ready}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tracing::info; -use super::conntrack::{ConnectionTracker, StateChangeObserver, StreamScannerState, TagObserver}; +use crate::metrics::Direction; #[derive(Debug)] enum TransferState { - Running(CopyBuffer, StreamScannerState), + Running(CopyBuffer), ShuttingDown(u64), Done(u64), } @@ -45,7 +45,8 @@ pub enum ErrorSource { fn transfer_one_direction( cx: &mut Context<'_>, state: &mut TransferState, - mut observer: impl TagObserver, + direction: Direction, + conn_tracker: &mut impl for<'a> FnMut(Direction, &'a [u8]), r: &mut A, w: &mut B, ) -> Poll> @@ -57,9 +58,9 @@ where let mut w = Pin::new(w); loop { match state { - TransferState::Running(buf, stream_state) => { + TransferState::Running(buf) => { let count = - ready!(buf.poll_copy(cx, stream_state, &mut observer, r.as_mut(), w.as_mut()))?; + ready!(buf.poll_copy(cx, direction, conn_tracker, r.as_mut(), w.as_mut()))?; *state = TransferState::ShuttingDown(count); } TransferState::ShuttingDown(count) => { @@ -75,20 +76,21 @@ where pub async fn copy_bidirectional_client_compute( client: &mut Client, compute: &mut Compute, - conn_tracker: &mut ConnectionTracker, + mut conn_tracker: impl for<'a> FnMut(Direction, &'a [u8]), ) -> Result<(u64, u64), ErrorSource> where Client: AsyncRead + AsyncWrite + Unpin + ?Sized, Compute: AsyncRead + AsyncWrite + Unpin + ?Sized, { - let mut client_to_compute = TransferState::Running(CopyBuffer::new(), StreamScannerState::Tag); - let mut compute_to_client = TransferState::Running(CopyBuffer::new(), StreamScannerState::Tag); + let mut client_to_compute = TransferState::Running(CopyBuffer::new()); + let mut compute_to_client = TransferState::Running(CopyBuffer::new()); poll_fn(|cx| { let mut client_to_compute_result = transfer_one_direction( cx, &mut client_to_compute, - |tag| conn_tracker.frontend_message_tag(tag), + Direction::ClientToCompute, + &mut conn_tracker, client, compute, ) @@ -96,7 +98,8 @@ where let mut compute_to_client_result = transfer_one_direction( cx, &mut compute_to_client, - |tag| conn_tracker.backend_message_tag(tag), + Direction::ComputeToClient, + &mut conn_tracker, compute, client, ) @@ -106,14 +109,15 @@ where // Early termination checks from compute to client. if let TransferState::Done(_) = compute_to_client { - if let TransferState::Running(buf, _) = &client_to_compute { + if let TransferState::Running(buf) = &client_to_compute { info!("Compute is done, terminate client"); // Initiate shutdown client_to_compute = TransferState::ShuttingDown(buf.amt); client_to_compute_result = transfer_one_direction( cx, &mut client_to_compute, - |tag| conn_tracker.frontend_message_tag(tag), + Direction::ClientToCompute, + &mut conn_tracker, client, compute, ) @@ -123,14 +127,15 @@ where // Early termination checks from client to compute. if let TransferState::Done(_) = client_to_compute { - if let TransferState::Running(buf, _) = &compute_to_client { + if let TransferState::Running(buf) = &compute_to_client { info!("Client is done, terminate compute"); // Initiate shutdown compute_to_client = TransferState::ShuttingDown(buf.amt); compute_to_client_result = transfer_one_direction( cx, &mut compute_to_client, - |tag| conn_tracker.backend_message_tag(tag), + Direction::ComputeToClient, + &mut conn_tracker, compute, client, ) @@ -173,8 +178,8 @@ impl CopyBuffer { fn poll_fill_buf( &mut self, cx: &mut Context<'_>, - state: &mut StreamScannerState, - observer: &mut impl TagObserver, + direction: Direction, + conn_tracker: &mut impl for<'a> FnMut(Direction, &'a [u8]), reader: Pin<&mut R>, ) -> Poll> where @@ -185,7 +190,7 @@ impl CopyBuffer { buf.set_filled(me.cap); let res = reader.poll_read(cx, &mut buf); - state.scan_bytes(&buf.filled()[me.cap..], observer); + conn_tracker(direction, &buf.filled()[me.cap..]); if let Poll::Ready(Ok(())) = res { let filled_len = buf.filled().len(); @@ -198,8 +203,8 @@ impl CopyBuffer { fn poll_write_buf( &mut self, cx: &mut Context<'_>, - state: &mut StreamScannerState, - observer: &mut impl TagObserver, + direction: Direction, + conn_tracker: &mut impl for<'a> FnMut(Direction, &'a [u8]), mut reader: Pin<&mut R>, mut writer: Pin<&mut W>, ) -> Poll> @@ -213,7 +218,7 @@ impl CopyBuffer { // Top up the buffer towards full if we can read a bit more // data - this should improve the chances of a large write if !me.read_done && me.cap < me.buf.len() { - ready!(me.poll_fill_buf(cx, state, observer, reader.as_mut())) + ready!(me.poll_fill_buf(cx, direction, conn_tracker, reader.as_mut())) .map_err(ErrorDirection::Read)?; } Poll::Pending @@ -225,8 +230,8 @@ impl CopyBuffer { pub(super) fn poll_copy( &mut self, cx: &mut Context<'_>, - state: &mut StreamScannerState, - observer: &mut impl TagObserver, + direction: Direction, + conn_tracker: &mut impl for<'a> FnMut(Direction, &'a [u8]), mut reader: Pin<&mut R>, mut writer: Pin<&mut W>, ) -> Poll> @@ -238,7 +243,7 @@ impl CopyBuffer { // If there is some space left in our buffer, then we try to read some // data to continue, thus maximizing the chances of a large write. if self.cap < self.buf.len() && !self.read_done { - match self.poll_fill_buf(cx, state, observer, reader.as_mut()) { + match self.poll_fill_buf(cx, direction, conn_tracker, reader.as_mut()) { Poll::Ready(Ok(())) => (), Poll::Ready(Err(err)) => return Poll::Ready(Err(ErrorDirection::Read(err))), Poll::Pending => { @@ -263,8 +268,8 @@ impl CopyBuffer { while self.pos < self.cap { let i = ready!(self.poll_write_buf( cx, - state, - observer, + direction, + conn_tracker, reader.as_mut(), writer.as_mut() ))?; @@ -305,19 +310,8 @@ impl CopyBuffer { mod tests { use tokio::io::AsyncWriteExt; - use crate::proxy::conntrack::ConnectionState; - use super::*; - #[derive(Default)] - struct Observer(Vec<(ConnectionState, ConnectionState)>); - - impl StateChangeObserver for Observer { - fn change(&mut self, old_state: ConnectionState, new_state: ConnectionState) { - self.0.push((old_state, new_state)); - } - } - #[tokio::test] async fn test_client_to_compute() { let (mut client_client, mut client_proxy) = tokio::io::duplex(8); // Create a mock duplex stream @@ -329,15 +323,10 @@ mod tests { compute_client.write_all(b"Neon").await.unwrap(); compute_client.shutdown().await.unwrap(); - let mut conn_tracker = ConnectionTracker::new(Observer::default()); - - let result = copy_bidirectional_client_compute( - &mut client_proxy, - &mut compute_proxy, - &mut conn_tracker, - ) - .await - .unwrap(); + let result = + copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy, |_, _| {}) + .await + .unwrap(); // Assert correct transferred amounts let (client_to_compute_count, compute_to_client_count) = result; @@ -358,15 +347,10 @@ mod tests { .await .unwrap(); - let mut conn_tracker = ConnectionTracker::new(Observer::default()); - - let result = copy_bidirectional_client_compute( - &mut client_proxy, - &mut compute_proxy, - &mut conn_tracker, - ) - .await - .unwrap(); + let result = + copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy, |_, _| {}) + .await + .unwrap(); // Assert correct transferred amounts let (client_to_compute_count, compute_to_client_count) = result; diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index b0e75c79bc..fde481f016 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -3,7 +3,6 @@ use std::sync::Arc; use smol_str::SmolStr; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::debug; -use utils::measured_stream::MeasuredStream; use super::copy_bidirectional::ErrorSource; use crate::cancellation; @@ -11,15 +10,16 @@ use crate::compute::PostgresConnection; use crate::config::ComputeConfig; use crate::control_plane::messages::MetricsAuxInfo; use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard}; -use crate::proxy::conntrack::ConnectionTracking; +use crate::proxy::conntrack::{ConnectionTracking, StreamScannerState}; +use crate::proxy::copy_bidirectional::copy_bidirectional_client_compute; use crate::stream::Stream; use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS}; /// Forward bytes in both directions (client <-> compute). #[tracing::instrument(skip_all)] pub(crate) async fn proxy_pass( - client: impl AsyncRead + AsyncWrite + Unpin, - compute: impl AsyncRead + AsyncWrite + Unpin, + mut client: impl AsyncRead + AsyncWrite + Unpin, + mut compute: impl AsyncRead + AsyncWrite + Unpin, aux: MetricsAuxInfo, private_link_id: Option, conntracking: &Arc, @@ -34,35 +34,32 @@ pub(crate) async fn proxy_pass( let mut conn_tracker = conntracking.new_tracker(); let metrics = &Metrics::get().proxy.io_bytes; - let m_sent = metrics.with_labels(Direction::Tx); - let mut client = MeasuredStream::new( - client, - |_| {}, - |cnt| { - // Number of bytes we sent to the client (outbound). - metrics.get_metric(m_sent).inc_by(cnt as u64); - usage_tx.record_egress(cnt as u64); - }, - ); + let m_sent = metrics.with_labels(Direction::ComputeToClient); + let m_recv = metrics.with_labels(Direction::ClientToCompute); - let m_recv = metrics.with_labels(Direction::Rx); - let mut compute = MeasuredStream::new( - compute, - |_| {}, - |cnt| { - // Number of bytes the client sent to the compute node (inbound). - metrics.get_metric(m_recv).inc_by(cnt as u64); - usage_tx.record_ingress(cnt as u64); - }, - ); + let mut client_to_compute = StreamScannerState::Tag; + let mut compute_to_client = StreamScannerState::Tag; - // Starting from here we only proxy the client's traffic. debug!("performing the proxy pass..."); - let _ = crate::proxy::copy_bidirectional::copy_bidirectional_client_compute( - &mut client, - &mut compute, - &mut conn_tracker, - ) + + let _ = copy_bidirectional_client_compute(&mut client, &mut compute, |direction, bytes| { + match direction { + Direction::ClientToCompute => { + client_to_compute + .scan_bytes(bytes, &mut |tag| conn_tracker.frontend_message_tag(tag)); + + metrics.get_metric(m_recv).inc_by(bytes.len() as u64); + usage_tx.record_ingress(bytes.len() as u64); + } + Direction::ComputeToClient => { + compute_to_client + .scan_bytes(bytes, &mut |tag| conn_tracker.backend_message_tag(tag)); + + metrics.get_metric(m_sent).inc_by(bytes.len() as u64); + usage_tx.record_egress(bytes.len() as u64); + } + } + }) .await?; Ok(())