replace measured stream with direct copy_bidirectional measurement integration

This commit is contained in:
Conrad Ludgate
2025-05-19 16:44:36 +01:00
parent 008cd84e7b
commit 14312f1a9a
4 changed files with 59 additions and 40 deletions

View File

@@ -383,8 +383,12 @@ async fn handle_client(
info!("performing the proxy pass...");
let res = match client {
Connection::Raw(mut c) => copy_bidirectional_client_compute(&mut tls_stream, &mut c).await,
Connection::Tls(mut c) => copy_bidirectional_client_compute(&mut tls_stream, &mut c).await,
Connection::Raw(mut c) => {
copy_bidirectional_client_compute(&mut tls_stream, &mut c, |_, _| {}).await
}
Connection::Tls(mut c) => {
copy_bidirectional_client_compute(&mut tls_stream, &mut c, |_, _| {}).await
}
};
match res {

View File

@@ -200,8 +200,10 @@ pub enum HttpDirection {
#[derive(FixedCardinalityLabel, Copy, Clone)]
#[label(singleton = "direction")]
pub enum Direction {
Tx,
Rx,
#[label(rename = "tx")]
ComputeToClient,
#[label(rename = "rx")]
ClientToCompute,
}
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]

View File

@@ -6,9 +6,10 @@ use std::task::{Context, Poll, ready};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tracing::info;
#[derive(Debug)]
use crate::metrics::Direction;
enum TransferState {
Running(CopyBuffer),
Running(CopyBuffer, Direction),
ShuttingDown(u64),
Done(u64),
}
@@ -43,6 +44,7 @@ pub enum ErrorSource {
fn transfer_one_direction<A, B>(
cx: &mut Context<'_>,
state: &mut TransferState,
f: &mut impl for<'a> FnMut(Direction, &'a [u8]),
r: &mut A,
w: &mut B,
) -> Poll<Result<u64, ErrorDirection>>
@@ -54,8 +56,8 @@ where
let mut w = Pin::new(w);
loop {
match state {
TransferState::Running(buf) => {
let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
TransferState::Running(buf, dir) => {
let count = ready!(buf.poll_copy(cx, |b| f(*dir, b), r.as_mut(), w.as_mut()))?;
*state = TransferState::ShuttingDown(count);
}
TransferState::ShuttingDown(count) => {
@@ -70,44 +72,47 @@ where
pub async fn copy_bidirectional_client_compute<Client, Compute>(
client: &mut Client,
compute: &mut Compute,
mut f: impl for<'a> FnMut(Direction, &'a [u8]),
) -> Result<(u64, u64), ErrorSource>
where
Client: AsyncRead + AsyncWrite + Unpin + ?Sized,
Compute: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
let mut client_to_compute = TransferState::Running(CopyBuffer::new());
let mut compute_to_client = TransferState::Running(CopyBuffer::new());
let mut client_to_compute =
TransferState::Running(CopyBuffer::new(), Direction::ClientToCompute);
let mut compute_to_client =
TransferState::Running(CopyBuffer::new(), Direction::ComputeToClient);
poll_fn(|cx| {
let mut client_to_compute_result =
transfer_one_direction(cx, &mut client_to_compute, client, compute)
transfer_one_direction(cx, &mut client_to_compute, &mut f, client, compute)
.map_err(ErrorSource::from_client)?;
let mut compute_to_client_result =
transfer_one_direction(cx, &mut compute_to_client, compute, client)
transfer_one_direction(cx, &mut compute_to_client, &mut f, compute, client)
.map_err(ErrorSource::from_compute)?;
// TODO: 1 info log, with a enum label for close direction.
// Early termination checks from compute to client.
if let TransferState::Done(_) = compute_to_client {
if let TransferState::Running(buf) = &client_to_compute {
if let TransferState::Running(buf, _) = &client_to_compute {
info!("Compute is done, terminate client");
// Initiate shutdown
client_to_compute = TransferState::ShuttingDown(buf.amt);
client_to_compute_result =
transfer_one_direction(cx, &mut client_to_compute, client, compute)
transfer_one_direction(cx, &mut client_to_compute, &mut f, client, compute)
.map_err(ErrorSource::from_client)?;
}
}
// Early termination checks from client to compute.
if let TransferState::Done(_) = client_to_compute {
if let TransferState::Running(buf) = &compute_to_client {
if let TransferState::Running(buf, _) = &compute_to_client {
info!("Client is done, terminate compute");
// Initiate shutdown
compute_to_client = TransferState::ShuttingDown(buf.amt);
compute_to_client_result =
transfer_one_direction(cx, &mut compute_to_client, compute, client)
transfer_one_direction(cx, &mut compute_to_client, &mut f, compute, client)
.map_err(ErrorSource::from_compute)?;
}
}
@@ -147,6 +152,7 @@ impl CopyBuffer {
fn poll_fill_buf<R>(
&mut self,
cx: &mut Context<'_>,
f: &mut impl for<'a> FnMut(&'a [u8]),
reader: Pin<&mut R>,
) -> Poll<io::Result<()>>
where
@@ -157,6 +163,8 @@ impl CopyBuffer {
buf.set_filled(me.cap);
let res = reader.poll_read(cx, &mut buf);
f(&buf.filled()[me.cap..]);
if let Poll::Ready(Ok(())) = res {
let filled_len = buf.filled().len();
me.read_done = me.cap == filled_len;
@@ -168,6 +176,7 @@ impl CopyBuffer {
fn poll_write_buf<R, W>(
&mut self,
cx: &mut Context<'_>,
f: &mut impl for<'a> FnMut(&'a [u8]),
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<Result<usize, ErrorDirection>>
@@ -181,7 +190,8 @@ impl CopyBuffer {
// Top up the buffer towards full if we can read a bit more
// data - this should improve the chances of a large write
if !me.read_done && me.cap < me.buf.len() {
ready!(me.poll_fill_buf(cx, reader.as_mut())).map_err(ErrorDirection::Read)?;
ready!(me.poll_fill_buf(cx, f, reader.as_mut()))
.map_err(ErrorDirection::Read)?;
}
Poll::Pending
}
@@ -192,6 +202,7 @@ impl CopyBuffer {
pub(super) fn poll_copy<R, W>(
&mut self,
cx: &mut Context<'_>,
mut f: impl for<'a> FnMut(&'a [u8]),
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<Result<u64, ErrorDirection>>
@@ -203,7 +214,7 @@ impl CopyBuffer {
// If there is some space left in our buffer, then we try to read some
// data to continue, thus maximizing the chances of a large write.
if self.cap < self.buf.len() && !self.read_done {
match self.poll_fill_buf(cx, reader.as_mut()) {
match self.poll_fill_buf(cx, &mut f, reader.as_mut()) {
Poll::Ready(Ok(())) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err(ErrorDirection::Read(err))),
Poll::Pending => {
@@ -226,7 +237,7 @@ impl CopyBuffer {
// If our buffer has some data, let's write it out!
while self.pos < self.cap {
let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
let i = ready!(self.poll_write_buf(cx, &mut f, reader.as_mut(), writer.as_mut()))?;
if i == 0 {
return Poll::Ready(Err(ErrorDirection::Write(io::Error::new(
io::ErrorKind::WriteZero,
@@ -277,9 +288,10 @@ mod tests {
compute_client.write_all(b"Neon").await.unwrap();
compute_client.shutdown().await.unwrap();
let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy)
.await
.unwrap();
let result =
copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy, |_, _| {})
.await
.unwrap();
// Assert correct transferred amounts
let (client_to_compute_count, compute_to_client_count) = result;
@@ -300,9 +312,10 @@ mod tests {
.await
.unwrap();
let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy)
.await
.unwrap();
let result =
copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy, |_, _| {})
.await
.unwrap();
// Assert correct transferred amounts
let (client_to_compute_count, compute_to_client_count) = result;

View File

@@ -1,7 +1,6 @@
use smol_str::SmolStr;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::debug;
use utils::measured_stream::MeasuredStream;
use super::copy_bidirectional::ErrorSource;
use crate::cancellation;
@@ -15,7 +14,7 @@ use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
/// Forward bytes in both directions (client <-> compute).
#[tracing::instrument(level = "debug", skip_all)]
pub(crate) async fn proxy_pass(
client: impl AsyncRead + AsyncWrite + Unpin,
mut client: impl AsyncRead + AsyncWrite + Unpin,
mut compute: impl AsyncRead + AsyncWrite + Unpin,
aux: MetricsAuxInfo,
private_link_id: Option<SmolStr>,
@@ -28,25 +27,26 @@ pub(crate) async fn proxy_pass(
});
let metrics = &Metrics::get().proxy.io_bytes;
let m_sent = metrics.with_labels(Direction::Tx);
let m_recv = metrics.with_labels(Direction::Rx);
let mut client = MeasuredStream::new(
client,
|bytes_read| {
metrics.get_metric(m_recv).inc_by(bytes_read as u64);
usage_tx.record_ingress(bytes_read as u64);
},
|bytes_flushed| {
metrics.get_metric(m_sent).inc_by(bytes_flushed as u64);
usage_tx.record_egress(bytes_flushed as u64);
},
);
let m_sent = metrics.with_labels(Direction::ComputeToClient);
let m_recv = metrics.with_labels(Direction::ClientToCompute);
let inspect = |direction, bytes: &[u8]| match direction {
Direction::ComputeToClient => {
metrics.get_metric(m_sent).inc_by(bytes.len() as u64);
usage_tx.record_egress(bytes.len() as u64);
}
Direction::ClientToCompute => {
metrics.get_metric(m_recv).inc_by(bytes.len() as u64);
usage_tx.record_ingress(bytes.len() as u64);
}
};
// Starting from here we only proxy the client's traffic.
debug!("performing the proxy pass...");
let _ = crate::proxy::copy_bidirectional::copy_bidirectional_client_compute(
&mut client,
&mut compute,
inspect,
)
.await?;