mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-14 08:52:56 +00:00
move the cancel-on-shutdown handling to the cancel session maintenance task
This commit is contained in:
@@ -3,6 +3,7 @@ use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
use anyhow::anyhow;
|
||||
use futures::FutureExt;
|
||||
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
|
||||
use postgres_client::RawCancelToken;
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
@@ -392,29 +393,39 @@ impl Session {
|
||||
/// but stop when the channel is dropped.
|
||||
pub(crate) async fn maintain_cancel_key(
|
||||
self,
|
||||
session_id: uuid::Uuid,
|
||||
cancel: tokio::sync::oneshot::Receiver<Infallible>,
|
||||
cancel_closure: &CancelClosure,
|
||||
) -> Result<(), CancelError> {
|
||||
cancel_closure: CancelClosure,
|
||||
compute_config: &ComputeConfig,
|
||||
) {
|
||||
tokio::select! {
|
||||
res = self.maintain_redis_cancel_key(cancel_closure) => match res? {},
|
||||
_ = cancel => Ok(()),
|
||||
_ = self.maintain_redis_cancel_key(&cancel_closure) => {}
|
||||
_ = cancel => {}
|
||||
};
|
||||
|
||||
if let Err(err) = cancel_closure
|
||||
.try_cancel_query(compute_config)
|
||||
.boxed()
|
||||
.await
|
||||
{
|
||||
tracing::warn!(
|
||||
?session_id,
|
||||
?err,
|
||||
"could not cancel the query in the database"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Ensure the cancel key is continously refreshed.
|
||||
async fn maintain_redis_cancel_key(
|
||||
&self,
|
||||
cancel_closure: &CancelClosure,
|
||||
) -> Result<Infallible, CancelError> {
|
||||
// Ensure the cancel key is continously refreshed.
|
||||
async fn maintain_redis_cancel_key(&self, cancel_closure: &CancelClosure) -> ! {
|
||||
let Some(tx) = self.cancellation_handler.tx.get() else {
|
||||
tracing::warn!("cancellation handler is not available");
|
||||
return Err(CancelError::InternalError);
|
||||
// don't exit, as we only want to exit if cancelled externally.
|
||||
std::future::pending().await
|
||||
};
|
||||
|
||||
let closure_json = serde_json::to_string(&cancel_closure).map_err(|e| {
|
||||
tracing::warn!("failed to serialize cancel closure: {e}");
|
||||
CancelError::InternalError
|
||||
})?;
|
||||
let closure_json = serde_json::to_string(&cancel_closure)
|
||||
.expect("serialising to json string should not fail");
|
||||
|
||||
loop {
|
||||
let guard = Metrics::get()
|
||||
|
||||
@@ -265,7 +265,8 @@ impl ConnectInfo {
|
||||
}
|
||||
}
|
||||
|
||||
type RustlsStream = <ComputeConfig as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
|
||||
pub type RustlsStream = <ComputeConfig as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
|
||||
pub type MaybeRustlsStream = MaybeTlsStream<tokio::net::TcpStream, RustlsStream>;
|
||||
|
||||
pub(crate) struct PostgresConnection {
|
||||
/// Socket connected to a compute node.
|
||||
@@ -279,7 +280,7 @@ pub(crate) struct PostgresConnection {
|
||||
/// Notices received from compute after authenticating
|
||||
pub(crate) delayed_notice: Vec<NoticeResponseBody>,
|
||||
|
||||
_guage: NumDbConnectionsGuard<'static>,
|
||||
pub(crate) guage: NumDbConnectionsGuard<'static>,
|
||||
}
|
||||
|
||||
impl ConnectInfo {
|
||||
@@ -342,7 +343,7 @@ impl ConnectInfo {
|
||||
delayed_notice,
|
||||
cancel_closure,
|
||||
aux,
|
||||
_guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
|
||||
guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),
|
||||
};
|
||||
|
||||
Ok(connection)
|
||||
|
||||
@@ -120,7 +120,7 @@ pub async fn task_main(
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
let _disconnect = ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
match p.proxy_pass().await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
error!(
|
||||
@@ -237,18 +237,25 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
prepare_client_connection(&node, *session.key(), &mut stream);
|
||||
let stream = stream.flush_and_into_inner().await?;
|
||||
|
||||
let cancel_closure = node.cancel_closure.clone();
|
||||
let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel();
|
||||
tokio::spawn(async move { session.maintain_cancel_key(cancel, &cancel_closure).await });
|
||||
tokio::spawn(session.maintain_cancel_key(
|
||||
ctx.session_id(),
|
||||
cancel,
|
||||
node.cancel_closure,
|
||||
&config.connect_to_compute,
|
||||
));
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
aux: node.aux.clone(),
|
||||
compute: node.stream,
|
||||
|
||||
aux: node.aux,
|
||||
private_link_id: None,
|
||||
compute: node,
|
||||
session_id: ctx.session_id(),
|
||||
|
||||
_cancel_on_shutdown: cancel_on_shutdown,
|
||||
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
_db_conn: node.guage,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
use std::convert::Infallible;
|
||||
|
||||
use futures::FutureExt;
|
||||
use smol_str::SmolStr;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::debug;
|
||||
use utils::measured_stream::MeasuredStream;
|
||||
|
||||
use super::copy_bidirectional::ErrorSource;
|
||||
use crate::compute::PostgresConnection;
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::compute::MaybeRustlsStream;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
|
||||
use crate::metrics::{
|
||||
Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard,
|
||||
NumDbConnectionsGuard,
|
||||
};
|
||||
use crate::stream::Stream;
|
||||
use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
|
||||
|
||||
@@ -65,39 +66,20 @@ pub(crate) async fn proxy_pass(
|
||||
|
||||
pub(crate) struct ProxyPassthrough<S> {
|
||||
pub(crate) client: Stream<S>,
|
||||
pub(crate) compute: PostgresConnection,
|
||||
pub(crate) compute: MaybeRustlsStream,
|
||||
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
pub(crate) session_id: uuid::Uuid,
|
||||
pub(crate) private_link_id: Option<SmolStr>,
|
||||
|
||||
pub(crate) _cancel_on_shutdown: tokio::sync::oneshot::Sender<Infallible>,
|
||||
|
||||
pub(crate) _req: NumConnectionRequestsGuard<'static>,
|
||||
pub(crate) _conn: NumClientConnectionsGuard<'static>,
|
||||
pub(crate) _db_conn: NumDbConnectionsGuard<'static>,
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
|
||||
pub(crate) async fn proxy_pass(
|
||||
self,
|
||||
compute_config: &ComputeConfig,
|
||||
) -> Result<(), ErrorSource> {
|
||||
let res = proxy_pass(
|
||||
self.client,
|
||||
self.compute.stream,
|
||||
self.aux,
|
||||
self.private_link_id,
|
||||
)
|
||||
.await;
|
||||
if let Err(err) = self
|
||||
.compute
|
||||
.cancel_closure
|
||||
.try_cancel_query(compute_config)
|
||||
.boxed()
|
||||
.await
|
||||
{
|
||||
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");
|
||||
}
|
||||
|
||||
res
|
||||
pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> {
|
||||
proxy_pass(self.client, self.compute, self.aux, self.private_link_id).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -155,7 +155,7 @@ pub async fn task_main(
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
let _disconnect = ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
match p.proxy_pass().await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
warn!(
|
||||
@@ -377,9 +377,13 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
prepare_client_connection(&node, *session.key(), &mut stream);
|
||||
let stream = stream.flush_and_into_inner().await?;
|
||||
|
||||
let cancel_closure = node.cancel_closure.clone();
|
||||
let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel();
|
||||
tokio::spawn(async move { session.maintain_cancel_key(cancel, &cancel_closure).await });
|
||||
tokio::spawn(session.maintain_cancel_key(
|
||||
ctx.session_id(),
|
||||
cancel,
|
||||
node.cancel_closure,
|
||||
&config.connect_to_compute,
|
||||
));
|
||||
|
||||
let private_link_id = match ctx.extra() {
|
||||
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
|
||||
@@ -389,13 +393,16 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
aux: node.aux.clone(),
|
||||
compute: node.stream,
|
||||
|
||||
aux: node.aux,
|
||||
private_link_id,
|
||||
compute: node,
|
||||
session_id: ctx.session_id(),
|
||||
|
||||
_cancel_on_shutdown: cancel_on_shutdown,
|
||||
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
_db_conn: node.guage,
|
||||
}))
|
||||
}
|
||||
|
||||
|
||||
@@ -167,7 +167,7 @@ pub(crate) async fn serve_websocket(
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
match p.proxy_pass().await {
|
||||
Ok(()) => Ok(()),
|
||||
Err(ErrorSource::Client(err)) => Err(err).context("client"),
|
||||
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
|
||||
|
||||
Reference in New Issue
Block a user