proxy: spawn cancellation checks in the background (#9918)

## Problem
For cancellation, a connection is open during all the cancel checks.
## Summary of changes
Spawn cancellation checks in the background, and close connection
immediately.
Use task_tracker for cancellation checks.
This commit is contained in:
Ivan Efremov
2024-11-28 08:32:22 +02:00
committed by GitHub
parent da1daa2426
commit 8173dc600a
6 changed files with 75 additions and 24 deletions

View File

@@ -99,16 +99,17 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
/// Try to cancel a running query for the corresponding connection.
/// If the cancellation key is not found, it will be published to Redis.
/// check_allowed - if true, check if the IP is allowed to cancel the query
/// return Result primarily for tests
pub(crate) async fn cancel_session(
&self,
key: CancelKeyData,
session_id: Uuid,
peer_addr: &IpAddr,
peer_addr: IpAddr,
check_allowed: bool,
) -> Result<(), CancelError> {
// TODO: check for unspecified address is only for backward compatibility, should be removed
if !peer_addr.is_unspecified() {
let subnet_key = match *peer_addr {
let subnet_key = match peer_addr {
IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
};
@@ -141,9 +142,11 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
return Ok(());
}
match self.client.try_publish(key, session_id, *peer_addr).await {
match self.client.try_publish(key, session_id, peer_addr).await {
Ok(()) => {} // do nothing
Err(e) => {
// log it here since cancel_session could be spawned in a task
tracing::error!("failed to publish cancellation key: {key}, error: {e}");
return Err(CancelError::IO(std::io::Error::new(
std::io::ErrorKind::Other,
e.to_string(),
@@ -154,8 +157,10 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
};
if check_allowed
&& !check_peer_addr_is_in_list(peer_addr, cancel_closure.ip_allowlist.as_slice())
&& !check_peer_addr_is_in_list(&peer_addr, cancel_closure.ip_allowlist.as_slice())
{
// log it here since cancel_session could be spawned in a task
tracing::warn!("IP is not allowed to cancel the query: {key}");
return Err(CancelError::IpNotAllowed);
}
@@ -306,7 +311,7 @@ mod tests {
cancel_key: 0,
},
Uuid::new_v4(),
&("127.0.0.1".parse().unwrap()),
"127.0.0.1".parse().unwrap(),
true,
)
.await

View File

@@ -35,6 +35,7 @@ pub async fn task_main(
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
@@ -48,6 +49,7 @@ pub async fn task_main(
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let cancellations = cancellations.clone();
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
@@ -96,6 +98,7 @@ pub async fn task_main(
cancellation_handler,
socket,
conn_gauge,
cancellations,
)
.instrument(ctx.span())
.boxed()
@@ -127,10 +130,12 @@ pub async fn task_main(
}
connections.close();
cancellations.close();
drop(listener);
// Drain connections
connections.wait().await;
cancellations.wait().await;
Ok(())
}
@@ -142,6 +147,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
cancellation_handler: Arc<CancellationHandlerMain>,
stream: S,
conn_gauge: NumClientConnectionsGuard<'static>,
cancellations: tokio_util::task::task_tracker::TaskTracker,
) -> Result<Option<ProxyPassthrough<CancellationHandlerMainInternal, S>>, ClientRequestError> {
debug!(
protocol = %ctx.protocol(),
@@ -161,15 +167,26 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
return Ok(cancellation_handler
.cancel_session(
cancel_key_data,
ctx.session_id(),
&ctx.peer_addr(),
config.authentication_config.ip_allowlist_check_enabled,
)
.await
.map(|()| None)?)
// spawn a task to cancel the session, but don't wait for it
cancellations.spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let session_id = ctx.session_id();
let peer_ip = ctx.peer_addr();
async move {
drop(
cancellation_handler_clone
.cancel_session(
cancel_key_data,
session_id,
peer_ip,
config.authentication_config.ip_allowlist_check_enabled,
)
.await,
);
}
});
return Ok(None);
}
};
drop(pause);

View File

@@ -69,6 +69,7 @@ pub async fn task_main(
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
@@ -82,6 +83,7 @@ pub async fn task_main(
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let cancellations = cancellations.clone();
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
@@ -133,6 +135,7 @@ pub async fn task_main(
ClientMode::Tcp,
endpoint_rate_limiter2,
conn_gauge,
cancellations,
)
.instrument(ctx.span())
.boxed()
@@ -164,10 +167,12 @@ pub async fn task_main(
}
connections.close();
cancellations.close();
drop(listener);
// Drain connections
connections.wait().await;
cancellations.wait().await;
Ok(())
}
@@ -250,6 +255,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: NumClientConnectionsGuard<'static>,
cancellations: tokio_util::task::task_tracker::TaskTracker,
) -> Result<Option<ProxyPassthrough<CancellationHandlerMainInternal, S>>, ClientRequestError> {
debug!(
protocol = %ctx.protocol(),
@@ -270,15 +276,26 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
return Ok(cancellation_handler
.cancel_session(
cancel_key_data,
ctx.session_id(),
&ctx.peer_addr(),
config.authentication_config.ip_allowlist_check_enabled,
)
.await
.map(|()| None)?)
// spawn a task to cancel the session, but don't wait for it
cancellations.spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let session_id = ctx.session_id();
let peer_ip = ctx.peer_addr();
async move {
drop(
cancellation_handler_clone
.cancel_session(
cancel_key_data,
session_id,
peer_ip,
config.authentication_config.ip_allowlist_check_enabled,
)
.await,
);
}
});
return Ok(None);
}
};
drop(pause);

View File

@@ -149,7 +149,7 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
.cancel_session(
cancel_session.cancel_key_data,
uuid::Uuid::nil(),
&peer_addr,
peer_addr,
cancel_session.peer_addr.is_some(),
)
.await

View File

@@ -132,6 +132,7 @@ pub async fn task_main(
let connections = tokio_util::task::task_tracker::TaskTracker::new();
connections.close(); // allows `connections.wait to complete`
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
if let Err(e) = conn.set_nodelay(true) {
@@ -160,6 +161,7 @@ pub async fn task_main(
let connections2 = connections.clone();
let cancellation_handler = cancellation_handler.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let cancellations = cancellations.clone();
connections.spawn(
async move {
let conn_token2 = conn_token.clone();
@@ -188,6 +190,7 @@ pub async fn task_main(
config,
backend,
connections2,
cancellations,
cancellation_handler,
endpoint_rate_limiter,
conn_token,
@@ -313,6 +316,7 @@ async fn connection_handler(
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
connections: TaskTracker,
cancellations: TaskTracker,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_token: CancellationToken,
@@ -353,6 +357,7 @@ async fn connection_handler(
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
// By spawning the future, we ensure it never gets cancelled until it decides to.
let cancellations = cancellations.clone();
let handler = connections.spawn(
request_handler(
req,
@@ -364,6 +369,7 @@ async fn connection_handler(
conn_info2.clone(),
http_request_token,
endpoint_rate_limiter.clone(),
cancellations,
)
.in_current_span()
.map_ok_or_else(api_error_into_response, |r| r),
@@ -411,6 +417,7 @@ async fn request_handler(
// used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellations: TaskTracker,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
let host = request
.headers()
@@ -436,6 +443,7 @@ async fn request_handler(
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
.map_err(|e| ApiError::BadRequest(e.into()))?;
let cancellations = cancellations.clone();
ws_connections.spawn(
async move {
if let Err(e) = websocket::serve_websocket(
@@ -446,6 +454,7 @@ async fn request_handler(
cancellation_handler,
endpoint_rate_limiter,
host,
cancellations,
)
.await
{

View File

@@ -123,6 +123,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn serve_websocket(
config: &'static ProxyConfig,
auth_backend: &'static crate::auth::Backend<'static, ()>,
@@ -131,6 +132,7 @@ pub(crate) async fn serve_websocket(
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
hostname: Option<String>,
cancellations: tokio_util::task::task_tracker::TaskTracker,
) -> anyhow::Result<()> {
let websocket = websocket.await?;
let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
@@ -149,6 +151,7 @@ pub(crate) async fn serve_websocket(
ClientMode::Websockets { hostname },
endpoint_rate_limiter,
conn_gauge,
cancellations,
))
.await;