From cf07c5b5f940ef388aa546d4416292cd9c235fd0 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 29 May 2025 18:20:29 +0100 Subject: [PATCH] dont box handle_client anymore and move spawning passthrough into handle_client so we don't need to move a heavy object in return position anymore --- proxy/src/proxy/mod.rs | 79 ++++++++++++++++++------------- proxy/src/serverless/websocket.rs | 18 ++++--- 2 files changed, 53 insertions(+), 44 deletions(-) diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 2843823cd9..50d443f53d 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -138,11 +138,13 @@ pub async fn task_main( crate::metrics::Protocol::Tcp, &config.region, ); + let span = ctx.span(); + let mut ctx = Some(ctx); let res = handle_client( config, auth_backend, - &ctx, + &mut ctx, cancellation_handler, socket, ClientMode::Tcp, @@ -150,22 +152,18 @@ pub async fn task_main( conn_gauge, tracker, ) - .instrument(ctx.span()) - .boxed() + .instrument(span) .await; - match res { - Err(e) => { + match (ctx, res) { + (None, _) => {} + (Some(ctx), Ok(())) => { + ctx.success(); + } + (Some(ctx), Err(e)) => { ctx.set_error_kind(e.get_error_kind()); warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}"); } - Ok(None) => { - ctx.set_success(); - } - Ok(Some(p)) => { - ctx.set_success(); - tokio::spawn(p.proxy_pass(ctx, &config.connect_to_compute)); - } } }); } @@ -241,46 +239,50 @@ impl ReportableError for ClientRequestError { } #[allow(clippy::too_many_arguments)] -pub(crate) async fn handle_client( +pub(crate) async fn handle_client( config: &'static ProxyConfig, auth_backend: &'static auth::Backend<'static, ()>, - ctx: &RequestContext, + ctx_slot: &mut Option, cancellation_handler: Arc, stream: S, mode: ClientMode, endpoint_rate_limiter: Arc, conn_gauge: NumClientConnectionsGuard<'static>, tracker: TaskTrackerToken, -) -> Result>, ClientRequestError> { +) -> Result<(), ClientRequestError> { let cplane = match auth_backend { auth::Backend::ControlPlane(cplane, ()) => &**cplane, auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"), }; - debug!( - protocol = %ctx.protocol(), - "handling interactive connection from client" - ); + let protocol = ctx_slot.as_ref().expect("context must be set").protocol(); + debug!(%protocol, "handling interactive connection from client"); let metrics = &Metrics::get().proxy; - let proto = ctx.protocol(); - let request_gauge = metrics.connection_requests.guard(proto); + let request_gauge = metrics.connection_requests.guard(protocol); let handshake_result: Result<_, ClientRequestError> = async { let tls = config.tls_config.load(); let tls = tls.as_deref(); + let ctx = ctx_slot.as_ref().expect("context must be set"); let record_handshake_error = !ctx.has_private_peer_addr(); - let _pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - let do_handshake = handshake( - ctx, - stream, - tracker, - mode.handshake_tls(tls), - record_handshake_error, - ); + let data = { + let _pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + tokio::time::timeout( + config.handshake_timeout, + handshake( + ctx, + stream, + tracker, + mode.handshake_tls(tls), + record_handshake_error, + ), + ) + .await?? + }; - match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? { + match data { HandshakeData::Startup(mut stream, params) => { ctx.set_db_options(params.clone()); @@ -298,7 +300,9 @@ pub(crate) async fn handle_client( Ok(Some((stream, params, session, user_info))) } HandshakeData::Cancel(cancel_key_data, tracker) => { - let ctx = ctx.clone(); + let ctx = ctx_slot.take().expect("context must be set"); + ctx.set_success(); + let cancel_span = tracing::info_span!(parent: None, "cancel_session", session_id = ?ctx.session_id()); cancel_span.follows_from(tracing::Span::current()); @@ -328,8 +332,9 @@ pub(crate) async fn handle_client( .await; let Some((mut stream, params, session, user_info)) = handshake_result? else { - return Ok(None); + return Ok(()); }; + let ctx = ctx_slot.as_ref().expect("context must be set"); let auth_result: Result<_, ClientRequestError> = async { let user = user_info.user.clone(); @@ -404,13 +409,16 @@ pub(crate) async fn handle_client( let (node, stream, tracker) = connect_result?; + let ctx = ctx_slot.take().expect("context must be set"); + ctx.set_success(); + let private_link_id = match ctx.extra() { Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()), Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()), None => None, }; - Ok(Some(ProxyPassthrough { + let p = ProxyPassthrough { client: stream, private_link_id, compute: node, @@ -419,7 +427,10 @@ pub(crate) async fn handle_client( _req: request_gauge, _conn: conn_gauge, _tracker: tracker, - })) + }; + tokio::spawn(p.proxy_pass(ctx, &config.connect_to_compute)); + + Ok(()) } /// Finish client connection initialization: confirm auth success, send params, etc. diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index eb1c90e7f9..342f8e3ab6 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -141,31 +141,29 @@ pub(crate) async fn serve_websocket( .client_connections .guard(crate::metrics::Protocol::Ws); - let res = Box::pin(handle_client( + let mut ctx_slot = Some(ctx); + let res = handle_client( config, auth_backend, - &ctx, + &mut ctx_slot, cancellation_handler, WebSocketRw::new(websocket), ClientMode::Websockets { hostname }, endpoint_rate_limiter, conn_gauge, tracker, - )) + ) .await; - match res { - Err(e) => { + match (ctx_slot, res) { + (None, _) => {} + (Some(ctx), Err(e)) => { ctx.set_error_kind(e.get_error_kind()); tracing::warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}"); } - Ok(None) => { + (Some(ctx), Ok(())) => { ctx.set_success(); } - Ok(Some(p)) => { - ctx.set_success(); - tokio::spawn(p.proxy_pass(ctx, &config.connect_to_compute)); - } } }