mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-17 02:12:56 +00:00
replace measured stream with direct copy_bidirectional measurement integration
This commit is contained in:
@@ -383,8 +383,12 @@ async fn handle_client(
|
||||
info!("performing the proxy pass...");
|
||||
|
||||
let res = match client {
|
||||
Connection::Raw(mut c) => copy_bidirectional_client_compute(&mut tls_stream, &mut c).await,
|
||||
Connection::Tls(mut c) => copy_bidirectional_client_compute(&mut tls_stream, &mut c).await,
|
||||
Connection::Raw(mut c) => {
|
||||
copy_bidirectional_client_compute(&mut tls_stream, &mut c, |_, _| {}).await
|
||||
}
|
||||
Connection::Tls(mut c) => {
|
||||
copy_bidirectional_client_compute(&mut tls_stream, &mut c, |_, _| {}).await
|
||||
}
|
||||
};
|
||||
|
||||
match res {
|
||||
|
||||
@@ -200,8 +200,10 @@ pub enum HttpDirection {
|
||||
#[derive(FixedCardinalityLabel, Copy, Clone)]
|
||||
#[label(singleton = "direction")]
|
||||
pub enum Direction {
|
||||
Tx,
|
||||
Rx,
|
||||
#[label(rename = "tx")]
|
||||
ComputeToClient,
|
||||
#[label(rename = "rx")]
|
||||
ClientToCompute,
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]
|
||||
|
||||
@@ -6,9 +6,10 @@ use std::task::{Context, Poll, ready};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tracing::info;
|
||||
|
||||
#[derive(Debug)]
|
||||
use crate::metrics::Direction;
|
||||
|
||||
enum TransferState {
|
||||
Running(CopyBuffer),
|
||||
Running(CopyBuffer, Direction),
|
||||
ShuttingDown(u64),
|
||||
Done(u64),
|
||||
}
|
||||
@@ -43,6 +44,7 @@ pub enum ErrorSource {
|
||||
fn transfer_one_direction<A, B>(
|
||||
cx: &mut Context<'_>,
|
||||
state: &mut TransferState,
|
||||
f: &mut impl for<'a> FnMut(Direction, &'a [u8]),
|
||||
r: &mut A,
|
||||
w: &mut B,
|
||||
) -> Poll<Result<u64, ErrorDirection>>
|
||||
@@ -54,8 +56,8 @@ where
|
||||
let mut w = Pin::new(w);
|
||||
loop {
|
||||
match state {
|
||||
TransferState::Running(buf) => {
|
||||
let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
|
||||
TransferState::Running(buf, dir) => {
|
||||
let count = ready!(buf.poll_copy(cx, |b| f(*dir, b), r.as_mut(), w.as_mut()))?;
|
||||
*state = TransferState::ShuttingDown(count);
|
||||
}
|
||||
TransferState::ShuttingDown(count) => {
|
||||
@@ -70,44 +72,47 @@ where
|
||||
pub async fn copy_bidirectional_client_compute<Client, Compute>(
|
||||
client: &mut Client,
|
||||
compute: &mut Compute,
|
||||
mut f: impl for<'a> FnMut(Direction, &'a [u8]),
|
||||
) -> Result<(u64, u64), ErrorSource>
|
||||
where
|
||||
Client: AsyncRead + AsyncWrite + Unpin + ?Sized,
|
||||
Compute: AsyncRead + AsyncWrite + Unpin + ?Sized,
|
||||
{
|
||||
let mut client_to_compute = TransferState::Running(CopyBuffer::new());
|
||||
let mut compute_to_client = TransferState::Running(CopyBuffer::new());
|
||||
let mut client_to_compute =
|
||||
TransferState::Running(CopyBuffer::new(), Direction::ClientToCompute);
|
||||
let mut compute_to_client =
|
||||
TransferState::Running(CopyBuffer::new(), Direction::ComputeToClient);
|
||||
|
||||
poll_fn(|cx| {
|
||||
let mut client_to_compute_result =
|
||||
transfer_one_direction(cx, &mut client_to_compute, client, compute)
|
||||
transfer_one_direction(cx, &mut client_to_compute, &mut f, client, compute)
|
||||
.map_err(ErrorSource::from_client)?;
|
||||
let mut compute_to_client_result =
|
||||
transfer_one_direction(cx, &mut compute_to_client, compute, client)
|
||||
transfer_one_direction(cx, &mut compute_to_client, &mut f, compute, client)
|
||||
.map_err(ErrorSource::from_compute)?;
|
||||
|
||||
// TODO: 1 info log, with a enum label for close direction.
|
||||
|
||||
// Early termination checks from compute to client.
|
||||
if let TransferState::Done(_) = compute_to_client {
|
||||
if let TransferState::Running(buf) = &client_to_compute {
|
||||
if let TransferState::Running(buf, _) = &client_to_compute {
|
||||
info!("Compute is done, terminate client");
|
||||
// Initiate shutdown
|
||||
client_to_compute = TransferState::ShuttingDown(buf.amt);
|
||||
client_to_compute_result =
|
||||
transfer_one_direction(cx, &mut client_to_compute, client, compute)
|
||||
transfer_one_direction(cx, &mut client_to_compute, &mut f, client, compute)
|
||||
.map_err(ErrorSource::from_client)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Early termination checks from client to compute.
|
||||
if let TransferState::Done(_) = client_to_compute {
|
||||
if let TransferState::Running(buf) = &compute_to_client {
|
||||
if let TransferState::Running(buf, _) = &compute_to_client {
|
||||
info!("Client is done, terminate compute");
|
||||
// Initiate shutdown
|
||||
compute_to_client = TransferState::ShuttingDown(buf.amt);
|
||||
compute_to_client_result =
|
||||
transfer_one_direction(cx, &mut compute_to_client, compute, client)
|
||||
transfer_one_direction(cx, &mut compute_to_client, &mut f, compute, client)
|
||||
.map_err(ErrorSource::from_compute)?;
|
||||
}
|
||||
}
|
||||
@@ -147,6 +152,7 @@ impl CopyBuffer {
|
||||
fn poll_fill_buf<R>(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
f: &mut impl for<'a> FnMut(&'a [u8]),
|
||||
reader: Pin<&mut R>,
|
||||
) -> Poll<io::Result<()>>
|
||||
where
|
||||
@@ -157,6 +163,8 @@ impl CopyBuffer {
|
||||
buf.set_filled(me.cap);
|
||||
|
||||
let res = reader.poll_read(cx, &mut buf);
|
||||
f(&buf.filled()[me.cap..]);
|
||||
|
||||
if let Poll::Ready(Ok(())) = res {
|
||||
let filled_len = buf.filled().len();
|
||||
me.read_done = me.cap == filled_len;
|
||||
@@ -168,6 +176,7 @@ impl CopyBuffer {
|
||||
fn poll_write_buf<R, W>(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
f: &mut impl for<'a> FnMut(&'a [u8]),
|
||||
mut reader: Pin<&mut R>,
|
||||
mut writer: Pin<&mut W>,
|
||||
) -> Poll<Result<usize, ErrorDirection>>
|
||||
@@ -181,7 +190,8 @@ impl CopyBuffer {
|
||||
// Top up the buffer towards full if we can read a bit more
|
||||
// data - this should improve the chances of a large write
|
||||
if !me.read_done && me.cap < me.buf.len() {
|
||||
ready!(me.poll_fill_buf(cx, reader.as_mut())).map_err(ErrorDirection::Read)?;
|
||||
ready!(me.poll_fill_buf(cx, f, reader.as_mut()))
|
||||
.map_err(ErrorDirection::Read)?;
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
@@ -192,6 +202,7 @@ impl CopyBuffer {
|
||||
pub(super) fn poll_copy<R, W>(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
mut f: impl for<'a> FnMut(&'a [u8]),
|
||||
mut reader: Pin<&mut R>,
|
||||
mut writer: Pin<&mut W>,
|
||||
) -> Poll<Result<u64, ErrorDirection>>
|
||||
@@ -203,7 +214,7 @@ impl CopyBuffer {
|
||||
// If there is some space left in our buffer, then we try to read some
|
||||
// data to continue, thus maximizing the chances of a large write.
|
||||
if self.cap < self.buf.len() && !self.read_done {
|
||||
match self.poll_fill_buf(cx, reader.as_mut()) {
|
||||
match self.poll_fill_buf(cx, &mut f, reader.as_mut()) {
|
||||
Poll::Ready(Ok(())) => (),
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(ErrorDirection::Read(err))),
|
||||
Poll::Pending => {
|
||||
@@ -226,7 +237,7 @@ impl CopyBuffer {
|
||||
|
||||
// If our buffer has some data, let's write it out!
|
||||
while self.pos < self.cap {
|
||||
let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
|
||||
let i = ready!(self.poll_write_buf(cx, &mut f, reader.as_mut(), writer.as_mut()))?;
|
||||
if i == 0 {
|
||||
return Poll::Ready(Err(ErrorDirection::Write(io::Error::new(
|
||||
io::ErrorKind::WriteZero,
|
||||
@@ -277,9 +288,10 @@ mod tests {
|
||||
compute_client.write_all(b"Neon").await.unwrap();
|
||||
compute_client.shutdown().await.unwrap();
|
||||
|
||||
let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy)
|
||||
.await
|
||||
.unwrap();
|
||||
let result =
|
||||
copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy, |_, _| {})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Assert correct transferred amounts
|
||||
let (client_to_compute_count, compute_to_client_count) = result;
|
||||
@@ -300,9 +312,10 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy)
|
||||
.await
|
||||
.unwrap();
|
||||
let result =
|
||||
copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy, |_, _| {})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Assert correct transferred amounts
|
||||
let (client_to_compute_count, compute_to_client_count) = result;
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use smol_str::SmolStr;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::debug;
|
||||
use utils::measured_stream::MeasuredStream;
|
||||
|
||||
use super::copy_bidirectional::ErrorSource;
|
||||
use crate::cancellation;
|
||||
@@ -15,7 +14,7 @@ use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
|
||||
/// Forward bytes in both directions (client <-> compute).
|
||||
#[tracing::instrument(level = "debug", skip_all)]
|
||||
pub(crate) async fn proxy_pass(
|
||||
client: impl AsyncRead + AsyncWrite + Unpin,
|
||||
mut client: impl AsyncRead + AsyncWrite + Unpin,
|
||||
mut compute: impl AsyncRead + AsyncWrite + Unpin,
|
||||
aux: MetricsAuxInfo,
|
||||
private_link_id: Option<SmolStr>,
|
||||
@@ -28,25 +27,26 @@ pub(crate) async fn proxy_pass(
|
||||
});
|
||||
|
||||
let metrics = &Metrics::get().proxy.io_bytes;
|
||||
let m_sent = metrics.with_labels(Direction::Tx);
|
||||
let m_recv = metrics.with_labels(Direction::Rx);
|
||||
let mut client = MeasuredStream::new(
|
||||
client,
|
||||
|bytes_read| {
|
||||
metrics.get_metric(m_recv).inc_by(bytes_read as u64);
|
||||
usage_tx.record_ingress(bytes_read as u64);
|
||||
},
|
||||
|bytes_flushed| {
|
||||
metrics.get_metric(m_sent).inc_by(bytes_flushed as u64);
|
||||
usage_tx.record_egress(bytes_flushed as u64);
|
||||
},
|
||||
);
|
||||
let m_sent = metrics.with_labels(Direction::ComputeToClient);
|
||||
let m_recv = metrics.with_labels(Direction::ClientToCompute);
|
||||
|
||||
let inspect = |direction, bytes: &[u8]| match direction {
|
||||
Direction::ComputeToClient => {
|
||||
metrics.get_metric(m_sent).inc_by(bytes.len() as u64);
|
||||
usage_tx.record_egress(bytes.len() as u64);
|
||||
}
|
||||
Direction::ClientToCompute => {
|
||||
metrics.get_metric(m_recv).inc_by(bytes.len() as u64);
|
||||
usage_tx.record_ingress(bytes.len() as u64);
|
||||
}
|
||||
};
|
||||
|
||||
// Starting from here we only proxy the client's traffic.
|
||||
debug!("performing the proxy pass...");
|
||||
let _ = crate::proxy::copy_bidirectional::copy_bidirectional_client_compute(
|
||||
&mut client,
|
||||
&mut compute,
|
||||
inspect,
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user