optimise passthrough calling convention to further reduce memory

This commit is contained in:
Conrad Ludgate
2025-05-29 18:35:24 +01:00
parent cf07c5b5f9
commit fd43058bd7
3 changed files with 103 additions and 114 deletions

View File

@@ -1,6 +1,6 @@
use std::sync::Arc;
use futures::{FutureExt, TryFutureExt};
use futures::TryFutureExt;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
use tokio_util::task::task_tracker::TaskTrackerToken;
@@ -15,7 +15,7 @@ use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
use crate::proxy::handshake::{HandshakeData, handshake};
use crate::proxy::passthrough::ProxyPassthrough;
use crate::proxy::passthrough::passthrough;
use crate::proxy::{ClientRequestError, prepare_client_connection, run_until_cancelled};
pub async fn task_main(
@@ -101,30 +101,28 @@ pub async fn task_main(
&config.region,
);
let span = ctx.span();
let mut slot = Some(ctx);
let res = handle_client(
config,
backend,
&ctx,
&mut slot,
cancellation_handler,
socket,
conn_gauge,
tracker,
)
.instrument(ctx.span())
.boxed()
.instrument(span)
.await;
match res {
Err(e) => {
match (slot, res) {
(None, _) => {}
(Some(ctx), Ok(())) => {
ctx.success();
}
(Some(ctx), Err(e)) => {
ctx.set_error_kind(e.get_error_kind());
error!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
}
Ok(None) => {
ctx.set_success();
}
Ok(Some(p)) => {
ctx.set_success();
tokio::spawn(p.proxy_pass(ctx, &config.connect_to_compute));
tracing::warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
}
}
});
@@ -140,40 +138,39 @@ pub async fn task_main(
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
config: &'static ProxyConfig,
backend: &'static ConsoleRedirectBackend,
ctx: &RequestContext,
ctx_slot: &mut Option<RequestContext>,
cancellation_handler: Arc<CancellationHandler>,
stream: S,
conn_gauge: NumClientConnectionsGuard<'static>,
tracker: TaskTrackerToken,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
debug!(
protocol = %ctx.protocol(),
"handling interactive connection from client"
);
) -> Result<(), ClientRequestError> {
let protocol = ctx_slot.as_ref().expect("context must be set").protocol();
debug!(%protocol, "handling interactive connection from client");
let metrics = &Metrics::get().proxy;
let proto = ctx.protocol();
let request_gauge = metrics.connection_requests.guard(proto);
let request_gauge = metrics.connection_requests.guard(protocol);
let tls = config.tls_config.load();
let tls = tls.as_deref();
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, tracker, tls, record_handshake_error);
let data = {
let ctx = ctx_slot.as_ref().expect("context must be set");
let record_handshake_error = !ctx.has_private_peer_addr();
let _pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, tracker, tls, record_handshake_error);
tokio::time::timeout(config.handshake_timeout, do_handshake).await??
};
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
.await??
{
let (mut stream, params) = match data {
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data, tracker) => {
// spawn a task to cancel the session, but don't wait for it
tokio::spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let ctx = ctx.clone();
let ctx = ctx_slot.take().expect("context must be set");
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
cancel_span.follows_from(tracing::Span::current());
async move {
@@ -193,11 +190,11 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
.instrument(cancel_span)
});
return Ok(None);
return Ok(());
}
};
drop(pause);
let ctx = ctx_slot.as_ref().expect("context must be set");
ctx.set_db_options(params.clone());
let (node_info, user_info, _ip_allowlist) = match backend
@@ -239,14 +236,19 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let (stream, read_buf, tracker) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
Ok(Some(ProxyPassthrough {
client: stream,
private_link_id: None,
compute: node,
session_id: ctx.session_id(),
cancel: session,
_req: request_gauge,
_conn: conn_gauge,
_tracker: tracker,
}))
let ctx = ctx_slot.take().expect("context must be set");
ctx.set_success();
tokio::spawn(passthrough(
ctx,
&config.connect_to_compute,
stream,
node,
session,
request_gauge,
conn_gauge,
tracker,
));
Ok(())
}

View File

@@ -13,10 +13,11 @@ pub use copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
use futures::{FutureExt, TryFutureExt};
use itertools::Itertools;
use once_cell::sync::OnceCell;
use passthrough::passthrough;
use pq_proto::{BeMessage as Be, CancelKeyData, StartupMessageParams};
use regex::Regex;
use serde::{Deserialize, Serialize};
use smol_str::{SmolStr, ToSmolStr, format_smolstr};
use smol_str::{SmolStr, format_smolstr};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
@@ -24,13 +25,12 @@ use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::{Instrument, debug, error, info, warn};
use self::connect_compute::{TcpMechanism, connect_to_compute};
use self::passthrough::ProxyPassthrough;
use crate::cancellation::{self, CancellationHandler};
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::handshake::{HandshakeData, handshake};
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::{PqStream, Stream};
@@ -412,23 +412,16 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'st
let ctx = ctx_slot.take().expect("context must be set");
ctx.set_success();
let private_link_id = match ctx.extra() {
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
None => None,
};
let p = ProxyPassthrough {
client: stream,
private_link_id,
compute: node,
session_id: ctx.session_id(),
cancel: session,
_req: request_gauge,
_conn: conn_gauge,
_tracker: tracker,
};
tokio::spawn(p.proxy_pass(ctx, &config.connect_to_compute));
tokio::spawn(passthrough(
ctx,
&config.connect_to_compute,
stream,
node,
session,
request_gauge,
conn_gauge,
tracker,
));
Ok(())
}

View File

@@ -1,4 +1,4 @@
use smol_str::SmolStr;
use smol_str::{SmolStr, ToSmolStr};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::debug;
@@ -11,6 +11,7 @@ use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
use crate::protocol2::ConnectionInfoExtra;
use crate::stream::Stream;
use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
@@ -62,60 +63,53 @@ pub(crate) async fn proxy_pass(
Ok(())
}
pub(crate) struct ProxyPassthrough<S> {
pub(crate) client: Stream<S>,
pub(crate) compute: PostgresConnection,
pub(crate) session_id: uuid::Uuid,
pub(crate) private_link_id: Option<SmolStr>,
pub(crate) cancel: cancellation::Session,
#[allow(clippy::too_many_arguments)]
pub(crate) async fn passthrough<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
ctx: RequestContext,
compute_config: &'static ComputeConfig,
pub(crate) _req: NumConnectionRequestsGuard<'static>,
pub(crate) _conn: NumClientConnectionsGuard<'static>,
/// ensures proxy stays online while this is set.
pub(crate) _tracker: TaskTrackerToken,
}
client: Stream<S>,
compute: PostgresConnection,
cancel: cancellation::Session,
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> ProxyPassthrough<S> {
pub(crate) async fn proxy_pass(
self,
ctx: RequestContext,
compute_config: &'static ComputeConfig,
) {
let _disconnect = ctx.log_connect();
let res = proxy_pass(
self.client,
self.compute.stream,
self.compute.aux,
self.private_link_id,
)
.await;
_req: NumConnectionRequestsGuard<'static>,
_conn: NumClientConnectionsGuard<'static>,
_tracker: TaskTrackerToken,
) {
let session_id = ctx.session_id();
let private_link_id = match ctx.extra() {
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
None => None,
};
match res {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
tracing::warn!(
session_id = ?self.session_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
Err(ErrorSource::Compute(e)) => {
tracing::error!(
session_id = ?self.session_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
let _disconnect = ctx.log_connect();
let res = proxy_pass(client, compute.stream, compute.aux, private_link_id).await;
match res {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
tracing::warn!(
session_id = ?session_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
if let Err(err) = self
.compute
.cancel_closure
.try_cancel_query(compute_config)
.await
{
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");
Err(ErrorSource::Compute(e)) => {
tracing::error!(
session_id = ?session_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
// we don't need a result. If the queue is full, we just log the error
drop(self.cancel.remove_cancel_key());
}
if let Err(err) = compute
.cancel_closure
.try_cancel_query(compute_config)
.await
{
tracing::warn!(session_id = ?session_id, ?err, "could not cancel the query in the database");
}
// we don't need a result. If the queue is full, we just log the error
drop(cancel.remove_cancel_key());
}