dont box handle_client anymore and move spawning passthrough into handle_client so we don't need to move a heavy object in return position anymore

This commit is contained in:
Conrad Ludgate
2025-05-29 18:20:29 +01:00
parent 11bb84c38d
commit cf07c5b5f9
2 changed files with 53 additions and 44 deletions

View File

@@ -138,11 +138,13 @@ pub async fn task_main(
crate::metrics::Protocol::Tcp,
&config.region,
);
let span = ctx.span();
let mut ctx = Some(ctx);
let res = handle_client(
config,
auth_backend,
&ctx,
&mut ctx,
cancellation_handler,
socket,
ClientMode::Tcp,
@@ -150,22 +152,18 @@ pub async fn task_main(
conn_gauge,
tracker,
)
.instrument(ctx.span())
.boxed()
.instrument(span)
.await;
match res {
Err(e) => {
match (ctx, res) {
(None, _) => {}
(Some(ctx), Ok(())) => {
ctx.success();
}
(Some(ctx), Err(e)) => {
ctx.set_error_kind(e.get_error_kind());
warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
}
Ok(None) => {
ctx.set_success();
}
Ok(Some(p)) => {
ctx.set_success();
tokio::spawn(p.proxy_pass(ctx, &config.connect_to_compute));
}
}
});
}
@@ -241,46 +239,50 @@ impl ReportableError for ClientRequestError {
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
ctx: &RequestContext,
ctx_slot: &mut Option<RequestContext>,
cancellation_handler: Arc<CancellationHandler>,
stream: S,
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: NumClientConnectionsGuard<'static>,
tracker: TaskTrackerToken,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
) -> Result<(), ClientRequestError> {
let cplane = match auth_backend {
auth::Backend::ControlPlane(cplane, ()) => &**cplane,
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
};
debug!(
protocol = %ctx.protocol(),
"handling interactive connection from client"
);
let protocol = ctx_slot.as_ref().expect("context must be set").protocol();
debug!(%protocol, "handling interactive connection from client");
let metrics = &Metrics::get().proxy;
let proto = ctx.protocol();
let request_gauge = metrics.connection_requests.guard(proto);
let request_gauge = metrics.connection_requests.guard(protocol);
let handshake_result: Result<_, ClientRequestError> = async {
let tls = config.tls_config.load();
let tls = tls.as_deref();
let ctx = ctx_slot.as_ref().expect("context must be set");
let record_handshake_error = !ctx.has_private_peer_addr();
let _pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(
ctx,
stream,
tracker,
mode.handshake_tls(tls),
record_handshake_error,
);
let data = {
let _pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
tokio::time::timeout(
config.handshake_timeout,
handshake(
ctx,
stream,
tracker,
mode.handshake_tls(tls),
record_handshake_error,
),
)
.await??
};
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
match data {
HandshakeData::Startup(mut stream, params) => {
ctx.set_db_options(params.clone());
@@ -298,7 +300,9 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
Ok(Some((stream, params, session, user_info)))
}
HandshakeData::Cancel(cancel_key_data, tracker) => {
let ctx = ctx.clone();
let ctx = ctx_slot.take().expect("context must be set");
ctx.set_success();
let cancel_span = tracing::info_span!(parent: None, "cancel_session", session_id = ?ctx.session_id());
cancel_span.follows_from(tracing::Span::current());
@@ -328,8 +332,9 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
.await;
let Some((mut stream, params, session, user_info)) = handshake_result? else {
return Ok(None);
return Ok(());
};
let ctx = ctx_slot.as_ref().expect("context must be set");
let auth_result: Result<_, ClientRequestError> = async {
let user = user_info.user.clone();
@@ -404,13 +409,16 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
let (node, stream, tracker) = connect_result?;
let ctx = ctx_slot.take().expect("context must be set");
ctx.set_success();
let private_link_id = match ctx.extra() {
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
None => None,
};
Ok(Some(ProxyPassthrough {
let p = ProxyPassthrough {
client: stream,
private_link_id,
compute: node,
@@ -419,7 +427,10 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
_req: request_gauge,
_conn: conn_gauge,
_tracker: tracker,
}))
};
tokio::spawn(p.proxy_pass(ctx, &config.connect_to_compute));
Ok(())
}
/// Finish client connection initialization: confirm auth success, send params, etc.

View File

@@ -141,31 +141,29 @@ pub(crate) async fn serve_websocket(
.client_connections
.guard(crate::metrics::Protocol::Ws);
let res = Box::pin(handle_client(
let mut ctx_slot = Some(ctx);
let res = handle_client(
config,
auth_backend,
&ctx,
&mut ctx_slot,
cancellation_handler,
WebSocketRw::new(websocket),
ClientMode::Websockets { hostname },
endpoint_rate_limiter,
conn_gauge,
tracker,
))
)
.await;
match res {
Err(e) => {
match (ctx_slot, res) {
(None, _) => {}
(Some(ctx), Err(e)) => {
ctx.set_error_kind(e.get_error_kind());
tracing::warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
}
Ok(None) => {
(Some(ctx), Ok(())) => {
ctx.set_success();
}
Ok(Some(p)) => {
ctx.set_success();
tokio::spawn(p.proxy_pass(ctx, &config.connect_to_compute));
}
}
}