cancellation does not need tls

This commit is contained in:
Conrad Ludgate
2025-06-20 11:15:35 +01:00
parent 7c469b30aa
commit d8ddf5c850
11 changed files with 21 additions and 72 deletions

View File

@@ -1,13 +1,12 @@
use tokio::net::TcpStream;
use crate::client::SocketConfig;
use crate::config::{Host, SslMode};
use crate::config::Host;
use crate::tls::MakeTlsConnect;
use crate::{Error, cancel_query_raw, connect_socket};
use crate::{Error, cancel_query_raw, connect_socket, connect_tls};
pub(crate) async fn cancel_query<T>(
config: SocketConfig,
ssl_mode: SslMode,
tls: T,
process_id: i32,
secret_key: i32,
@@ -30,5 +29,6 @@ where
)
.await?;
cancel_query_raw::cancel_query_raw(socket, ssl_mode, tls, process_id, secret_key).await
let stream = connect_tls::connect_tls(socket, config.ssl_mode, tls).await?;
cancel_query_raw::cancel_query_raw(stream, process_id, secret_key).await
}

View File

@@ -2,23 +2,16 @@ use bytes::BytesMut;
use postgres_protocol2::message::frontend;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use crate::config::SslMode;
use crate::tls::TlsConnect;
use crate::{Error, connect_tls};
use crate::Error;
pub async fn cancel_query_raw<S, T>(
stream: S,
mode: SslMode,
tls: T,
pub async fn cancel_query_raw<S>(
mut stream: S,
process_id: i32,
secret_key: i32,
) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
{
let mut stream = connect_tls::connect_tls(stream, mode, tls).await?;
let mut buf = BytesMut::new();
frontend::cancel_request(process_id, secret_key, &mut buf);

View File

@@ -3,8 +3,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use crate::client::SocketConfig;
use crate::config::SslMode;
use crate::tls::{MakeTlsConnect, TlsConnect};
use crate::tls::MakeTlsConnect;
use crate::{Error, cancel_query, cancel_query_raw};
/// The capability to request cancellation of in-progress queries on a
@@ -19,7 +18,6 @@ pub struct CancelToken {
/// connection.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RawCancelToken {
pub ssl_mode: SslMode,
pub process_id: i32,
pub secret_key: i32,
}
@@ -43,7 +41,6 @@ impl CancelToken {
{
cancel_query::cancel_query(
self.socket_config.clone(),
self.raw.ssl_mode,
tls,
self.raw.process_id,
self.raw.secret_key,
@@ -55,18 +52,10 @@ impl CancelToken {
impl RawCancelToken {
/// Like `cancel_query`, but uses a stream which is already connected to the server rather than opening a new
/// connection itself.
pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>
pub async fn cancel_query_raw<S>(&self, stream: S) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
{
cancel_query_raw::cancel_query_raw(
stream,
self.ssl_mode,
tls,
self.process_id,
self.secret_key,
)
.await
cancel_query_raw::cancel_query_raw(stream, self.process_id, self.secret_key).await
}
}

View File

@@ -167,6 +167,7 @@ pub struct SocketConfig {
pub host: Host,
pub port: u16,
pub connect_timeout: Option<Duration>,
pub ssl_mode: SslMode,
}
/// An asynchronous PostgreSQL client.
@@ -178,7 +179,6 @@ pub struct Client {
cached_typeinfo: CachedTypeInfo,
socket_config: SocketConfig,
ssl_mode: SslMode,
process_id: i32,
secret_key: i32,
}
@@ -188,7 +188,6 @@ impl Client {
sender: mpsc::UnboundedSender<FrontendMessage>,
receiver: mpsc::Receiver<BackendMessages>,
socket_config: SocketConfig,
ssl_mode: SslMode,
process_id: i32,
secret_key: i32,
) -> Client {
@@ -206,7 +205,6 @@ impl Client {
cached_typeinfo: Default::default(),
socket_config,
ssl_mode,
process_id,
secret_key,
}
@@ -334,7 +332,6 @@ impl Client {
CancelToken {
socket_config: self.socket_config.clone(),
raw: RawCancelToken {
ssl_mode: self.ssl_mode,
process_id: self.process_id,
secret_key: self.secret_key,
},

View File

@@ -57,6 +57,7 @@ where
host: host.clone(),
port,
connect_timeout: config.connect_timeout,
ssl_mode: config.ssl_mode,
};
let (client_tx, conn_rx) = mpsc::unbounded_channel();
@@ -65,7 +66,6 @@ where
client_tx,
client_rx,
socket_config,
config.ssl_mode,
process_id,
secret_key,
);

View File

@@ -201,7 +201,7 @@ pub async fn run() -> anyhow::Result<()> {
auth_backend,
http_listener,
shutdown.clone(),
Arc::new(CancellationHandler::new(&config.connect_to_compute)),
Arc::new(CancellationHandler::new()),
endpoint_rate_limiter,
);

View File

@@ -391,7 +391,7 @@ pub async fn run() -> anyhow::Result<()> {
.as_ref()
.map(|redis_publisher| RedisKVClient::new(redis_publisher.clone(), redis_rps_limit));
let cancellation_handler = Arc::new(CancellationHandler::new(&config.connect_to_compute));
let cancellation_handler = Arc::new(CancellationHandler::new());
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
RateBucketInfo::to_leaky_bucket(&args.endpoint_rps_limit)

View File

@@ -7,7 +7,6 @@ use anyhow::anyhow;
use futures::FutureExt;
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use postgres_client::RawCancelToken;
use postgres_client::tls::MakeTlsConnect;
use redis::{Cmd, FromRedisValue, Value};
use serde::{Deserialize, Serialize};
use thiserror::Error;
@@ -18,7 +17,6 @@ use tracing::{debug, error, info};
use crate::auth::AuthError;
use crate::auth::backend::ComputeUserInfo;
use crate::batch::{BatchQueue, QueueProcessing};
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::ControlPlaneApi;
use crate::error::ReportableError;
@@ -144,7 +142,6 @@ impl QueueProcessing for CancellationProcessor {
///
/// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
pub struct CancellationHandler {
compute_config: &'static ComputeConfig,
// rate limiter of cancellation requests
limiter: Arc<std::sync::Mutex<LeakyBucketRateLimiter<IpSubnetKey>>>,
tx: OnceLock<BatchQueue<CancellationProcessor>>, // send messages to the redis KV client task
@@ -187,9 +184,8 @@ impl ReportableError for CancelError {
}
impl CancellationHandler {
pub fn new(compute_config: &'static ComputeConfig) -> Self {
pub fn new() -> Self {
Self {
compute_config,
tx: OnceLock::new(),
limiter: Arc::new(std::sync::Mutex::new(
LeakyBucketRateLimiter::<IpSubnetKey>::new_with_shards(
@@ -332,7 +328,7 @@ impl CancellationHandler {
kind: crate::metrics::CancellationOutcome::Found,
});
info!("cancelling query per user's request using key {key}");
cancel_closure.try_cancel_query(self.compute_config).await
cancel_closure.try_cancel_query().await
}
}
@@ -362,19 +358,9 @@ impl CancelClosure {
}
}
/// Cancels the query running on user's compute node.
pub(crate) async fn try_cancel_query(
&self,
compute_config: &ComputeConfig,
) -> Result<(), CancelError> {
pub(crate) async fn try_cancel_query(&self) -> Result<(), CancelError> {
let socket = TcpStream::connect(self.socket_addr).await?;
let tls = <_ as MakeTlsConnect<tokio::net::TcpStream>>::make_tls_connect(
compute_config,
&self.hostname,
)
.map_err(|e| CancelError::IO(std::io::Error::other(e.to_string())))?;
self.cancel_token.cancel_query_raw(socket, tls).await?;
self.cancel_token.cancel_query_raw(socket).await?;
debug!("query was cancelled");
Ok(())
}
@@ -399,7 +385,6 @@ impl Session {
session_id: uuid::Uuid,
cancel: tokio::sync::oneshot::Receiver<Infallible>,
cancel_closure: &CancelClosure,
compute_config: &ComputeConfig,
) {
futures::future::select(
std::pin::pin!(self.maintain_redis_cancel_key(cancel_closure)),
@@ -407,11 +392,7 @@ impl Session {
)
.await;
if let Err(err) = cancel_closure
.try_cancel_query(compute_config)
.boxed()
.await
{
if let Err(err) = cancel_closure.try_cancel_query().boxed().await {
tracing::warn!(
?session_id,
?err,

View File

@@ -329,7 +329,6 @@ impl ConnectInfo {
let cancel_closure = CancelClosure::new(
socket_addr,
RawCancelToken {
ssl_mode: self.ssl_mode,
process_id,
secret_key,
},

View File

@@ -241,12 +241,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
session
.maintain_cancel_key(
session_id,
cancel,
&node.cancel_closure,
&config.connect_to_compute,
)
.maintain_cancel_key(session_id, cancel, &node.cancel_closure)
.await;
});

View File

@@ -381,12 +381,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
session
.maintain_cancel_key(
session_id,
cancel,
&node.cancel_closure,
&config.connect_to_compute,
)
.maintain_cancel_key(session_id, cancel, &node.cancel_closure)
.await;
});