From 2d3ea77953e9dfe49dbc73b5a6b20fbf39971f03 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 29 May 2025 15:39:33 +0100 Subject: [PATCH] box the handshake task --- proxy/src/cancellation.rs | 4 +- proxy/src/proxy/mod.rs | 87 ++++++++++++++++++++------------------- 2 files changed, 47 insertions(+), 44 deletions(-) diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index a6e7bf85a0..51da277f6b 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -323,7 +323,7 @@ impl CancellationHandler { } } - pub(crate) fn get_key(self: &Arc) -> Session { + pub(crate) fn get_key(self: Arc) -> Session { // we intentionally generate a random "backend pid" and "secret key" here. // we use the corresponding u64 as an identifier for the // actual endpoint+pid+secret for postgres/pgbouncer. @@ -340,7 +340,7 @@ impl CancellationHandler { Session { key, redis_key, - cancellation_handler: Arc::clone(self), + cancellation_handler: self, } } diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index fec1243b7d..d177a92802 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -264,7 +264,7 @@ impl ReportableError for ClientRequestError { } #[allow(clippy::too_many_arguments)] -pub(crate) async fn handle_client( +pub(crate) async fn handle_client( config: &'static ProxyConfig, auth_backend: &'static auth::Backend<'static, ()>, ctx: &RequestContext, @@ -287,50 +287,57 @@ pub(crate) async fn handle_client( let tls = config.tls_config.load(); let tls = tls.as_deref(); - let record_handshake_error = !ctx.has_private_peer_addr(); - let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - let do_handshake = handshake( - ctx, - stream, - tracker, - mode.handshake_tls(tls), - record_handshake_error, - ); + let handshake_result: Result<_, ClientRequestError> = async { + let record_handshake_error = !ctx.has_private_peer_addr(); + let _pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + let do_handshake = handshake( + ctx, + stream, + tracker, + mode.handshake_tls(tls), + record_handshake_error, + ); - let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake) - .await?? - { - HandshakeData::Startup(stream, params) => (stream, params), - HandshakeData::Cancel(cancel_key_data, tracker) => { - // spawn a task to cancel the session, but don't wait for it - tokio::spawn({ - let cancellation_handler_clone = Arc::clone(&cancellation_handler); + match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? { + HandshakeData::Startup(stream, params) => { + let session = cancellation_handler.get_key(); + Ok(Some((stream, params, session))) + }, + HandshakeData::Cancel(cancel_key_data, tracker) => { let ctx = ctx.clone(); - let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id()); + let cancel_span = tracing::info_span!(parent: None, "cancel_session", session_id = ?ctx.session_id()); cancel_span.follows_from(tracing::Span::current()); - async move { - // ensure the proxy doesn't shutdown until we complete this task. - let _tracker = tracker; - cancellation_handler_clone - .cancel_session( - cancel_key_data, - ctx, - config.authentication_config.ip_allowlist_check_enabled, - config.authentication_config.is_vpc_acccess_proxy, - auth_backend.get_api(), - ) - .await - .inspect_err(|e| debug!(error = ?e, "cancel_session failed")) - .ok(); - } - .instrument(cancel_span) - }); + // spawn a task to cancel the session, but don't wait for it + tokio::spawn( + async move { + // ensure the proxy doesn't shutdown until we complete this task. + let _tracker = tracker; - return Ok(None); + cancellation_handler + .cancel_session( + cancel_key_data, + ctx, + config.authentication_config.ip_allowlist_check_enabled, + config.authentication_config.is_vpc_acccess_proxy, + auth_backend.get_api(), + ) + .await + .unwrap_or_else(|e| debug!(error = ?e, "cancel_session failed")); + } + .instrument(cancel_span), + ); + + Ok(None) + } } + } + .boxed() + .await; + + let Some((mut stream, params, session)) = handshake_result? else { + return Ok(None); }; - drop(pause); ctx.set_db_options(params.clone()); @@ -397,11 +404,7 @@ pub(crate) async fn handle_client( .or_else(|e| stream.throw_error(e, Some(ctx))) .await?; - let cancellation_handler_clone = Arc::clone(&cancellation_handler); - let session = cancellation_handler_clone.get_key(); - session.write_cancel_key(node.cancel_closure.clone())?; - prepare_client_connection(&node, *session.key(), &mut stream).await?; // Before proxy passing, forward to compute whatever data is left in the