abstract byte handling

This commit is contained in:
Conrad Ludgate
2025-05-13 14:45:21 +01:00
parent b2e0ab5dc6
commit 9cffb16463
4 changed files with 74 additions and 93 deletions

View File

@@ -297,11 +297,9 @@ async fn handle_client(
// Starting from here we only proxy the client's traffic.
info!("performing the proxy pass...");
// match copy_bidirectional_client_compute(&mut tls_stream, &mut client).await {
// Ok(_) => Ok(()),
// Err(ErrorSource::Client(err)) => Err(err).context("client"),
// Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
// }
Ok(())
match copy_bidirectional_client_compute(&mut tls_stream, &mut client, |_, _| {}).await {
Ok(_) => Ok(()),
Err(ErrorSource::Client(err)) => Err(err).context("client"),
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
}
}

View File

@@ -200,8 +200,10 @@ pub enum HttpDirection {
#[derive(FixedCardinalityLabel, Copy, Clone)]
#[label(singleton = "direction")]
pub enum Direction {
Tx,
Rx,
#[label(rename = "tx")]
ComputeToClient,
#[label(rename = "rx")]
ClientToCompute,
}
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]

View File

@@ -6,11 +6,11 @@ use std::task::{Context, Poll, ready};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tracing::info;
use super::conntrack::{ConnectionTracker, StateChangeObserver, StreamScannerState, TagObserver};
use crate::metrics::Direction;
#[derive(Debug)]
enum TransferState {
Running(CopyBuffer, StreamScannerState),
Running(CopyBuffer),
ShuttingDown(u64),
Done(u64),
}
@@ -45,7 +45,8 @@ pub enum ErrorSource {
fn transfer_one_direction<A, B>(
cx: &mut Context<'_>,
state: &mut TransferState,
mut observer: impl TagObserver,
direction: Direction,
conn_tracker: &mut impl for<'a> FnMut(Direction, &'a [u8]),
r: &mut A,
w: &mut B,
) -> Poll<Result<u64, ErrorDirection>>
@@ -57,9 +58,9 @@ where
let mut w = Pin::new(w);
loop {
match state {
TransferState::Running(buf, stream_state) => {
TransferState::Running(buf) => {
let count =
ready!(buf.poll_copy(cx, stream_state, &mut observer, r.as_mut(), w.as_mut()))?;
ready!(buf.poll_copy(cx, direction, conn_tracker, r.as_mut(), w.as_mut()))?;
*state = TransferState::ShuttingDown(count);
}
TransferState::ShuttingDown(count) => {
@@ -75,20 +76,21 @@ where
pub async fn copy_bidirectional_client_compute<Client, Compute>(
client: &mut Client,
compute: &mut Compute,
conn_tracker: &mut ConnectionTracker<impl StateChangeObserver>,
mut conn_tracker: impl for<'a> FnMut(Direction, &'a [u8]),
) -> Result<(u64, u64), ErrorSource>
where
Client: AsyncRead + AsyncWrite + Unpin + ?Sized,
Compute: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
let mut client_to_compute = TransferState::Running(CopyBuffer::new(), StreamScannerState::Tag);
let mut compute_to_client = TransferState::Running(CopyBuffer::new(), StreamScannerState::Tag);
let mut client_to_compute = TransferState::Running(CopyBuffer::new());
let mut compute_to_client = TransferState::Running(CopyBuffer::new());
poll_fn(|cx| {
let mut client_to_compute_result = transfer_one_direction(
cx,
&mut client_to_compute,
|tag| conn_tracker.frontend_message_tag(tag),
Direction::ClientToCompute,
&mut conn_tracker,
client,
compute,
)
@@ -96,7 +98,8 @@ where
let mut compute_to_client_result = transfer_one_direction(
cx,
&mut compute_to_client,
|tag| conn_tracker.backend_message_tag(tag),
Direction::ComputeToClient,
&mut conn_tracker,
compute,
client,
)
@@ -106,14 +109,15 @@ where
// Early termination checks from compute to client.
if let TransferState::Done(_) = compute_to_client {
if let TransferState::Running(buf, _) = &client_to_compute {
if let TransferState::Running(buf) = &client_to_compute {
info!("Compute is done, terminate client");
// Initiate shutdown
client_to_compute = TransferState::ShuttingDown(buf.amt);
client_to_compute_result = transfer_one_direction(
cx,
&mut client_to_compute,
|tag| conn_tracker.frontend_message_tag(tag),
Direction::ClientToCompute,
&mut conn_tracker,
client,
compute,
)
@@ -123,14 +127,15 @@ where
// Early termination checks from client to compute.
if let TransferState::Done(_) = client_to_compute {
if let TransferState::Running(buf, _) = &compute_to_client {
if let TransferState::Running(buf) = &compute_to_client {
info!("Client is done, terminate compute");
// Initiate shutdown
compute_to_client = TransferState::ShuttingDown(buf.amt);
compute_to_client_result = transfer_one_direction(
cx,
&mut compute_to_client,
|tag| conn_tracker.backend_message_tag(tag),
Direction::ComputeToClient,
&mut conn_tracker,
compute,
client,
)
@@ -173,8 +178,8 @@ impl CopyBuffer {
fn poll_fill_buf<R>(
&mut self,
cx: &mut Context<'_>,
state: &mut StreamScannerState,
observer: &mut impl TagObserver,
direction: Direction,
conn_tracker: &mut impl for<'a> FnMut(Direction, &'a [u8]),
reader: Pin<&mut R>,
) -> Poll<io::Result<()>>
where
@@ -185,7 +190,7 @@ impl CopyBuffer {
buf.set_filled(me.cap);
let res = reader.poll_read(cx, &mut buf);
state.scan_bytes(&buf.filled()[me.cap..], observer);
conn_tracker(direction, &buf.filled()[me.cap..]);
if let Poll::Ready(Ok(())) = res {
let filled_len = buf.filled().len();
@@ -198,8 +203,8 @@ impl CopyBuffer {
fn poll_write_buf<R, W>(
&mut self,
cx: &mut Context<'_>,
state: &mut StreamScannerState,
observer: &mut impl TagObserver,
direction: Direction,
conn_tracker: &mut impl for<'a> FnMut(Direction, &'a [u8]),
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<Result<usize, ErrorDirection>>
@@ -213,7 +218,7 @@ impl CopyBuffer {
// Top up the buffer towards full if we can read a bit more
// data - this should improve the chances of a large write
if !me.read_done && me.cap < me.buf.len() {
ready!(me.poll_fill_buf(cx, state, observer, reader.as_mut()))
ready!(me.poll_fill_buf(cx, direction, conn_tracker, reader.as_mut()))
.map_err(ErrorDirection::Read)?;
}
Poll::Pending
@@ -225,8 +230,8 @@ impl CopyBuffer {
pub(super) fn poll_copy<R, W>(
&mut self,
cx: &mut Context<'_>,
state: &mut StreamScannerState,
observer: &mut impl TagObserver,
direction: Direction,
conn_tracker: &mut impl for<'a> FnMut(Direction, &'a [u8]),
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<Result<u64, ErrorDirection>>
@@ -238,7 +243,7 @@ impl CopyBuffer {
// If there is some space left in our buffer, then we try to read some
// data to continue, thus maximizing the chances of a large write.
if self.cap < self.buf.len() && !self.read_done {
match self.poll_fill_buf(cx, state, observer, reader.as_mut()) {
match self.poll_fill_buf(cx, direction, conn_tracker, reader.as_mut()) {
Poll::Ready(Ok(())) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err(ErrorDirection::Read(err))),
Poll::Pending => {
@@ -263,8 +268,8 @@ impl CopyBuffer {
while self.pos < self.cap {
let i = ready!(self.poll_write_buf(
cx,
state,
observer,
direction,
conn_tracker,
reader.as_mut(),
writer.as_mut()
))?;
@@ -305,19 +310,8 @@ impl CopyBuffer {
mod tests {
use tokio::io::AsyncWriteExt;
use crate::proxy::conntrack::ConnectionState;
use super::*;
#[derive(Default)]
struct Observer(Vec<(ConnectionState, ConnectionState)>);
impl StateChangeObserver for Observer {
fn change(&mut self, old_state: ConnectionState, new_state: ConnectionState) {
self.0.push((old_state, new_state));
}
}
#[tokio::test]
async fn test_client_to_compute() {
let (mut client_client, mut client_proxy) = tokio::io::duplex(8); // Create a mock duplex stream
@@ -329,15 +323,10 @@ mod tests {
compute_client.write_all(b"Neon").await.unwrap();
compute_client.shutdown().await.unwrap();
let mut conn_tracker = ConnectionTracker::new(Observer::default());
let result = copy_bidirectional_client_compute(
&mut client_proxy,
&mut compute_proxy,
&mut conn_tracker,
)
.await
.unwrap();
let result =
copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy, |_, _| {})
.await
.unwrap();
// Assert correct transferred amounts
let (client_to_compute_count, compute_to_client_count) = result;
@@ -358,15 +347,10 @@ mod tests {
.await
.unwrap();
let mut conn_tracker = ConnectionTracker::new(Observer::default());
let result = copy_bidirectional_client_compute(
&mut client_proxy,
&mut compute_proxy,
&mut conn_tracker,
)
.await
.unwrap();
let result =
copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy, |_, _| {})
.await
.unwrap();
// Assert correct transferred amounts
let (client_to_compute_count, compute_to_client_count) = result;

View File

@@ -3,7 +3,6 @@ use std::sync::Arc;
use smol_str::SmolStr;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::debug;
use utils::measured_stream::MeasuredStream;
use super::copy_bidirectional::ErrorSource;
use crate::cancellation;
@@ -11,15 +10,16 @@ use crate::compute::PostgresConnection;
use crate::config::ComputeConfig;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
use crate::proxy::conntrack::ConnectionTracking;
use crate::proxy::conntrack::{ConnectionTracking, StreamScannerState};
use crate::proxy::copy_bidirectional::copy_bidirectional_client_compute;
use crate::stream::Stream;
use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
/// Forward bytes in both directions (client <-> compute).
#[tracing::instrument(skip_all)]
pub(crate) async fn proxy_pass(
client: impl AsyncRead + AsyncWrite + Unpin,
compute: impl AsyncRead + AsyncWrite + Unpin,
mut client: impl AsyncRead + AsyncWrite + Unpin,
mut compute: impl AsyncRead + AsyncWrite + Unpin,
aux: MetricsAuxInfo,
private_link_id: Option<SmolStr>,
conntracking: &Arc<ConnectionTracking>,
@@ -34,35 +34,32 @@ pub(crate) async fn proxy_pass(
let mut conn_tracker = conntracking.new_tracker();
let metrics = &Metrics::get().proxy.io_bytes;
let m_sent = metrics.with_labels(Direction::Tx);
let mut client = MeasuredStream::new(
client,
|_| {},
|cnt| {
// Number of bytes we sent to the client (outbound).
metrics.get_metric(m_sent).inc_by(cnt as u64);
usage_tx.record_egress(cnt as u64);
},
);
let m_sent = metrics.with_labels(Direction::ComputeToClient);
let m_recv = metrics.with_labels(Direction::ClientToCompute);
let m_recv = metrics.with_labels(Direction::Rx);
let mut compute = MeasuredStream::new(
compute,
|_| {},
|cnt| {
// Number of bytes the client sent to the compute node (inbound).
metrics.get_metric(m_recv).inc_by(cnt as u64);
usage_tx.record_ingress(cnt as u64);
},
);
let mut client_to_compute = StreamScannerState::Tag;
let mut compute_to_client = StreamScannerState::Tag;
// Starting from here we only proxy the client's traffic.
debug!("performing the proxy pass...");
let _ = crate::proxy::copy_bidirectional::copy_bidirectional_client_compute(
&mut client,
&mut compute,
&mut conn_tracker,
)
let _ = copy_bidirectional_client_compute(&mut client, &mut compute, |direction, bytes| {
match direction {
Direction::ClientToCompute => {
client_to_compute
.scan_bytes(bytes, &mut |tag| conn_tracker.frontend_message_tag(tag));
metrics.get_metric(m_recv).inc_by(bytes.len() as u64);
usage_tx.record_ingress(bytes.len() as u64);
}
Direction::ComputeToClient => {
compute_to_client
.scan_bytes(bytes, &mut |tag| conn_tracker.backend_message_tag(tag));
metrics.get_metric(m_sent).inc_by(bytes.len() as u64);
usage_tx.record_egress(bytes.len() as u64);
}
}
})
.await?;
Ok(())