From f37a558280a3d44d706cdc61c0d39ddcd9e588fc Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 9 Jun 2025 19:08:44 -0700 Subject: [PATCH] move the cancel-on-shutdown handling to the cancel session maintenance task --- proxy/src/cancellation.rs | 39 ++++++++++++++++++----------- proxy/src/compute/mod.rs | 7 +++--- proxy/src/console_redirect_proxy.rs | 19 +++++++++----- proxy/src/pglb/passthrough.rs | 38 ++++++++-------------------- proxy/src/proxy/mod.rs | 19 +++++++++----- proxy/src/serverless/websocket.rs | 2 +- 6 files changed, 66 insertions(+), 58 deletions(-) diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index a4243cfc29..17d33c9ccd 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -3,6 +3,7 @@ use std::net::{IpAddr, SocketAddr}; use std::sync::{Arc, OnceLock}; use anyhow::anyhow; +use futures::FutureExt; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use postgres_client::RawCancelToken; use postgres_client::tls::MakeTlsConnect; @@ -392,29 +393,39 @@ impl Session { /// but stop when the channel is dropped. pub(crate) async fn maintain_cancel_key( self, + session_id: uuid::Uuid, cancel: tokio::sync::oneshot::Receiver, - cancel_closure: &CancelClosure, - ) -> Result<(), CancelError> { + cancel_closure: CancelClosure, + compute_config: &ComputeConfig, + ) { tokio::select! { - res = self.maintain_redis_cancel_key(cancel_closure) => match res? {}, - _ = cancel => Ok(()), + _ = self.maintain_redis_cancel_key(&cancel_closure) => {} + _ = cancel => {} + }; + + if let Err(err) = cancel_closure + .try_cancel_query(compute_config) + .boxed() + .await + { + tracing::warn!( + ?session_id, + ?err, + "could not cancel the query in the database" + ); } } - /// Ensure the cancel key is continously refreshed. - async fn maintain_redis_cancel_key( - &self, - cancel_closure: &CancelClosure, - ) -> Result { + // Ensure the cancel key is continously refreshed. + async fn maintain_redis_cancel_key(&self, cancel_closure: &CancelClosure) -> ! { let Some(tx) = self.cancellation_handler.tx.get() else { tracing::warn!("cancellation handler is not available"); - return Err(CancelError::InternalError); + // don't exit, as we only want to exit if cancelled externally. + std::future::pending().await }; - let closure_json = serde_json::to_string(&cancel_closure).map_err(|e| { - tracing::warn!("failed to serialize cancel closure: {e}"); - CancelError::InternalError - })?; + let closure_json = serde_json::to_string(&cancel_closure) + .expect("serialising to json string should not fail"); loop { let guard = Metrics::get() diff --git a/proxy/src/compute/mod.rs b/proxy/src/compute/mod.rs index 727ba3fd60..5dd264b35e 100644 --- a/proxy/src/compute/mod.rs +++ b/proxy/src/compute/mod.rs @@ -265,7 +265,8 @@ impl ConnectInfo { } } -type RustlsStream = >::Stream; +pub type RustlsStream = >::Stream; +pub type MaybeRustlsStream = MaybeTlsStream; pub(crate) struct PostgresConnection { /// Socket connected to a compute node. @@ -279,7 +280,7 @@ pub(crate) struct PostgresConnection { /// Notices received from compute after authenticating pub(crate) delayed_notice: Vec, - _guage: NumDbConnectionsGuard<'static>, + pub(crate) guage: NumDbConnectionsGuard<'static>, } impl ConnectInfo { @@ -342,7 +343,7 @@ impl ConnectInfo { delayed_notice, cancel_closure, aux, - _guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()), + guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()), }; Ok(connection) diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 49e0c673a9..96ce090e78 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -120,7 +120,7 @@ pub async fn task_main( Ok(Some(p)) => { ctx.set_success(); let _disconnect = ctx.log_connect(); - match p.proxy_pass(&config.connect_to_compute).await { + match p.proxy_pass().await { Ok(()) => {} Err(ErrorSource::Client(e)) => { error!( @@ -237,18 +237,25 @@ pub(crate) async fn handle_client( prepare_client_connection(&node, *session.key(), &mut stream); let stream = stream.flush_and_into_inner().await?; - let cancel_closure = node.cancel_closure.clone(); let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel(); - tokio::spawn(async move { session.maintain_cancel_key(cancel, &cancel_closure).await }); + tokio::spawn(session.maintain_cancel_key( + ctx.session_id(), + cancel, + node.cancel_closure, + &config.connect_to_compute, + )); Ok(Some(ProxyPassthrough { client: stream, - aux: node.aux.clone(), + compute: node.stream, + + aux: node.aux, private_link_id: None, - compute: node, - session_id: ctx.session_id(), + _cancel_on_shutdown: cancel_on_shutdown, + _req: request_gauge, _conn: conn_gauge, + _db_conn: node.guage, })) } diff --git a/proxy/src/pglb/passthrough.rs b/proxy/src/pglb/passthrough.rs index fa3df288be..d4c029f6d9 100644 --- a/proxy/src/pglb/passthrough.rs +++ b/proxy/src/pglb/passthrough.rs @@ -1,16 +1,17 @@ use std::convert::Infallible; -use futures::FutureExt; use smol_str::SmolStr; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::debug; use utils::measured_stream::MeasuredStream; use super::copy_bidirectional::ErrorSource; -use crate::compute::PostgresConnection; -use crate::config::ComputeConfig; +use crate::compute::MaybeRustlsStream; use crate::control_plane::messages::MetricsAuxInfo; -use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard}; +use crate::metrics::{ + Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard, + NumDbConnectionsGuard, +}; use crate::stream::Stream; use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS}; @@ -65,39 +66,20 @@ pub(crate) async fn proxy_pass( pub(crate) struct ProxyPassthrough { pub(crate) client: Stream, - pub(crate) compute: PostgresConnection, + pub(crate) compute: MaybeRustlsStream, + pub(crate) aux: MetricsAuxInfo, - pub(crate) session_id: uuid::Uuid, pub(crate) private_link_id: Option, pub(crate) _cancel_on_shutdown: tokio::sync::oneshot::Sender, pub(crate) _req: NumConnectionRequestsGuard<'static>, pub(crate) _conn: NumClientConnectionsGuard<'static>, + pub(crate) _db_conn: NumDbConnectionsGuard<'static>, } impl ProxyPassthrough { - pub(crate) async fn proxy_pass( - self, - compute_config: &ComputeConfig, - ) -> Result<(), ErrorSource> { - let res = proxy_pass( - self.client, - self.compute.stream, - self.aux, - self.private_link_id, - ) - .await; - if let Err(err) = self - .compute - .cancel_closure - .try_cancel_query(compute_config) - .boxed() - .await - { - tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database"); - } - - res + pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> { + proxy_pass(self.client, self.compute, self.aux, self.private_link_id).await } } diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 980b82df36..988a66cb76 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -155,7 +155,7 @@ pub async fn task_main( Ok(Some(p)) => { ctx.set_success(); let _disconnect = ctx.log_connect(); - match p.proxy_pass(&config.connect_to_compute).await { + match p.proxy_pass().await { Ok(()) => {} Err(ErrorSource::Client(e)) => { warn!( @@ -377,9 +377,13 @@ pub(crate) async fn handle_client( prepare_client_connection(&node, *session.key(), &mut stream); let stream = stream.flush_and_into_inner().await?; - let cancel_closure = node.cancel_closure.clone(); let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel(); - tokio::spawn(async move { session.maintain_cancel_key(cancel, &cancel_closure).await }); + tokio::spawn(session.maintain_cancel_key( + ctx.session_id(), + cancel, + node.cancel_closure, + &config.connect_to_compute, + )); let private_link_id = match ctx.extra() { Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()), @@ -389,13 +393,16 @@ pub(crate) async fn handle_client( Ok(Some(ProxyPassthrough { client: stream, - aux: node.aux.clone(), + compute: node.stream, + + aux: node.aux, private_link_id, - compute: node, - session_id: ctx.session_id(), + _cancel_on_shutdown: cancel_on_shutdown, + _req: request_gauge, _conn: conn_gauge, + _db_conn: node.guage, })) } diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 8648a94869..0d374e6df2 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -167,7 +167,7 @@ pub(crate) async fn serve_websocket( Ok(Some(p)) => { ctx.set_success(); ctx.log_connect(); - match p.proxy_pass(&config.connect_to_compute).await { + match p.proxy_pass().await { Ok(()) => Ok(()), Err(ErrorSource::Client(err)) => Err(err).context("client"), Err(ErrorSource::Compute(err)) => Err(err).context("compute"),