mirror of
https://github.com/neondatabase/neon.git
synced 2026-06-02 13:00:37 +00:00
abstract byte handling
This commit is contained in:
@@ -297,11 +297,9 @@ async fn handle_client(
|
||||
// Starting from here we only proxy the client's traffic.
|
||||
info!("performing the proxy pass...");
|
||||
|
||||
// match copy_bidirectional_client_compute(&mut tls_stream, &mut client).await {
|
||||
// Ok(_) => Ok(()),
|
||||
// Err(ErrorSource::Client(err)) => Err(err).context("client"),
|
||||
// Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
|
||||
// }
|
||||
|
||||
Ok(())
|
||||
match copy_bidirectional_client_compute(&mut tls_stream, &mut client, |_, _| {}).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(ErrorSource::Client(err)) => Err(err).context("client"),
|
||||
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -200,8 +200,10 @@ pub enum HttpDirection {
|
||||
#[derive(FixedCardinalityLabel, Copy, Clone)]
|
||||
#[label(singleton = "direction")]
|
||||
pub enum Direction {
|
||||
Tx,
|
||||
Rx,
|
||||
#[label(rename = "tx")]
|
||||
ComputeToClient,
|
||||
#[label(rename = "rx")]
|
||||
ClientToCompute,
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]
|
||||
|
||||
@@ -6,11 +6,11 @@ use std::task::{Context, Poll, ready};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tracing::info;
|
||||
|
||||
use super::conntrack::{ConnectionTracker, StateChangeObserver, StreamScannerState, TagObserver};
|
||||
use crate::metrics::Direction;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum TransferState {
|
||||
Running(CopyBuffer, StreamScannerState),
|
||||
Running(CopyBuffer),
|
||||
ShuttingDown(u64),
|
||||
Done(u64),
|
||||
}
|
||||
@@ -45,7 +45,8 @@ pub enum ErrorSource {
|
||||
fn transfer_one_direction<A, B>(
|
||||
cx: &mut Context<'_>,
|
||||
state: &mut TransferState,
|
||||
mut observer: impl TagObserver,
|
||||
direction: Direction,
|
||||
conn_tracker: &mut impl for<'a> FnMut(Direction, &'a [u8]),
|
||||
r: &mut A,
|
||||
w: &mut B,
|
||||
) -> Poll<Result<u64, ErrorDirection>>
|
||||
@@ -57,9 +58,9 @@ where
|
||||
let mut w = Pin::new(w);
|
||||
loop {
|
||||
match state {
|
||||
TransferState::Running(buf, stream_state) => {
|
||||
TransferState::Running(buf) => {
|
||||
let count =
|
||||
ready!(buf.poll_copy(cx, stream_state, &mut observer, r.as_mut(), w.as_mut()))?;
|
||||
ready!(buf.poll_copy(cx, direction, conn_tracker, r.as_mut(), w.as_mut()))?;
|
||||
*state = TransferState::ShuttingDown(count);
|
||||
}
|
||||
TransferState::ShuttingDown(count) => {
|
||||
@@ -75,20 +76,21 @@ where
|
||||
pub async fn copy_bidirectional_client_compute<Client, Compute>(
|
||||
client: &mut Client,
|
||||
compute: &mut Compute,
|
||||
conn_tracker: &mut ConnectionTracker<impl StateChangeObserver>,
|
||||
mut conn_tracker: impl for<'a> FnMut(Direction, &'a [u8]),
|
||||
) -> Result<(u64, u64), ErrorSource>
|
||||
where
|
||||
Client: AsyncRead + AsyncWrite + Unpin + ?Sized,
|
||||
Compute: AsyncRead + AsyncWrite + Unpin + ?Sized,
|
||||
{
|
||||
let mut client_to_compute = TransferState::Running(CopyBuffer::new(), StreamScannerState::Tag);
|
||||
let mut compute_to_client = TransferState::Running(CopyBuffer::new(), StreamScannerState::Tag);
|
||||
let mut client_to_compute = TransferState::Running(CopyBuffer::new());
|
||||
let mut compute_to_client = TransferState::Running(CopyBuffer::new());
|
||||
|
||||
poll_fn(|cx| {
|
||||
let mut client_to_compute_result = transfer_one_direction(
|
||||
cx,
|
||||
&mut client_to_compute,
|
||||
|tag| conn_tracker.frontend_message_tag(tag),
|
||||
Direction::ClientToCompute,
|
||||
&mut conn_tracker,
|
||||
client,
|
||||
compute,
|
||||
)
|
||||
@@ -96,7 +98,8 @@ where
|
||||
let mut compute_to_client_result = transfer_one_direction(
|
||||
cx,
|
||||
&mut compute_to_client,
|
||||
|tag| conn_tracker.backend_message_tag(tag),
|
||||
Direction::ComputeToClient,
|
||||
&mut conn_tracker,
|
||||
compute,
|
||||
client,
|
||||
)
|
||||
@@ -106,14 +109,15 @@ where
|
||||
|
||||
// Early termination checks from compute to client.
|
||||
if let TransferState::Done(_) = compute_to_client {
|
||||
if let TransferState::Running(buf, _) = &client_to_compute {
|
||||
if let TransferState::Running(buf) = &client_to_compute {
|
||||
info!("Compute is done, terminate client");
|
||||
// Initiate shutdown
|
||||
client_to_compute = TransferState::ShuttingDown(buf.amt);
|
||||
client_to_compute_result = transfer_one_direction(
|
||||
cx,
|
||||
&mut client_to_compute,
|
||||
|tag| conn_tracker.frontend_message_tag(tag),
|
||||
Direction::ClientToCompute,
|
||||
&mut conn_tracker,
|
||||
client,
|
||||
compute,
|
||||
)
|
||||
@@ -123,14 +127,15 @@ where
|
||||
|
||||
// Early termination checks from client to compute.
|
||||
if let TransferState::Done(_) = client_to_compute {
|
||||
if let TransferState::Running(buf, _) = &compute_to_client {
|
||||
if let TransferState::Running(buf) = &compute_to_client {
|
||||
info!("Client is done, terminate compute");
|
||||
// Initiate shutdown
|
||||
compute_to_client = TransferState::ShuttingDown(buf.amt);
|
||||
compute_to_client_result = transfer_one_direction(
|
||||
cx,
|
||||
&mut compute_to_client,
|
||||
|tag| conn_tracker.backend_message_tag(tag),
|
||||
Direction::ComputeToClient,
|
||||
&mut conn_tracker,
|
||||
compute,
|
||||
client,
|
||||
)
|
||||
@@ -173,8 +178,8 @@ impl CopyBuffer {
|
||||
fn poll_fill_buf<R>(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
state: &mut StreamScannerState,
|
||||
observer: &mut impl TagObserver,
|
||||
direction: Direction,
|
||||
conn_tracker: &mut impl for<'a> FnMut(Direction, &'a [u8]),
|
||||
reader: Pin<&mut R>,
|
||||
) -> Poll<io::Result<()>>
|
||||
where
|
||||
@@ -185,7 +190,7 @@ impl CopyBuffer {
|
||||
buf.set_filled(me.cap);
|
||||
|
||||
let res = reader.poll_read(cx, &mut buf);
|
||||
state.scan_bytes(&buf.filled()[me.cap..], observer);
|
||||
conn_tracker(direction, &buf.filled()[me.cap..]);
|
||||
|
||||
if let Poll::Ready(Ok(())) = res {
|
||||
let filled_len = buf.filled().len();
|
||||
@@ -198,8 +203,8 @@ impl CopyBuffer {
|
||||
fn poll_write_buf<R, W>(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
state: &mut StreamScannerState,
|
||||
observer: &mut impl TagObserver,
|
||||
direction: Direction,
|
||||
conn_tracker: &mut impl for<'a> FnMut(Direction, &'a [u8]),
|
||||
mut reader: Pin<&mut R>,
|
||||
mut writer: Pin<&mut W>,
|
||||
) -> Poll<Result<usize, ErrorDirection>>
|
||||
@@ -213,7 +218,7 @@ impl CopyBuffer {
|
||||
// Top up the buffer towards full if we can read a bit more
|
||||
// data - this should improve the chances of a large write
|
||||
if !me.read_done && me.cap < me.buf.len() {
|
||||
ready!(me.poll_fill_buf(cx, state, observer, reader.as_mut()))
|
||||
ready!(me.poll_fill_buf(cx, direction, conn_tracker, reader.as_mut()))
|
||||
.map_err(ErrorDirection::Read)?;
|
||||
}
|
||||
Poll::Pending
|
||||
@@ -225,8 +230,8 @@ impl CopyBuffer {
|
||||
pub(super) fn poll_copy<R, W>(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
state: &mut StreamScannerState,
|
||||
observer: &mut impl TagObserver,
|
||||
direction: Direction,
|
||||
conn_tracker: &mut impl for<'a> FnMut(Direction, &'a [u8]),
|
||||
mut reader: Pin<&mut R>,
|
||||
mut writer: Pin<&mut W>,
|
||||
) -> Poll<Result<u64, ErrorDirection>>
|
||||
@@ -238,7 +243,7 @@ impl CopyBuffer {
|
||||
// If there is some space left in our buffer, then we try to read some
|
||||
// data to continue, thus maximizing the chances of a large write.
|
||||
if self.cap < self.buf.len() && !self.read_done {
|
||||
match self.poll_fill_buf(cx, state, observer, reader.as_mut()) {
|
||||
match self.poll_fill_buf(cx, direction, conn_tracker, reader.as_mut()) {
|
||||
Poll::Ready(Ok(())) => (),
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(ErrorDirection::Read(err))),
|
||||
Poll::Pending => {
|
||||
@@ -263,8 +268,8 @@ impl CopyBuffer {
|
||||
while self.pos < self.cap {
|
||||
let i = ready!(self.poll_write_buf(
|
||||
cx,
|
||||
state,
|
||||
observer,
|
||||
direction,
|
||||
conn_tracker,
|
||||
reader.as_mut(),
|
||||
writer.as_mut()
|
||||
))?;
|
||||
@@ -305,19 +310,8 @@ impl CopyBuffer {
|
||||
mod tests {
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
use crate::proxy::conntrack::ConnectionState;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Default)]
|
||||
struct Observer(Vec<(ConnectionState, ConnectionState)>);
|
||||
|
||||
impl StateChangeObserver for Observer {
|
||||
fn change(&mut self, old_state: ConnectionState, new_state: ConnectionState) {
|
||||
self.0.push((old_state, new_state));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_client_to_compute() {
|
||||
let (mut client_client, mut client_proxy) = tokio::io::duplex(8); // Create a mock duplex stream
|
||||
@@ -329,15 +323,10 @@ mod tests {
|
||||
compute_client.write_all(b"Neon").await.unwrap();
|
||||
compute_client.shutdown().await.unwrap();
|
||||
|
||||
let mut conn_tracker = ConnectionTracker::new(Observer::default());
|
||||
|
||||
let result = copy_bidirectional_client_compute(
|
||||
&mut client_proxy,
|
||||
&mut compute_proxy,
|
||||
&mut conn_tracker,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let result =
|
||||
copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy, |_, _| {})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Assert correct transferred amounts
|
||||
let (client_to_compute_count, compute_to_client_count) = result;
|
||||
@@ -358,15 +347,10 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut conn_tracker = ConnectionTracker::new(Observer::default());
|
||||
|
||||
let result = copy_bidirectional_client_compute(
|
||||
&mut client_proxy,
|
||||
&mut compute_proxy,
|
||||
&mut conn_tracker,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let result =
|
||||
copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy, |_, _| {})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Assert correct transferred amounts
|
||||
let (client_to_compute_count, compute_to_client_count) = result;
|
||||
|
||||
@@ -3,7 +3,6 @@ use std::sync::Arc;
|
||||
use smol_str::SmolStr;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::debug;
|
||||
use utils::measured_stream::MeasuredStream;
|
||||
|
||||
use super::copy_bidirectional::ErrorSource;
|
||||
use crate::cancellation;
|
||||
@@ -11,15 +10,16 @@ use crate::compute::PostgresConnection;
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
|
||||
use crate::proxy::conntrack::ConnectionTracking;
|
||||
use crate::proxy::conntrack::{ConnectionTracking, StreamScannerState};
|
||||
use crate::proxy::copy_bidirectional::copy_bidirectional_client_compute;
|
||||
use crate::stream::Stream;
|
||||
use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
|
||||
|
||||
/// Forward bytes in both directions (client <-> compute).
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(crate) async fn proxy_pass(
|
||||
client: impl AsyncRead + AsyncWrite + Unpin,
|
||||
compute: impl AsyncRead + AsyncWrite + Unpin,
|
||||
mut client: impl AsyncRead + AsyncWrite + Unpin,
|
||||
mut compute: impl AsyncRead + AsyncWrite + Unpin,
|
||||
aux: MetricsAuxInfo,
|
||||
private_link_id: Option<SmolStr>,
|
||||
conntracking: &Arc<ConnectionTracking>,
|
||||
@@ -34,35 +34,32 @@ pub(crate) async fn proxy_pass(
|
||||
let mut conn_tracker = conntracking.new_tracker();
|
||||
|
||||
let metrics = &Metrics::get().proxy.io_bytes;
|
||||
let m_sent = metrics.with_labels(Direction::Tx);
|
||||
let mut client = MeasuredStream::new(
|
||||
client,
|
||||
|_| {},
|
||||
|cnt| {
|
||||
// Number of bytes we sent to the client (outbound).
|
||||
metrics.get_metric(m_sent).inc_by(cnt as u64);
|
||||
usage_tx.record_egress(cnt as u64);
|
||||
},
|
||||
);
|
||||
let m_sent = metrics.with_labels(Direction::ComputeToClient);
|
||||
let m_recv = metrics.with_labels(Direction::ClientToCompute);
|
||||
|
||||
let m_recv = metrics.with_labels(Direction::Rx);
|
||||
let mut compute = MeasuredStream::new(
|
||||
compute,
|
||||
|_| {},
|
||||
|cnt| {
|
||||
// Number of bytes the client sent to the compute node (inbound).
|
||||
metrics.get_metric(m_recv).inc_by(cnt as u64);
|
||||
usage_tx.record_ingress(cnt as u64);
|
||||
},
|
||||
);
|
||||
let mut client_to_compute = StreamScannerState::Tag;
|
||||
let mut compute_to_client = StreamScannerState::Tag;
|
||||
|
||||
// Starting from here we only proxy the client's traffic.
|
||||
debug!("performing the proxy pass...");
|
||||
let _ = crate::proxy::copy_bidirectional::copy_bidirectional_client_compute(
|
||||
&mut client,
|
||||
&mut compute,
|
||||
&mut conn_tracker,
|
||||
)
|
||||
|
||||
let _ = copy_bidirectional_client_compute(&mut client, &mut compute, |direction, bytes| {
|
||||
match direction {
|
||||
Direction::ClientToCompute => {
|
||||
client_to_compute
|
||||
.scan_bytes(bytes, &mut |tag| conn_tracker.frontend_message_tag(tag));
|
||||
|
||||
metrics.get_metric(m_recv).inc_by(bytes.len() as u64);
|
||||
usage_tx.record_ingress(bytes.len() as u64);
|
||||
}
|
||||
Direction::ComputeToClient => {
|
||||
compute_to_client
|
||||
.scan_bytes(bytes, &mut |tag| conn_tracker.backend_message_tag(tag));
|
||||
|
||||
metrics.get_metric(m_sent).inc_by(bytes.len() as u64);
|
||||
usage_tx.record_egress(bytes.len() as u64);
|
||||
}
|
||||
}
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
|
||||
Reference in New Issue
Block a user