move the cancel-on-shutdown handling to the cancel session maintenance task

This commit is contained in:
Conrad Ludgate
2025-06-09 19:08:44 -07:00
parent 744011437a
commit f37a558280
6 changed files with 66 additions and 58 deletions

View File

@@ -3,6 +3,7 @@ use std::net::{IpAddr, SocketAddr};
use std::sync::{Arc, OnceLock};
use anyhow::anyhow;
use futures::FutureExt;
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use postgres_client::RawCancelToken;
use postgres_client::tls::MakeTlsConnect;
@@ -392,29 +393,39 @@ impl Session {
/// but stop when the channel is dropped.
pub(crate) async fn maintain_cancel_key(
self,
session_id: uuid::Uuid,
cancel: tokio::sync::oneshot::Receiver<Infallible>,
cancel_closure: &CancelClosure,
) -> Result<(), CancelError> {
cancel_closure: CancelClosure,
compute_config: &ComputeConfig,
) {
tokio::select! {
res = self.maintain_redis_cancel_key(cancel_closure) => match res? {},
_ = cancel => Ok(()),
_ = self.maintain_redis_cancel_key(&cancel_closure) => {}
_ = cancel => {}
};
if let Err(err) = cancel_closure
.try_cancel_query(compute_config)
.boxed()
.await
{
tracing::warn!(
?session_id,
?err,
"could not cancel the query in the database"
);
}
}
/// Ensure the cancel key is continously refreshed.
async fn maintain_redis_cancel_key(
&self,
cancel_closure: &CancelClosure,
) -> Result<Infallible, CancelError> {
// Ensure the cancel key is continously refreshed.
async fn maintain_redis_cancel_key(&self, cancel_closure: &CancelClosure) -> ! {
let Some(tx) = self.cancellation_handler.tx.get() else {
tracing::warn!("cancellation handler is not available");
return Err(CancelError::InternalError);
// don't exit, as we only want to exit if cancelled externally.
std::future::pending().await
};
let closure_json = serde_json::to_string(&cancel_closure).map_err(|e| {
tracing::warn!("failed to serialize cancel closure: {e}");
CancelError::InternalError
})?;
let closure_json = serde_json::to_string(&cancel_closure)
.expect("serialising to json string should not fail");
loop {
let guard = Metrics::get()

View File

@@ -265,7 +265,8 @@ impl ConnectInfo {
}
}
type RustlsStream = <ComputeConfig as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
pub type RustlsStream = <ComputeConfig as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
pub type MaybeRustlsStream = MaybeTlsStream<tokio::net::TcpStream, RustlsStream>;
pub(crate) struct PostgresConnection {
/// Socket connected to a compute node.
@@ -279,7 +280,7 @@ pub(crate) struct PostgresConnection {
/// Notices received from compute after authenticating
pub(crate) delayed_notice: Vec<NoticeResponseBody>,
_guage: NumDbConnectionsGuard<'static>,
pub(crate) guage: NumDbConnectionsGuard<'static>,
}
impl ConnectInfo {
@@ -342,7 +343,7 @@ impl ConnectInfo {
delayed_notice,
cancel_closure,
aux,
_guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
};
Ok(connection)

View File

@@ -120,7 +120,7 @@ pub async fn task_main(
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
match p.proxy_pass().await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
error!(
@@ -237,18 +237,25 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
prepare_client_connection(&node, *session.key(), &mut stream);
let stream = stream.flush_and_into_inner().await?;
let cancel_closure = node.cancel_closure.clone();
let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel();
tokio::spawn(async move { session.maintain_cancel_key(cancel, &cancel_closure).await });
tokio::spawn(session.maintain_cancel_key(
ctx.session_id(),
cancel,
node.cancel_closure,
&config.connect_to_compute,
));
Ok(Some(ProxyPassthrough {
client: stream,
aux: node.aux.clone(),
compute: node.stream,
aux: node.aux,
private_link_id: None,
compute: node,
session_id: ctx.session_id(),
_cancel_on_shutdown: cancel_on_shutdown,
_req: request_gauge,
_conn: conn_gauge,
_db_conn: node.guage,
}))
}

View File

@@ -1,16 +1,17 @@
use std::convert::Infallible;
use futures::FutureExt;
use smol_str::SmolStr;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::debug;
use utils::measured_stream::MeasuredStream;
use super::copy_bidirectional::ErrorSource;
use crate::compute::PostgresConnection;
use crate::config::ComputeConfig;
use crate::compute::MaybeRustlsStream;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
use crate::metrics::{
Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard,
NumDbConnectionsGuard,
};
use crate::stream::Stream;
use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
@@ -65,39 +66,20 @@ pub(crate) async fn proxy_pass(
pub(crate) struct ProxyPassthrough<S> {
pub(crate) client: Stream<S>,
pub(crate) compute: PostgresConnection,
pub(crate) compute: MaybeRustlsStream,
pub(crate) aux: MetricsAuxInfo,
pub(crate) session_id: uuid::Uuid,
pub(crate) private_link_id: Option<SmolStr>,
pub(crate) _cancel_on_shutdown: tokio::sync::oneshot::Sender<Infallible>,
pub(crate) _req: NumConnectionRequestsGuard<'static>,
pub(crate) _conn: NumClientConnectionsGuard<'static>,
pub(crate) _db_conn: NumDbConnectionsGuard<'static>,
}
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
pub(crate) async fn proxy_pass(
self,
compute_config: &ComputeConfig,
) -> Result<(), ErrorSource> {
let res = proxy_pass(
self.client,
self.compute.stream,
self.aux,
self.private_link_id,
)
.await;
if let Err(err) = self
.compute
.cancel_closure
.try_cancel_query(compute_config)
.boxed()
.await
{
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");
}
res
pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> {
proxy_pass(self.client, self.compute, self.aux, self.private_link_id).await
}
}

View File

@@ -155,7 +155,7 @@ pub async fn task_main(
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
match p.proxy_pass().await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
warn!(
@@ -377,9 +377,13 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
prepare_client_connection(&node, *session.key(), &mut stream);
let stream = stream.flush_and_into_inner().await?;
let cancel_closure = node.cancel_closure.clone();
let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel();
tokio::spawn(async move { session.maintain_cancel_key(cancel, &cancel_closure).await });
tokio::spawn(session.maintain_cancel_key(
ctx.session_id(),
cancel,
node.cancel_closure,
&config.connect_to_compute,
));
let private_link_id = match ctx.extra() {
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
@@ -389,13 +393,16 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
Ok(Some(ProxyPassthrough {
client: stream,
aux: node.aux.clone(),
compute: node.stream,
aux: node.aux,
private_link_id,
compute: node,
session_id: ctx.session_id(),
_cancel_on_shutdown: cancel_on_shutdown,
_req: request_gauge,
_conn: conn_gauge,
_db_conn: node.guage,
}))
}

View File

@@ -167,7 +167,7 @@ pub(crate) async fn serve_websocket(
Ok(Some(p)) => {
ctx.set_success();
ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
match p.proxy_pass().await {
Ok(()) => Ok(()),
Err(ErrorSource::Client(err)) => Err(err).context("client"),
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),