Compare commits

...

13 Commits

Author SHA1 Message Date
Conrad Ludgate
19cbffc796 fix close order 2025-05-20 18:28:11 +01:00
Conrad Ludgate
c05f105035 clippy 2025-05-20 18:16:46 +01:00
Conrad Ludgate
a17a882895 some changes 2025-05-20 18:00:29 +01:00
Conrad Ludgate
49b6ee6c57 fix and add timeout test 2025-05-20 17:17:53 +01:00
Conrad Ludgate
69d21d85bc complete rewrite? 2025-05-20 16:54:08 +01:00
Conrad Ludgate
78f37e6f11 over-optimise copy_buffer 2025-05-19 23:26:17 +01:00
Conrad Ludgate
e5fd708495 over-optimise copy_bidirectional 2025-05-19 23:16:13 +01:00
Conrad Ludgate
f1e3f2259e refactor writes 2025-05-19 22:50:12 +01:00
Conrad Ludgate
29c4db1658 update error handling 2025-05-19 22:45:28 +01:00
Conrad Ludgate
727c333831 remove total byte copy amount 2025-05-19 22:38:09 +01:00
Conrad Ludgate
14312f1a9a replace measured stream with direct copy_bidirectional measurement integration 2025-05-19 16:44:36 +01:00
Conrad Ludgate
008cd84e7b remove one measuredstream layer of indirection 2025-05-19 16:21:32 +01:00
Conrad Ludgate
68d561664b proxy(passthrough): only instrument on debug 2025-05-19 16:17:24 +01:00
7 changed files with 366 additions and 268 deletions

View File

@@ -383,12 +383,19 @@ async fn handle_client(
info!("performing the proxy pass...");
let res = match client {
Connection::Raw(mut c) => copy_bidirectional_client_compute(&mut tls_stream, &mut c).await,
Connection::Tls(mut c) => copy_bidirectional_client_compute(&mut tls_stream, &mut c).await,
Connection::Raw(mut c) => {
copy_bidirectional_client_compute(&mut tls_stream, &mut c, |_, _| {}).await
}
Connection::Tls(mut c) => {
copy_bidirectional_client_compute(&mut tls_stream, &mut c, |_, _| {}).await
}
};
match res {
Ok(_) => Ok(()),
Ok(()) => Ok(()),
Err(ErrorSource::Timeout(_)) => Err(anyhow!(
"timed out while gracefully shutting down the connection"
)),
Err(ErrorSource::Client(err)) => Err(err).context("client"),
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
}

View File

@@ -129,6 +129,12 @@ pub async fn task_main(
let _disconnect = ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
Ok(()) => {}
Err(ErrorSource::Timeout(_)) => {
info!(
?session_id,
"per-client task timed out while gracefully shutting down the connection"
);
}
Err(ErrorSource::Client(e)) => {
error!(
?session_id,

View File

@@ -200,8 +200,10 @@ pub enum HttpDirection {
#[derive(FixedCardinalityLabel, Copy, Clone)]
#[label(singleton = "direction")]
pub enum Direction {
Tx,
Rx,
#[label(rename = "tx")]
ComputeToClient,
#[label(rename = "rx")]
ClientToCompute,
}
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]

View File

@@ -1,313 +1,394 @@
use std::future::poll_fn;
use std::io;
use std::ops::Range;
use std::pin::Pin;
use std::task::{Context, Poll, ready};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tracing::info;
#[derive(Debug)]
enum TransferState {
Running(CopyBuffer),
ShuttingDown(u64),
Done(u64),
}
use crate::metrics::Direction;
#[derive(Debug)]
pub(crate) enum ErrorDirection {
Read(io::Error),
Write(io::Error),
}
const DISCONNECT_TIMEOUT: Duration = Duration::from_secs(10);
impl ErrorSource {
fn from_client(err: ErrorDirection) -> ErrorSource {
match err {
ErrorDirection::Read(client) => Self::Client(client),
ErrorDirection::Write(compute) => Self::Compute(compute),
}
}
fn from_compute(err: ErrorDirection) -> ErrorSource {
match err {
ErrorDirection::Write(client) => Self::Client(client),
ErrorDirection::Read(compute) => Self::Compute(compute),
}
}
/// Mark a value as being unlikely.
#[cold]
#[inline(always)]
fn cold<I>(i: I) -> I {
i
}
#[derive(Debug)]
pub enum ErrorSource {
Client(io::Error),
Compute(io::Error),
Timeout(tokio::time::error::Elapsed),
}
fn transfer_one_direction<A, B>(
cx: &mut Context<'_>,
state: &mut TransferState,
r: &mut A,
w: &mut B,
) -> Poll<Result<u64, ErrorDirection>>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
let mut r = Pin::new(r);
let mut w = Pin::new(w);
loop {
match state {
TransferState::Running(buf) => {
let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
*state = TransferState::ShuttingDown(count);
}
TransferState::ShuttingDown(count) => {
ready!(w.as_mut().poll_shutdown(cx)).map_err(ErrorDirection::Write)?;
*state = TransferState::Done(*count);
}
TransferState::Done(count) => return Poll::Ready(Ok(*count)),
impl ErrorSource {
fn read(dir: Direction, err: io::Error) -> Self {
match dir {
Direction::ComputeToClient => ErrorSource::Compute(err),
Direction::ClientToCompute => ErrorSource::Client(err),
}
}
fn write(dir: Direction, err: io::Error) -> Self {
match dir {
Direction::ComputeToClient => ErrorSource::Client(err),
Direction::ClientToCompute => ErrorSource::Compute(err),
}
}
}
#[tracing::instrument(skip_all)]
pub async fn copy_bidirectional_client_compute<Client, Compute>(
client: &mut Client,
compute: &mut Compute,
) -> Result<(u64, u64), ErrorSource>
mut f: impl for<'a> FnMut(Direction, &'a [u8]),
) -> Result<(), ErrorSource>
where
Client: AsyncRead + AsyncWrite + Unpin + ?Sized,
Compute: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
let mut client_to_compute = TransferState::Running(CopyBuffer::new());
let mut compute_to_client = TransferState::Running(CopyBuffer::new());
let f = &mut f;
let mut client_to_compute = CopyBuffer::new(Direction::ClientToCompute);
let mut compute_to_client = CopyBuffer::new(Direction::ComputeToClient);
poll_fn(|cx| {
let mut client_to_compute_result =
transfer_one_direction(cx, &mut client_to_compute, client, compute)
.map_err(ErrorSource::from_client)?;
let mut compute_to_client_result =
transfer_one_direction(cx, &mut compute_to_client, compute, client)
.map_err(ErrorSource::from_compute)?;
let mut client = Pin::new(client);
let mut compute = Pin::new(compute);
// TODO: 1 info log, with a enum label for close direction.
// Initial copy hot path
let close_dir = poll_fn(|cx| -> Poll<Result<_, ErrorSource>> {
let copy1 = client_to_compute.poll_copy(cx, f, client.as_mut(), compute.as_mut())?;
let copy2 = compute_to_client.poll_copy(cx, f, compute.as_mut(), client.as_mut())?;
// Early termination checks from compute to client.
if let TransferState::Done(_) = compute_to_client {
if let TransferState::Running(buf) = &client_to_compute {
info!("Compute is done, terminate client");
// Initiate shutdown
client_to_compute = TransferState::ShuttingDown(buf.amt);
client_to_compute_result =
transfer_one_direction(cx, &mut client_to_compute, client, compute)
.map_err(ErrorSource::from_client)?;
}
match (copy1, copy2) {
(Poll::Pending, Poll::Pending) => Poll::Pending,
(Poll::Ready(()), _) => Poll::Ready(Ok(client_to_compute.dir)),
(_, Poll::Ready(())) => Poll::Ready(Ok(compute_to_client.dir)),
}
// Early termination checks from client to compute.
if let TransferState::Done(_) = client_to_compute {
if let TransferState::Running(buf) = &compute_to_client {
info!("Client is done, terminate compute");
// Initiate shutdown
compute_to_client = TransferState::ShuttingDown(buf.amt);
compute_to_client_result =
transfer_one_direction(cx, &mut compute_to_client, compute, client)
.map_err(ErrorSource::from_compute)?;
}
}
// It is not a problem if ready! returns early ... (comment remains the same)
let client_to_compute = ready!(client_to_compute_result);
let compute_to_client = ready!(compute_to_client_result);
Poll::Ready(Ok((client_to_compute, compute_to_client)))
})
.await
.await?;
// initiate shutdown.
match close_dir {
Direction::ClientToCompute => {
info!("Client is done, terminate compute");
// we will never write anymore data to the client.
compute_to_client.filled = 0..0;
// make sure to shutdown the client conn.
compute_to_client.need_flush = true;
}
Direction::ComputeToClient => {
info!("Compute is done, terminate client");
// we will never write anymore data to the compute.
client_to_compute.filled = 0..0;
// make sure to shutdown the compute conn.
client_to_compute.need_flush = true;
}
}
// Finish sending the rest of the data to client/compute before shutting it down.
//
// Edge case:
// * peer has filled the TCP buffers and is blocking on a `write()`,
// * proxy has filled the TCP buffers and is waiting on a `write()`.
// Since no side is reading from the buffers, no progress will be made.
let shutdown = poll_fn(|cx| {
let res1 = client_to_compute.poll_empty(cx, compute.as_mut())?;
let res2 = compute_to_client.poll_empty(cx, client.as_mut())?;
if res1.is_ready() && res2.is_ready() {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
});
// We assume most peers will have enough buffer space so this issue doesn't arise, but we apply
// a timeout just in case.
//
// We could also update `poll_empty` to try and read the data, but I think this is not an edge case
// worth overcomplicating.
let res = tokio::time::timeout(DISCONNECT_TIMEOUT, shutdown).await;
match res {
Ok(res) => res,
Err(timeout) => Err(ErrorSource::Timeout(timeout)),
}
}
#[derive(Debug)]
pub(super) struct CopyBuffer {
read_done: bool,
need_flush: bool,
pos: usize,
cap: usize,
amt: u64,
buf: Box<[u8]>,
}
const DEFAULT_BUF_SIZE: usize = 1024;
pub(super) struct CopyBuffer {
dir: Direction,
need_flush: bool,
filled: Range<usize>,
buf: [u8; DEFAULT_BUF_SIZE],
}
impl CopyBuffer {
pub(super) fn new() -> Self {
pub(super) const fn new(dir: Direction) -> Self {
Self {
read_done: false,
dir,
need_flush: false,
pos: 0,
cap: 0,
amt: 0,
buf: vec![0; DEFAULT_BUF_SIZE].into_boxed_slice(),
filled: 0..0,
buf: [0; DEFAULT_BUF_SIZE],
}
}
fn poll_fill_buf<R>(
/// Returns Ready(Ok(())) when no more writes could progress, and the buffer has space to read.
#[inline(always)]
fn poll_write_loop<W>(
&mut self,
cx: &mut Context<'_>,
reader: Pin<&mut R>,
) -> Poll<io::Result<()>>
where
R: AsyncRead + ?Sized,
{
let me = &mut *self;
let mut buf = ReadBuf::new(&mut me.buf);
buf.set_filled(me.cap);
let res = reader.poll_read(cx, &mut buf);
if let Poll::Ready(Ok(())) = res {
let filled_len = buf.filled().len();
me.read_done = me.cap == filled_len;
me.cap = filled_len;
}
res
}
fn poll_write_buf<R, W>(
&mut self,
cx: &mut Context<'_>,
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<Result<usize, ErrorDirection>>
) -> Poll<Result<(), ErrorSource>>
where
R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized,
{
let me = &mut *self;
match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) {
Poll::Pending => {
// Top up the buffer towards full if we can read a bit more
// data - this should improve the chances of a large write
if !me.read_done && me.cap < me.buf.len() {
ready!(me.poll_fill_buf(cx, reader.as_mut())).map_err(ErrorDirection::Read)?;
}
Poll::Pending
}
res @ Poll::Ready(_) => res.map_err(ErrorDirection::Write),
}
}
debug_assert!(!self.filled.is_empty());
pub(super) fn poll_copy<R, W>(
&mut self,
cx: &mut Context<'_>,
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<Result<u64, ErrorDirection>>
where
R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized,
{
loop {
// If there is some space left in our buffer, then we try to read some
// data to continue, thus maximizing the chances of a large write.
if self.cap < self.buf.len() && !self.read_done {
match self.poll_fill_buf(cx, reader.as_mut()) {
Poll::Ready(Ok(())) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err(ErrorDirection::Read(err))),
Poll::Pending => {
// Ignore pending reads when our buffer is not empty, because
// we can try to write data immediately.
if self.pos == self.cap {
// Try flushing when the reader has no progress to avoid deadlock
// when the reader depends on buffered writer.
if self.need_flush {
ready!(writer.as_mut().poll_flush(cx))
.map_err(ErrorDirection::Write)?;
self.need_flush = false;
}
let filled_buf = &self.buf[self.filled.clone()];
match writer.as_mut().poll_write(cx, filled_buf) {
Poll::Ready(Err(err)) => {
return Poll::Ready(Err(ErrorSource::write(self.dir, cold(err))));
}
Poll::Ready(Ok(0)) => {
let err =
io::Error::new(io::ErrorKind::WriteZero, "write zero byte into writer");
return Poll::Ready(Err(ErrorSource::write(self.dir, cold(err))));
}
Poll::Ready(Ok(i)) => {
// update the write head.
self.filled.start += i;
self.need_flush = true;
return Poll::Pending;
}
// we wrote some data, but the filled buffer might not be fully empty yet.
if !self.filled.is_empty() {
continue;
}
// the buffer is definitely empty. reset positions.
self.filled = 0..0;
break;
}
}
// If our buffer has some data, let's write it out!
while self.pos < self.cap {
let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
if i == 0 {
return Poll::Ready(Err(ErrorDirection::Write(io::Error::new(
io::ErrorKind::WriteZero,
"write zero byte into writer",
))));
}
self.pos += i;
self.amt += i as u64;
self.need_flush = true;
}
// If pos larger than cap, this loop will never stop.
// In particular, user's wrong poll_write implementation returning
// incorrect written length may lead to thread blocking.
debug_assert!(
self.pos <= self.cap,
"writer returned length larger than input slice"
);
// All data has been written, the buffer can be considered empty again
self.pos = 0;
self.cap = 0;
// If we've written all the data and we've seen EOF, flush out the
// data and finish the transfer.
if self.read_done {
ready!(writer.as_mut().poll_flush(cx)).map_err(ErrorDirection::Write)?;
return Poll::Ready(Ok(self.amt));
// While we couldn't write, we might be able to read.
Poll::Pending if self.filled.end < self.buf.len() => break,
// We couldn't write, and have no space to read. Just exit.
Poll::Pending => return Poll::Pending,
}
}
Poll::Ready(Ok(()))
}
/// Returns Ready(Ok((true))) when read returns EOF.
/// Returns Ready(Ok((false))) when read returns data.
#[inline(always)]
fn poll_read_once<R>(
&mut self,
cx: &mut Context<'_>,
f: &mut impl for<'a> FnMut(Direction, &'a [u8]),
reader: Pin<&mut R>,
) -> Poll<Result<bool, ErrorSource>>
where
R: AsyncRead + ?Sized,
{
debug_assert!(self.filled.end < self.buf.len());
let mut buf = ReadBuf::new(&mut self.buf[self.filled.end..]);
match reader.poll_read(cx, &mut buf) {
Poll::Ready(Ok(())) => {
let filled = buf.filled();
// no more data to read, switch to shutdown mode.
if filled.is_empty() {
self.need_flush = true;
return Poll::Ready(Ok(true));
}
// run our inspection callback.
f(self.dir, filled);
// update the read head.
self.filled.end += filled.len();
// read more data
Poll::Ready(Ok(false))
}
// cannot continue on error.
Poll::Ready(Err(e)) => Poll::Ready(Err(ErrorSource::read(self.dir, cold(e)))),
// No more data to read, and no more data to write.
Poll::Pending => Poll::Pending,
}
}
/// Returns Ready(Ok(())) when read returns EOF.
fn poll_copy<R, W>(
&mut self,
cx: &mut Context<'_>,
f: &mut impl for<'a> FnMut(Direction, &'a [u8]),
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<Result<(), ErrorSource>>
where
R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized,
{
// this register eliminates a branch in the hot loop.
let mut empty = self.filled.is_empty();
// write then read hot loop
loop {
if !empty {
ready!(self.poll_write_loop(cx, writer.as_mut())?);
}
// If empty is true, there is guaranteed space to read.
// If empty is false, and the write loop returned ready, then we know there's space for more reads.
match self.poll_read_once(cx, f, reader.as_mut())? {
// EOF
Poll::Ready(true) => return Poll::Ready(Ok(())),
// Needs write.
Poll::Ready(false) => empty = false,
// Cannot read. The peer might not send us anything until
// they receive data from us, so let's switch to flushing.
Poll::Pending => break,
}
}
if self.need_flush {
let flush = writer.as_mut().poll_flush(cx);
ready!(flush.map_err(|e| ErrorSource::write(self.dir, e))?);
self.need_flush = false;
}
// there might be more data still to read.
Poll::Pending
}
/// Returns Ready(Ok(())) when the conn is fully shutdown.
pub(super) fn poll_empty<W>(
&mut self,
cx: &mut Context<'_>,
mut writer: Pin<&mut W>,
) -> Poll<Result<(), ErrorSource>>
where
W: AsyncWrite + ?Sized,
{
if !self.filled.is_empty() {
ready!(self.poll_write_loop(cx, writer.as_mut())?);
if !self.filled.is_empty() {
// still some data to write
return Poll::Pending;
}
}
if self.need_flush {
let res = writer.poll_shutdown(cx);
ready!(res.map_err(|e| ErrorSource::write(self.dir, e))?);
self.need_flush = false;
}
// no data to read, no data to write.
Poll::Ready(Ok(()))
}
}
#[cfg(test)]
mod tests {
use tokio::io::AsyncWriteExt;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use super::*;
#[tokio::test]
async fn test_client_to_compute() {
let (mut client_client, mut client_proxy) = tokio::io::duplex(8); // Create a mock duplex stream
let (mut compute_proxy, mut compute_client) = tokio::io::duplex(32); // Create a mock duplex stream
let (mut client, mut client_proxy) = tokio::io::duplex(32); // Create a mock duplex stream
let (mut proxy_compute, mut compute) = tokio::io::duplex(16); // Create a mock duplex stream
// Simulate 'a' finishing while there's still data for 'b'
client_client.write_all(b"hello").await.unwrap();
client_client.shutdown().await.unwrap();
compute_client.write_all(b"Neon").await.unwrap();
compute_client.shutdown().await.unwrap();
client.write_all(b"Neon Serverless Postgres").await.unwrap();
compute.write_all(b"is amazing").await.unwrap();
let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy)
.await
.unwrap();
client.shutdown().await.unwrap();
let copy = tokio::spawn(async move {
copy_bidirectional_client_compute(&mut client_proxy, &mut proxy_compute, |_, _| {})
.await
.unwrap();
});
// Assert correct transferred amounts
let (client_to_compute_count, compute_to_client_count) = result;
assert_eq!(client_to_compute_count, 5); // 'hello' was transferred
assert_eq!(compute_to_client_count, 4); // response only partially transferred or not at all
let mut client_recv = String::new();
let mut compute_recv = String::new();
client.read_to_string(&mut client_recv).await.unwrap();
compute.read_to_string(&mut compute_recv).await.unwrap();
assert_eq!(compute_recv, "Neon Serverless Postgres");
assert_eq!(client_recv, "is amazing");
copy.await.unwrap();
}
#[tokio::test]
async fn test_compute_to_client() {
let (mut client_client, mut client_proxy) = tokio::io::duplex(32); // Create a mock duplex stream
let (mut compute_proxy, mut compute_client) = tokio::io::duplex(8); // Create a mock duplex stream
let (mut client, mut client_proxy) = tokio::io::duplex(32); // Create a mock duplex stream
let (mut proxy_compute, mut compute) = tokio::io::duplex(16); // Create a mock duplex stream
// Simulate 'a' finishing while there's still data for 'b'
compute_client.write_all(b"hello").await.unwrap();
compute_client.shutdown().await.unwrap();
client_client
.write_all(b"Neon Serverless Postgres")
.await
.unwrap();
client.write_all(b"Neon Serverless Postgres").await.unwrap();
compute.write_all(b"is amazing").await.unwrap();
let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy)
.await
.unwrap();
compute.shutdown().await.unwrap();
let copy = tokio::spawn(async move {
copy_bidirectional_client_compute(&mut client_proxy, &mut proxy_compute, |_, _| {})
.await
.unwrap();
});
// Assert correct transferred amounts
let (client_to_compute_count, compute_to_client_count) = result;
assert_eq!(compute_to_client_count, 5); // 'hello' was transferred
assert!(client_to_compute_count <= 8); // response only partially transferred or not at all
let mut client_recv = String::new();
let mut compute_recv = String::new();
client.read_to_string(&mut client_recv).await.unwrap();
compute.read_to_string(&mut compute_recv).await.unwrap();
assert_eq!(compute_recv, "Neon Serverless ");
assert_eq!(client_recv, "is amazing");
copy.await.unwrap();
}
#[tokio::test(start_paused = true)]
async fn test_timeout() {
let (mut client, mut client_proxy) = tokio::io::duplex(32); // Create a mock duplex stream
let (mut proxy_compute, mut compute) = tokio::io::duplex(16); // Create a mock duplex stream
// Try to send 24 bytes to compute, but compute only has space for 16 bytes.
// Writes will not succeed.
client.write_all(b"Neon Serverless Postgres").await.unwrap();
client.shutdown().await.unwrap();
let copy = tokio::spawn(async move {
copy_bidirectional_client_compute(&mut client_proxy, &mut proxy_compute, |_, _| {})
.await
.unwrap_err()
});
tokio::time::advance(DISCONNECT_TIMEOUT).await;
let res = copy.await.unwrap();
assert!(matches!(res, ErrorSource::Timeout(_)));
// Assert correct transferred amounts
let mut compute_recv = String::new();
compute.read_to_string(&mut compute_recv).await.unwrap();
assert_eq!(compute_recv, "Neon Serverless ");
}
}

View File

@@ -167,6 +167,12 @@ pub async fn task_main(
let _disconnect = ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
Ok(()) => {}
Err(ErrorSource::Timeout(_)) => {
info!(
?session_id,
"per-client task timed out while gracefully shutting down the connection"
);
}
Err(ErrorSource::Client(e)) => {
warn!(
?session_id,

View File

@@ -1,7 +1,6 @@
use smol_str::SmolStr;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::debug;
use utils::measured_stream::MeasuredStream;
use super::copy_bidirectional::ErrorSource;
use crate::cancellation;
@@ -9,14 +8,15 @@ use crate::compute::PostgresConnection;
use crate::config::ComputeConfig;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
use crate::proxy::copy_bidirectional_client_compute;
use crate::stream::Stream;
use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
/// Forward bytes in both directions (client <-> compute).
#[tracing::instrument(skip_all)]
#[tracing::instrument(level = "debug", skip_all)]
pub(crate) async fn proxy_pass(
client: impl AsyncRead + AsyncWrite + Unpin,
compute: impl AsyncRead + AsyncWrite + Unpin,
mut client: Stream<impl AsyncRead + AsyncWrite + Unpin>,
mut compute: impl AsyncRead + AsyncWrite + Unpin,
aux: MetricsAuxInfo,
private_link_id: Option<SmolStr>,
) -> Result<(), ErrorSource> {
@@ -28,37 +28,30 @@ pub(crate) async fn proxy_pass(
});
let metrics = &Metrics::get().proxy.io_bytes;
let m_sent = metrics.with_labels(Direction::Tx);
let mut client = MeasuredStream::new(
client,
|_| {},
|cnt| {
// Number of bytes we sent to the client (outbound).
metrics.get_metric(m_sent).inc_by(cnt as u64);
usage_tx.record_egress(cnt as u64);
},
);
let m_sent = metrics.with_labels(Direction::ComputeToClient);
let m_recv = metrics.with_labels(Direction::ClientToCompute);
let m_recv = metrics.with_labels(Direction::Rx);
let mut compute = MeasuredStream::new(
compute,
|_| {},
|cnt| {
// Number of bytes the client sent to the compute node (inbound).
metrics.get_metric(m_recv).inc_by(cnt as u64);
usage_tx.record_ingress(cnt as u64);
},
);
let inspect = |direction, bytes: &[u8]| match direction {
Direction::ComputeToClient => {
metrics.get_metric(m_sent).inc_by(bytes.len() as u64);
usage_tx.record_egress(bytes.len() as u64);
}
Direction::ClientToCompute => {
metrics.get_metric(m_recv).inc_by(bytes.len() as u64);
usage_tx.record_ingress(bytes.len() as u64);
}
};
// Starting from here we only proxy the client's traffic.
debug!("performing the proxy pass...");
let _ = crate::proxy::copy_bidirectional::copy_bidirectional_client_compute(
&mut client,
&mut compute,
)
.await?;
Ok(())
// reduce branching internal to the hot path.
match &mut client {
Stream::Raw { raw } => copy_bidirectional_client_compute(raw, &mut compute, inspect).await,
Stream::Tls { tls, .. } => {
copy_bidirectional_client_compute(&mut *tls, &mut compute, inspect).await
}
}
}
pub(crate) struct ProxyPassthrough<S> {

View File

@@ -2,7 +2,7 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, ready};
use anyhow::Context as _;
use anyhow::{Context as _, anyhow};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use framed_websockets::{Frame, OpCode, WebSocketServer};
use futures::{Sink, Stream};
@@ -169,6 +169,9 @@ pub(crate) async fn serve_websocket(
ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
Ok(()) => Ok(()),
Err(ErrorSource::Timeout(_)) => Err(anyhow!(
"timed out while gracefully shutting down the connection"
)),
Err(ErrorSource::Client(err)) => Err(err).context("client"),
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
}