box the handshake task

This commit is contained in:
Conrad Ludgate
2025-05-29 15:39:33 +01:00
parent 3124729f53
commit 2d3ea77953
2 changed files with 47 additions and 44 deletions

View File

@@ -323,7 +323,7 @@ impl CancellationHandler {
}
}
pub(crate) fn get_key(self: &Arc<Self>) -> Session {
pub(crate) fn get_key(self: Arc<Self>) -> Session {
// we intentionally generate a random "backend pid" and "secret key" here.
// we use the corresponding u64 as an identifier for the
// actual endpoint+pid+secret for postgres/pgbouncer.
@@ -340,7 +340,7 @@ impl CancellationHandler {
Session {
key,
redis_key,
cancellation_handler: Arc::clone(self),
cancellation_handler: self,
}
}

View File

@@ -264,7 +264,7 @@ impl ReportableError for ClientRequestError {
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
ctx: &RequestContext,
@@ -287,50 +287,57 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let tls = config.tls_config.load();
let tls = tls.as_deref();
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(
ctx,
stream,
tracker,
mode.handshake_tls(tls),
record_handshake_error,
);
let handshake_result: Result<_, ClientRequestError> = async {
let record_handshake_error = !ctx.has_private_peer_addr();
let _pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(
ctx,
stream,
tracker,
mode.handshake_tls(tls),
record_handshake_error,
);
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
.await??
{
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data, tracker) => {
// spawn a task to cancel the session, but don't wait for it
tokio::spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
HandshakeData::Startup(stream, params) => {
let session = cancellation_handler.get_key();
Ok(Some((stream, params, session)))
},
HandshakeData::Cancel(cancel_key_data, tracker) => {
let ctx = ctx.clone();
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
let cancel_span = tracing::info_span!(parent: None, "cancel_session", session_id = ?ctx.session_id());
cancel_span.follows_from(tracing::Span::current());
async move {
// ensure the proxy doesn't shutdown until we complete this task.
let _tracker = tracker;
cancellation_handler_clone
.cancel_session(
cancel_key_data,
ctx,
config.authentication_config.ip_allowlist_check_enabled,
config.authentication_config.is_vpc_acccess_proxy,
auth_backend.get_api(),
)
.await
.inspect_err(|e| debug!(error = ?e, "cancel_session failed"))
.ok();
}
.instrument(cancel_span)
});
// spawn a task to cancel the session, but don't wait for it
tokio::spawn(
async move {
// ensure the proxy doesn't shutdown until we complete this task.
let _tracker = tracker;
return Ok(None);
cancellation_handler
.cancel_session(
cancel_key_data,
ctx,
config.authentication_config.ip_allowlist_check_enabled,
config.authentication_config.is_vpc_acccess_proxy,
auth_backend.get_api(),
)
.await
.unwrap_or_else(|e| debug!(error = ?e, "cancel_session failed"));
}
.instrument(cancel_span),
);
Ok(None)
}
}
}
.boxed()
.await;
let Some((mut stream, params, session)) = handshake_result? else {
return Ok(None);
};
drop(pause);
ctx.set_db_options(params.clone());
@@ -397,11 +404,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
.or_else(|e| stream.throw_error(e, Some(ctx)))
.await?;
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let session = cancellation_handler_clone.get_key();
session.write_cancel_key(node.cancel_closure.clone())?;
prepare_client_connection(&node, *session.key(), &mut stream).await?;
// Before proxy passing, forward to compute whatever data is left in the