diff --git a/proxy/src/binary/pg_sni_router.rs b/proxy/src/binary/pg_sni_router.rs index 3e87538ae7..29c3a45c8f 100644 --- a/proxy/src/binary/pg_sni_router.rs +++ b/proxy/src/binary/pg_sni_router.rs @@ -383,8 +383,12 @@ async fn handle_client( info!("performing the proxy pass..."); let res = match client { - Connection::Raw(mut c) => copy_bidirectional_client_compute(&mut tls_stream, &mut c).await, - Connection::Tls(mut c) => copy_bidirectional_client_compute(&mut tls_stream, &mut c).await, + Connection::Raw(mut c) => { + copy_bidirectional_client_compute(&mut tls_stream, &mut c, |_, _| {}).await + } + Connection::Tls(mut c) => { + copy_bidirectional_client_compute(&mut tls_stream, &mut c, |_, _| {}).await + } }; match res { diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index 4b22c912eb..eb1a4fddf9 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -200,8 +200,10 @@ pub enum HttpDirection { #[derive(FixedCardinalityLabel, Copy, Clone)] #[label(singleton = "direction")] pub enum Direction { - Tx, - Rx, + #[label(rename = "tx")] + ComputeToClient, + #[label(rename = "rx")] + ClientToCompute, } #[derive(FixedCardinalityLabel, Clone, Copy, Debug)] diff --git a/proxy/src/proxy/copy_bidirectional.rs b/proxy/src/proxy/copy_bidirectional.rs index 41ac5b2880..d89397508f 100644 --- a/proxy/src/proxy/copy_bidirectional.rs +++ b/proxy/src/proxy/copy_bidirectional.rs @@ -6,9 +6,10 @@ use std::task::{Context, Poll, ready}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tracing::info; -#[derive(Debug)] +use crate::metrics::Direction; + enum TransferState { - Running(CopyBuffer), + Running(CopyBuffer, Direction), ShuttingDown(u64), Done(u64), } @@ -43,6 +44,7 @@ pub enum ErrorSource { fn transfer_one_direction( cx: &mut Context<'_>, state: &mut TransferState, + f: &mut impl for<'a> FnMut(Direction, &'a [u8]), r: &mut A, w: &mut B, ) -> Poll> @@ -54,8 +56,8 @@ where let mut w = Pin::new(w); loop { match state { - TransferState::Running(buf) => { - let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?; + TransferState::Running(buf, dir) => { + let count = ready!(buf.poll_copy(cx, |b| f(*dir, b), r.as_mut(), w.as_mut()))?; *state = TransferState::ShuttingDown(count); } TransferState::ShuttingDown(count) => { @@ -70,44 +72,47 @@ where pub async fn copy_bidirectional_client_compute( client: &mut Client, compute: &mut Compute, + mut f: impl for<'a> FnMut(Direction, &'a [u8]), ) -> Result<(u64, u64), ErrorSource> where Client: AsyncRead + AsyncWrite + Unpin + ?Sized, Compute: AsyncRead + AsyncWrite + Unpin + ?Sized, { - let mut client_to_compute = TransferState::Running(CopyBuffer::new()); - let mut compute_to_client = TransferState::Running(CopyBuffer::new()); + let mut client_to_compute = + TransferState::Running(CopyBuffer::new(), Direction::ClientToCompute); + let mut compute_to_client = + TransferState::Running(CopyBuffer::new(), Direction::ComputeToClient); poll_fn(|cx| { let mut client_to_compute_result = - transfer_one_direction(cx, &mut client_to_compute, client, compute) + transfer_one_direction(cx, &mut client_to_compute, &mut f, client, compute) .map_err(ErrorSource::from_client)?; let mut compute_to_client_result = - transfer_one_direction(cx, &mut compute_to_client, compute, client) + transfer_one_direction(cx, &mut compute_to_client, &mut f, compute, client) .map_err(ErrorSource::from_compute)?; // TODO: 1 info log, with a enum label for close direction. // Early termination checks from compute to client. if let TransferState::Done(_) = compute_to_client { - if let TransferState::Running(buf) = &client_to_compute { + if let TransferState::Running(buf, _) = &client_to_compute { info!("Compute is done, terminate client"); // Initiate shutdown client_to_compute = TransferState::ShuttingDown(buf.amt); client_to_compute_result = - transfer_one_direction(cx, &mut client_to_compute, client, compute) + transfer_one_direction(cx, &mut client_to_compute, &mut f, client, compute) .map_err(ErrorSource::from_client)?; } } // Early termination checks from client to compute. if let TransferState::Done(_) = client_to_compute { - if let TransferState::Running(buf) = &compute_to_client { + if let TransferState::Running(buf, _) = &compute_to_client { info!("Client is done, terminate compute"); // Initiate shutdown compute_to_client = TransferState::ShuttingDown(buf.amt); compute_to_client_result = - transfer_one_direction(cx, &mut compute_to_client, compute, client) + transfer_one_direction(cx, &mut compute_to_client, &mut f, compute, client) .map_err(ErrorSource::from_compute)?; } } @@ -147,6 +152,7 @@ impl CopyBuffer { fn poll_fill_buf( &mut self, cx: &mut Context<'_>, + f: &mut impl for<'a> FnMut(&'a [u8]), reader: Pin<&mut R>, ) -> Poll> where @@ -157,6 +163,8 @@ impl CopyBuffer { buf.set_filled(me.cap); let res = reader.poll_read(cx, &mut buf); + f(&buf.filled()[me.cap..]); + if let Poll::Ready(Ok(())) = res { let filled_len = buf.filled().len(); me.read_done = me.cap == filled_len; @@ -168,6 +176,7 @@ impl CopyBuffer { fn poll_write_buf( &mut self, cx: &mut Context<'_>, + f: &mut impl for<'a> FnMut(&'a [u8]), mut reader: Pin<&mut R>, mut writer: Pin<&mut W>, ) -> Poll> @@ -181,7 +190,8 @@ impl CopyBuffer { // Top up the buffer towards full if we can read a bit more // data - this should improve the chances of a large write if !me.read_done && me.cap < me.buf.len() { - ready!(me.poll_fill_buf(cx, reader.as_mut())).map_err(ErrorDirection::Read)?; + ready!(me.poll_fill_buf(cx, f, reader.as_mut())) + .map_err(ErrorDirection::Read)?; } Poll::Pending } @@ -192,6 +202,7 @@ impl CopyBuffer { pub(super) fn poll_copy( &mut self, cx: &mut Context<'_>, + mut f: impl for<'a> FnMut(&'a [u8]), mut reader: Pin<&mut R>, mut writer: Pin<&mut W>, ) -> Poll> @@ -203,7 +214,7 @@ impl CopyBuffer { // If there is some space left in our buffer, then we try to read some // data to continue, thus maximizing the chances of a large write. if self.cap < self.buf.len() && !self.read_done { - match self.poll_fill_buf(cx, reader.as_mut()) { + match self.poll_fill_buf(cx, &mut f, reader.as_mut()) { Poll::Ready(Ok(())) => (), Poll::Ready(Err(err)) => return Poll::Ready(Err(ErrorDirection::Read(err))), Poll::Pending => { @@ -226,7 +237,7 @@ impl CopyBuffer { // If our buffer has some data, let's write it out! while self.pos < self.cap { - let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?; + let i = ready!(self.poll_write_buf(cx, &mut f, reader.as_mut(), writer.as_mut()))?; if i == 0 { return Poll::Ready(Err(ErrorDirection::Write(io::Error::new( io::ErrorKind::WriteZero, @@ -277,9 +288,10 @@ mod tests { compute_client.write_all(b"Neon").await.unwrap(); compute_client.shutdown().await.unwrap(); - let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy) - .await - .unwrap(); + let result = + copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy, |_, _| {}) + .await + .unwrap(); // Assert correct transferred amounts let (client_to_compute_count, compute_to_client_count) = result; @@ -300,9 +312,10 @@ mod tests { .await .unwrap(); - let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy) - .await - .unwrap(); + let result = + copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy, |_, _| {}) + .await + .unwrap(); // Assert correct transferred amounts let (client_to_compute_count, compute_to_client_count) = result; diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index c5eee1a5b5..036a3e76e3 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -1,7 +1,6 @@ use smol_str::SmolStr; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::debug; -use utils::measured_stream::MeasuredStream; use super::copy_bidirectional::ErrorSource; use crate::cancellation; @@ -15,7 +14,7 @@ use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS}; /// Forward bytes in both directions (client <-> compute). #[tracing::instrument(level = "debug", skip_all)] pub(crate) async fn proxy_pass( - client: impl AsyncRead + AsyncWrite + Unpin, + mut client: impl AsyncRead + AsyncWrite + Unpin, mut compute: impl AsyncRead + AsyncWrite + Unpin, aux: MetricsAuxInfo, private_link_id: Option, @@ -28,25 +27,26 @@ pub(crate) async fn proxy_pass( }); let metrics = &Metrics::get().proxy.io_bytes; - let m_sent = metrics.with_labels(Direction::Tx); - let m_recv = metrics.with_labels(Direction::Rx); - let mut client = MeasuredStream::new( - client, - |bytes_read| { - metrics.get_metric(m_recv).inc_by(bytes_read as u64); - usage_tx.record_ingress(bytes_read as u64); - }, - |bytes_flushed| { - metrics.get_metric(m_sent).inc_by(bytes_flushed as u64); - usage_tx.record_egress(bytes_flushed as u64); - }, - ); + let m_sent = metrics.with_labels(Direction::ComputeToClient); + let m_recv = metrics.with_labels(Direction::ClientToCompute); + + let inspect = |direction, bytes: &[u8]| match direction { + Direction::ComputeToClient => { + metrics.get_metric(m_sent).inc_by(bytes.len() as u64); + usage_tx.record_egress(bytes.len() as u64); + } + Direction::ClientToCompute => { + metrics.get_metric(m_recv).inc_by(bytes.len() as u64); + usage_tx.record_ingress(bytes.len() as u64); + } + }; // Starting from here we only proxy the client's traffic. debug!("performing the proxy pass..."); let _ = crate::proxy::copy_bidirectional::copy_bidirectional_client_compute( &mut client, &mut compute, + inspect, ) .await?;