proxy/conntrack: Global connection tracking table and debug logging

This commit is contained in:
Folke Behrens
2025-04-21 20:02:10 +02:00
parent fd07ecf58f
commit 431a12acba
7 changed files with 107 additions and 5 deletions

View File

@@ -24,6 +24,7 @@ use crate::config::{
use crate::context::parquet::ParquetUploadArgs;
use crate::http::health_server::AppMetrics;
use crate::metrics::Metrics;
use crate::proxy::conntrack::ConnectionTracking;
use crate::rate_limiter::{
EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo, WakeComputeRateLimiter,
};
@@ -418,6 +419,8 @@ pub async fn run() -> anyhow::Result<()> {
64,
));
let conntracking = Arc::new(ConnectionTracking::default());
// client facing tasks. these will exit on error or on cancellation
// cancellation returns Ok(())
let mut client_tasks = JoinSet::new();
@@ -431,6 +434,7 @@ pub async fn run() -> anyhow::Result<()> {
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
conntracking.clone(),
));
}
@@ -453,6 +457,7 @@ pub async fn run() -> anyhow::Result<()> {
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
conntracking.clone(),
));
}
}

View File

@@ -13,6 +13,7 @@ use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
use crate::proxy::conntrack::ConnectionTracking;
use crate::proxy::handshake::{HandshakeData, handshake};
use crate::proxy::passthrough::ProxyPassthrough;
use crate::proxy::{
@@ -25,6 +26,7 @@ pub async fn task_main(
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandler>,
conntracking: Arc<ConnectionTracking>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("proxy has shut down");
@@ -50,6 +52,7 @@ pub async fn task_main(
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let cancellations = cancellations.clone();
let conntracking = Arc::clone(&conntracking);
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
@@ -111,6 +114,7 @@ pub async fn task_main(
socket,
conn_gauge,
cancellations,
conntracking,
)
.instrument(ctx.span())
.boxed()
@@ -167,6 +171,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
conn_gauge: NumClientConnectionsGuard<'static>,
cancellations: tokio_util::task::task_tracker::TaskTracker,
conntracking: Arc<ConnectionTracking>,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
debug!(
protocol = %ctx.protocol(),
@@ -264,6 +269,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
compute: node,
session_id: ctx.session_id(),
cancel: session,
conntracking,
_req: request_gauge,
_conn: conn_gauge,
}))

View File

@@ -1,13 +1,72 @@
#![allow(dead_code, reason = "TODO: work in progress")]
use std::pin::Pin;
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, AtomicUsize, Ordering};
use std::task::{Context, Poll};
use std::time::SystemTime;
use std::{fmt, io};
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct ConnId(usize);
#[derive(Default)]
pub struct ConnectionTracking {
conns: clashmap::ClashMap<ConnId, (ConnectionState, SystemTime)>,
}
impl ConnectionTracking {
pub fn new_tracker(self: &Arc<Self>) -> ConnectionTracker<Arc<Self>> {
let conn_id = self.new_conn_id();
ConnectionTracker::new(conn_id, Arc::clone(self))
}
fn new_conn_id(&self) -> ConnId {
static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
let id = ConnId(NEXT_ID.fetch_add(1, Ordering::Relaxed));
self.conns
.insert(id, (ConnectionState::Idle, SystemTime::now()));
id
}
fn update(&self, conn_id: ConnId, new_state: ConnectionState) {
let new_timestamp = SystemTime::now();
let old_state = self.conns.insert(conn_id, (new_state, new_timestamp));
if let Some((old_state, _old_timestamp)) = old_state {
tracing::debug!(?conn_id, %old_state, %new_state, "conntrack: update");
} else {
tracing::debug!(?conn_id, %new_state, "conntrack: update");
}
}
fn remove(&self, conn_id: ConnId) {
if let Some((_, (old_state, _old_timestamp))) = self.conns.remove(&conn_id) {
tracing::debug!(?conn_id, %old_state, "conntrack: remove");
}
}
}
impl StateChangeObserver for Arc<ConnectionTracking> {
type ConnId = ConnId;
fn change(
&self,
conn_id: Self::ConnId,
_old_state: ConnectionState,
new_state: ConnectionState,
) {
match new_state {
ConnectionState::Init
| ConnectionState::Idle
| ConnectionState::Transaction
| ConnectionState::Busy
| ConnectionState::Unknown => self.update(conn_id, new_state),
ConnectionState::Closed => self.remove(conn_id),
}
}
}
/// Called by `ConnectionTracker` whenever the `ConnectionState` changed.
pub trait StateChangeObserver {
/// Identifier of the connection passed back on state change.

View File

@@ -31,6 +31,7 @@ use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
use crate::proxy::conntrack::ConnectionTracking;
use crate::proxy::handshake::{HandshakeData, handshake};
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::{PqStream, Stream};
@@ -61,6 +62,7 @@ pub async fn task_main(
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandler>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conntracking: Arc<ConnectionTracking>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("proxy has shut down");
@@ -86,6 +88,7 @@ pub async fn task_main(
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let cancellations = cancellations.clone();
let conntracking = Arc::clone(&conntracking);
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
@@ -150,6 +153,7 @@ pub async fn task_main(
endpoint_rate_limiter2,
conn_gauge,
cancellations,
conntracking,
)
.instrument(ctx.span())
.boxed()
@@ -269,6 +273,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: NumClientConnectionsGuard<'static>,
cancellations: tokio_util::task::task_tracker::TaskTracker,
conntracking: Arc<ConnectionTracking>,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
debug!(
protocol = %ctx.protocol(),
@@ -410,6 +415,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
compute: node,
session_id: ctx.session_id(),
cancel: session,
conntracking,
_req: request_gauge,
_conn: conn_gauge,
}))

View File

@@ -1,3 +1,5 @@
use std::sync::Arc;
use smol_str::SmolStr;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::debug;
@@ -9,6 +11,7 @@ use crate::compute::PostgresConnection;
use crate::config::ComputeConfig;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
use crate::proxy::conntrack::{ConnectionTracking, TrackedStream};
use crate::stream::Stream;
use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
@@ -19,6 +22,7 @@ pub(crate) async fn proxy_pass(
compute: impl AsyncRead + AsyncWrite + Unpin,
aux: MetricsAuxInfo,
private_link_id: Option<SmolStr>,
conntracking: &Arc<ConnectionTracking>,
) -> Result<(), ErrorSource> {
// we will report ingress at a later date
let usage_tx = USAGE_METRICS.register(Ids {
@@ -27,9 +31,11 @@ pub(crate) async fn proxy_pass(
private_link_id,
});
let conn_tracker = conntracking.new_tracker();
let metrics = &Metrics::get().proxy.io_bytes;
let m_sent = metrics.with_labels(Direction::Tx);
let mut client = MeasuredStream::new(
let client = MeasuredStream::new(
client,
|_| {},
|cnt| {
@@ -38,9 +44,10 @@ pub(crate) async fn proxy_pass(
usage_tx.record_egress(cnt as u64);
},
);
let mut client = TrackedStream::new(client, true, |tag| conn_tracker.frontend_message_tag(tag));
let m_recv = metrics.with_labels(Direction::Rx);
let mut compute = MeasuredStream::new(
let compute = MeasuredStream::new(
compute,
|_| {},
|cnt| {
@@ -49,6 +56,8 @@ pub(crate) async fn proxy_pass(
usage_tx.record_ingress(cnt as u64);
},
);
let mut compute =
TrackedStream::new(compute, true, |tag| conn_tracker.backend_message_tag(tag));
// Starting from here we only proxy the client's traffic.
debug!("performing the proxy pass...");
@@ -68,6 +77,7 @@ pub(crate) struct ProxyPassthrough<S> {
pub(crate) session_id: uuid::Uuid,
pub(crate) private_link_id: Option<SmolStr>,
pub(crate) cancel: cancellation::Session,
pub(crate) conntracking: Arc<ConnectionTracking>,
pub(crate) _req: NumConnectionRequestsGuard<'static>,
pub(crate) _conn: NumClientConnectionsGuard<'static>,
@@ -83,6 +93,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
self.compute.stream,
self.aux,
self.private_link_id,
&self.conntracking,
)
.await;
if let Err(err) = self

View File

@@ -50,6 +50,7 @@ use crate::context::RequestContext;
use crate::ext::TaskExt;
use crate::metrics::Metrics;
use crate::protocol2::{ChainRW, ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::conntrack::ConnectionTracking;
use crate::proxy::run_until_cancelled;
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
@@ -124,6 +125,9 @@ pub async fn task_main(
connections.close(); // allows `connections.wait to complete`
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
let conntracking = Arc::new(ConnectionTracking::default());
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
if let Err(e) = conn.set_nodelay(true) {
@@ -153,6 +157,8 @@ pub async fn task_main(
let cancellation_handler = cancellation_handler.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let cancellations = cancellations.clone();
let conntracking = Arc::clone(&conntracking);
connections.spawn(
async move {
let conn_token2 = conn_token.clone();
@@ -185,6 +191,7 @@ pub async fn task_main(
cancellation_handler,
endpoint_rate_limiter,
conn_token,
conntracking,
conn,
conn_info,
session_id,
@@ -309,6 +316,7 @@ async fn connection_handler(
cancellation_handler: Arc<CancellationHandler>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_token: CancellationToken,
conntracking: Arc<ConnectionTracking>,
conn: AsyncRW,
conn_info: ConnectionInfo,
session_id: uuid::Uuid,
@@ -347,6 +355,7 @@ async fn connection_handler(
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
// By spawning the future, we ensure it never gets cancelled until it decides to.
let cancellations = cancellations.clone();
let conntracking = Arc::clone(&conntracking);
let handler = connections.spawn(
request_handler(
req,
@@ -359,6 +368,7 @@ async fn connection_handler(
http_request_token,
endpoint_rate_limiter.clone(),
cancellations,
conntracking,
)
.in_current_span()
.map_ok_or_else(api_error_into_response, |r| r),
@@ -407,6 +417,7 @@ async fn request_handler(
http_cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellations: TaskTracker,
conntracking: Arc<ConnectionTracking>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
let host = request
.headers()
@@ -452,6 +463,7 @@ async fn request_handler(
endpoint_rate_limiter,
host,
cancellations,
conntracking,
)
.await
{

View File

@@ -17,6 +17,7 @@ use crate::config::ProxyConfig;
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::Metrics;
use crate::proxy::conntrack::ConnectionTracking;
use crate::proxy::{ClientMode, ErrorSource, handle_client};
use crate::rate_limiter::EndpointRateLimiter;
@@ -133,6 +134,7 @@ pub(crate) async fn serve_websocket(
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
hostname: Option<String>,
cancellations: tokio_util::task::task_tracker::TaskTracker,
conntracking: Arc<ConnectionTracking>,
) -> anyhow::Result<()> {
let websocket = websocket.await?;
let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
@@ -152,6 +154,7 @@ pub(crate) async fn serve_websocket(
endpoint_rate_limiter,
conn_gauge,
cancellations,
conntracking,
))
.await;