move more work inside handshake

This commit is contained in:
Conrad Ludgate
2025-05-29 15:50:10 +01:00
parent 8b1ffa1718
commit 034bdb1552

View File

@@ -289,10 +289,10 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
let proto = ctx.protocol();
let request_gauge = metrics.connection_requests.guard(proto);
let tls = config.tls_config.load();
let tls = tls.as_deref();
let handshake_result: Result<_, ClientRequestError> = async {
let tls = config.tls_config.load();
let tls = tls.as_deref();
let record_handshake_error = !ctx.has_private_peer_addr();
let _pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(
@@ -304,34 +304,44 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
);
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
HandshakeData::Startup(stream, params) => {
HandshakeData::Startup(mut stream, params) => {
ctx.set_db_options(params.clone());
let host = mode.hostname(stream.get_ref());
let cn = tls.map(|tls| &tls.common_names);
// Extract credentials which we're going to use for auth.
let result = auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, host, cn);
let user_info = match result {
Ok(user_info) => user_info,
Err(e) => stream.throw_error(e, Some(ctx)).await?,
};
let session = cancellation_handler.get_key();
Ok(Some((stream, params, session)))
},
Ok(Some((stream, params, session, user_info)))
}
HandshakeData::Cancel(cancel_key_data, tracker) => {
let ctx = ctx.clone();
let cancel_span = tracing::info_span!(parent: None, "cancel_session", session_id = ?ctx.session_id());
cancel_span.follows_from(tracing::Span::current());
// spawn a task to cancel the session, but don't wait for it
tokio::spawn(
async move {
// ensure the proxy doesn't shutdown until we complete this task.
let _tracker = tracker;
tokio::spawn(async move {
// ensure the proxy doesn't shutdown until we complete this task.
let _tracker = tracker;
cancellation_handler
.cancel_session(
cancel_key_data,
ctx,
config.authentication_config.ip_allowlist_check_enabled,
config.authentication_config.is_vpc_acccess_proxy,
auth_backend.get_api(),
)
.await
.unwrap_or_else(|e| debug!(error = ?e, "cancel_session failed"));
}
.instrument(cancel_span),
);
cancellation_handler
.cancel_session(
cancel_key_data,
ctx,
config.authentication_config.ip_allowlist_check_enabled,
config.authentication_config.is_vpc_acccess_proxy,
auth_backend.get_api(),
)
.instrument(cancel_span)
.await
.unwrap_or_else(|e| debug!(error = ?e, "cancel_session failed"));
});
Ok(None)
}
@@ -340,23 +350,10 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
.boxed()
.await;
let Some((mut stream, params, session)) = handshake_result? else {
let Some((mut stream, params, session, user_info)) = handshake_result? else {
return Ok(None);
};
ctx.set_db_options(params.clone());
let hostname = mode.hostname(stream.get_ref());
let common_names = tls.map(|tls| &tls.common_names);
// Extract credentials which we're going to use for auth.
let result = auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names);
let user_info = match result {
Ok(user_info) => user_info,
Err(e) => stream.throw_error(e, Some(ctx)).await?,
};
let user = user_info.user.clone();
let compute_creds = match cplane
.authenticate(