mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-07 13:32:57 +00:00
proxy: spawn cancellation checks in the background (#9918)
## Problem For cancellation, a connection is open during all the cancel checks. ## Summary of changes Spawn cancellation checks in the background, and close connection immediately. Use task_tracker for cancellation checks.
This commit is contained in:
@@ -99,16 +99,17 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
|
||||
/// Try to cancel a running query for the corresponding connection.
|
||||
/// If the cancellation key is not found, it will be published to Redis.
|
||||
/// check_allowed - if true, check if the IP is allowed to cancel the query
|
||||
/// return Result primarily for tests
|
||||
pub(crate) async fn cancel_session(
|
||||
&self,
|
||||
key: CancelKeyData,
|
||||
session_id: Uuid,
|
||||
peer_addr: &IpAddr,
|
||||
peer_addr: IpAddr,
|
||||
check_allowed: bool,
|
||||
) -> Result<(), CancelError> {
|
||||
// TODO: check for unspecified address is only for backward compatibility, should be removed
|
||||
if !peer_addr.is_unspecified() {
|
||||
let subnet_key = match *peer_addr {
|
||||
let subnet_key = match peer_addr {
|
||||
IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
|
||||
IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
|
||||
};
|
||||
@@ -141,9 +142,11 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
match self.client.try_publish(key, session_id, *peer_addr).await {
|
||||
match self.client.try_publish(key, session_id, peer_addr).await {
|
||||
Ok(()) => {} // do nothing
|
||||
Err(e) => {
|
||||
// log it here since cancel_session could be spawned in a task
|
||||
tracing::error!("failed to publish cancellation key: {key}, error: {e}");
|
||||
return Err(CancelError::IO(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
e.to_string(),
|
||||
@@ -154,8 +157,10 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
|
||||
};
|
||||
|
||||
if check_allowed
|
||||
&& !check_peer_addr_is_in_list(peer_addr, cancel_closure.ip_allowlist.as_slice())
|
||||
&& !check_peer_addr_is_in_list(&peer_addr, cancel_closure.ip_allowlist.as_slice())
|
||||
{
|
||||
// log it here since cancel_session could be spawned in a task
|
||||
tracing::warn!("IP is not allowed to cancel the query: {key}");
|
||||
return Err(CancelError::IpNotAllowed);
|
||||
}
|
||||
|
||||
@@ -306,7 +311,7 @@ mod tests {
|
||||
cancel_key: 0,
|
||||
},
|
||||
Uuid::new_v4(),
|
||||
&("127.0.0.1".parse().unwrap()),
|
||||
"127.0.0.1".parse().unwrap(),
|
||||
true,
|
||||
)
|
||||
.await
|
||||
|
||||
@@ -35,6 +35,7 @@ pub async fn task_main(
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
|
||||
while let Some(accept_result) =
|
||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||
@@ -48,6 +49,7 @@ pub async fn task_main(
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
let cancellations = cancellations.clone();
|
||||
|
||||
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
|
||||
@@ -96,6 +98,7 @@ pub async fn task_main(
|
||||
cancellation_handler,
|
||||
socket,
|
||||
conn_gauge,
|
||||
cancellations,
|
||||
)
|
||||
.instrument(ctx.span())
|
||||
.boxed()
|
||||
@@ -127,10 +130,12 @@ pub async fn task_main(
|
||||
}
|
||||
|
||||
connections.close();
|
||||
cancellations.close();
|
||||
drop(listener);
|
||||
|
||||
// Drain connections
|
||||
connections.wait().await;
|
||||
cancellations.wait().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -142,6 +147,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
stream: S,
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
) -> Result<Option<ProxyPassthrough<CancellationHandlerMainInternal, S>>, ClientRequestError> {
|
||||
debug!(
|
||||
protocol = %ctx.protocol(),
|
||||
@@ -161,15 +167,26 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(cancel_key_data) => {
|
||||
return Ok(cancellation_handler
|
||||
.cancel_session(
|
||||
cancel_key_data,
|
||||
ctx.session_id(),
|
||||
&ctx.peer_addr(),
|
||||
config.authentication_config.ip_allowlist_check_enabled,
|
||||
)
|
||||
.await
|
||||
.map(|()| None)?)
|
||||
// spawn a task to cancel the session, but don't wait for it
|
||||
cancellations.spawn({
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let session_id = ctx.session_id();
|
||||
let peer_ip = ctx.peer_addr();
|
||||
async move {
|
||||
drop(
|
||||
cancellation_handler_clone
|
||||
.cancel_session(
|
||||
cancel_key_data,
|
||||
session_id,
|
||||
peer_ip,
|
||||
config.authentication_config.ip_allowlist_check_enabled,
|
||||
)
|
||||
.await,
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
drop(pause);
|
||||
|
||||
@@ -69,6 +69,7 @@ pub async fn task_main(
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
|
||||
while let Some(accept_result) =
|
||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||
@@ -82,6 +83,7 @@ pub async fn task_main(
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
let cancellations = cancellations.clone();
|
||||
|
||||
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
|
||||
@@ -133,6 +135,7 @@ pub async fn task_main(
|
||||
ClientMode::Tcp,
|
||||
endpoint_rate_limiter2,
|
||||
conn_gauge,
|
||||
cancellations,
|
||||
)
|
||||
.instrument(ctx.span())
|
||||
.boxed()
|
||||
@@ -164,10 +167,12 @@ pub async fn task_main(
|
||||
}
|
||||
|
||||
connections.close();
|
||||
cancellations.close();
|
||||
drop(listener);
|
||||
|
||||
// Drain connections
|
||||
connections.wait().await;
|
||||
cancellations.wait().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -250,6 +255,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
mode: ClientMode,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
) -> Result<Option<ProxyPassthrough<CancellationHandlerMainInternal, S>>, ClientRequestError> {
|
||||
debug!(
|
||||
protocol = %ctx.protocol(),
|
||||
@@ -270,15 +276,26 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(cancel_key_data) => {
|
||||
return Ok(cancellation_handler
|
||||
.cancel_session(
|
||||
cancel_key_data,
|
||||
ctx.session_id(),
|
||||
&ctx.peer_addr(),
|
||||
config.authentication_config.ip_allowlist_check_enabled,
|
||||
)
|
||||
.await
|
||||
.map(|()| None)?)
|
||||
// spawn a task to cancel the session, but don't wait for it
|
||||
cancellations.spawn({
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let session_id = ctx.session_id();
|
||||
let peer_ip = ctx.peer_addr();
|
||||
async move {
|
||||
drop(
|
||||
cancellation_handler_clone
|
||||
.cancel_session(
|
||||
cancel_key_data,
|
||||
session_id,
|
||||
peer_ip,
|
||||
config.authentication_config.ip_allowlist_check_enabled,
|
||||
)
|
||||
.await,
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
drop(pause);
|
||||
|
||||
@@ -149,7 +149,7 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
|
||||
.cancel_session(
|
||||
cancel_session.cancel_key_data,
|
||||
uuid::Uuid::nil(),
|
||||
&peer_addr,
|
||||
peer_addr,
|
||||
cancel_session.peer_addr.is_some(),
|
||||
)
|
||||
.await
|
||||
|
||||
@@ -132,6 +132,7 @@ pub async fn task_main(
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
connections.close(); // allows `connections.wait to complete`
|
||||
|
||||
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
|
||||
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
|
||||
if let Err(e) = conn.set_nodelay(true) {
|
||||
@@ -160,6 +161,7 @@ pub async fn task_main(
|
||||
let connections2 = connections.clone();
|
||||
let cancellation_handler = cancellation_handler.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
let cancellations = cancellations.clone();
|
||||
connections.spawn(
|
||||
async move {
|
||||
let conn_token2 = conn_token.clone();
|
||||
@@ -188,6 +190,7 @@ pub async fn task_main(
|
||||
config,
|
||||
backend,
|
||||
connections2,
|
||||
cancellations,
|
||||
cancellation_handler,
|
||||
endpoint_rate_limiter,
|
||||
conn_token,
|
||||
@@ -313,6 +316,7 @@ async fn connection_handler(
|
||||
config: &'static ProxyConfig,
|
||||
backend: Arc<PoolingBackend>,
|
||||
connections: TaskTracker,
|
||||
cancellations: TaskTracker,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
cancellation_token: CancellationToken,
|
||||
@@ -353,6 +357,7 @@ async fn connection_handler(
|
||||
|
||||
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
|
||||
// By spawning the future, we ensure it never gets cancelled until it decides to.
|
||||
let cancellations = cancellations.clone();
|
||||
let handler = connections.spawn(
|
||||
request_handler(
|
||||
req,
|
||||
@@ -364,6 +369,7 @@ async fn connection_handler(
|
||||
conn_info2.clone(),
|
||||
http_request_token,
|
||||
endpoint_rate_limiter.clone(),
|
||||
cancellations,
|
||||
)
|
||||
.in_current_span()
|
||||
.map_ok_or_else(api_error_into_response, |r| r),
|
||||
@@ -411,6 +417,7 @@ async fn request_handler(
|
||||
// used to cancel in-flight HTTP requests. not used to cancel websockets
|
||||
http_cancellation_token: CancellationToken,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
cancellations: TaskTracker,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
|
||||
let host = request
|
||||
.headers()
|
||||
@@ -436,6 +443,7 @@ async fn request_handler(
|
||||
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
|
||||
.map_err(|e| ApiError::BadRequest(e.into()))?;
|
||||
|
||||
let cancellations = cancellations.clone();
|
||||
ws_connections.spawn(
|
||||
async move {
|
||||
if let Err(e) = websocket::serve_websocket(
|
||||
@@ -446,6 +454,7 @@ async fn request_handler(
|
||||
cancellation_handler,
|
||||
endpoint_rate_limiter,
|
||||
host,
|
||||
cancellations,
|
||||
)
|
||||
.await
|
||||
{
|
||||
|
||||
@@ -123,6 +123,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn serve_websocket(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static crate::auth::Backend<'static, ()>,
|
||||
@@ -131,6 +132,7 @@ pub(crate) async fn serve_websocket(
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
hostname: Option<String>,
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
) -> anyhow::Result<()> {
|
||||
let websocket = websocket.await?;
|
||||
let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
|
||||
@@ -149,6 +151,7 @@ pub(crate) async fn serve_websocket(
|
||||
ClientMode::Websockets { hostname },
|
||||
endpoint_rate_limiter,
|
||||
conn_gauge,
|
||||
cancellations,
|
||||
))
|
||||
.await;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user