From 8173dc600ad68872f4e488c753f59b8a1e2093aa Mon Sep 17 00:00:00 2001 From: Ivan Efremov Date: Thu, 28 Nov 2024 08:32:22 +0200 Subject: [PATCH] proxy: spawn cancellation checks in the background (#9918) ## Problem For cancellation, a connection is open during all the cancel checks. ## Summary of changes Spawn cancellation checks in the background, and close connection immediately. Use task_tracker for cancellation checks. --- proxy/src/cancellation.rs | 15 ++++++++----- proxy/src/console_redirect_proxy.rs | 35 +++++++++++++++++++++-------- proxy/src/proxy/mod.rs | 35 +++++++++++++++++++++-------- proxy/src/redis/notifications.rs | 2 +- proxy/src/serverless/mod.rs | 9 ++++++++ proxy/src/serverless/websocket.rs | 3 +++ 6 files changed, 75 insertions(+), 24 deletions(-) diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index 74415f1ffe..91e198bf88 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -99,16 +99,17 @@ impl CancellationHandler

{ /// Try to cancel a running query for the corresponding connection. /// If the cancellation key is not found, it will be published to Redis. /// check_allowed - if true, check if the IP is allowed to cancel the query + /// return Result primarily for tests pub(crate) async fn cancel_session( &self, key: CancelKeyData, session_id: Uuid, - peer_addr: &IpAddr, + peer_addr: IpAddr, check_allowed: bool, ) -> Result<(), CancelError> { // TODO: check for unspecified address is only for backward compatibility, should be removed if !peer_addr.is_unspecified() { - let subnet_key = match *peer_addr { + let subnet_key = match peer_addr { IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()), }; @@ -141,9 +142,11 @@ impl CancellationHandler

{ return Ok(()); } - match self.client.try_publish(key, session_id, *peer_addr).await { + match self.client.try_publish(key, session_id, peer_addr).await { Ok(()) => {} // do nothing Err(e) => { + // log it here since cancel_session could be spawned in a task + tracing::error!("failed to publish cancellation key: {key}, error: {e}"); return Err(CancelError::IO(std::io::Error::new( std::io::ErrorKind::Other, e.to_string(), @@ -154,8 +157,10 @@ impl CancellationHandler

{ }; if check_allowed - && !check_peer_addr_is_in_list(peer_addr, cancel_closure.ip_allowlist.as_slice()) + && !check_peer_addr_is_in_list(&peer_addr, cancel_closure.ip_allowlist.as_slice()) { + // log it here since cancel_session could be spawned in a task + tracing::warn!("IP is not allowed to cancel the query: {key}"); return Err(CancelError::IpNotAllowed); } @@ -306,7 +311,7 @@ mod tests { cancel_key: 0, }, Uuid::new_v4(), - &("127.0.0.1".parse().unwrap()), + "127.0.0.1".parse().unwrap(), true, ) .await diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index b910b524b1..8f78df1964 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -35,6 +35,7 @@ pub async fn task_main( socket2::SockRef::from(&listener).set_keepalive(true)?; let connections = tokio_util::task::task_tracker::TaskTracker::new(); + let cancellations = tokio_util::task::task_tracker::TaskTracker::new(); while let Some(accept_result) = run_until_cancelled(listener.accept(), &cancellation_token).await @@ -48,6 +49,7 @@ pub async fn task_main( let session_id = uuid::Uuid::new_v4(); let cancellation_handler = Arc::clone(&cancellation_handler); + let cancellations = cancellations.clone(); debug!(protocol = "tcp", %session_id, "accepted new TCP connection"); @@ -96,6 +98,7 @@ pub async fn task_main( cancellation_handler, socket, conn_gauge, + cancellations, ) .instrument(ctx.span()) .boxed() @@ -127,10 +130,12 @@ pub async fn task_main( } connections.close(); + cancellations.close(); drop(listener); // Drain connections connections.wait().await; + cancellations.wait().await; Ok(()) } @@ -142,6 +147,7 @@ pub(crate) async fn handle_client( cancellation_handler: Arc, stream: S, conn_gauge: NumClientConnectionsGuard<'static>, + cancellations: tokio_util::task::task_tracker::TaskTracker, ) -> Result>, ClientRequestError> { debug!( protocol = %ctx.protocol(), @@ -161,15 +167,26 @@ pub(crate) async fn handle_client( match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? { HandshakeData::Startup(stream, params) => (stream, params), HandshakeData::Cancel(cancel_key_data) => { - return Ok(cancellation_handler - .cancel_session( - cancel_key_data, - ctx.session_id(), - &ctx.peer_addr(), - config.authentication_config.ip_allowlist_check_enabled, - ) - .await - .map(|()| None)?) + // spawn a task to cancel the session, but don't wait for it + cancellations.spawn({ + let cancellation_handler_clone = Arc::clone(&cancellation_handler); + let session_id = ctx.session_id(); + let peer_ip = ctx.peer_addr(); + async move { + drop( + cancellation_handler_clone + .cancel_session( + cancel_key_data, + session_id, + peer_ip, + config.authentication_config.ip_allowlist_check_enabled, + ) + .await, + ); + } + }); + + return Ok(None); } }; drop(pause); diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 7fe67e43de..956036d29d 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -69,6 +69,7 @@ pub async fn task_main( socket2::SockRef::from(&listener).set_keepalive(true)?; let connections = tokio_util::task::task_tracker::TaskTracker::new(); + let cancellations = tokio_util::task::task_tracker::TaskTracker::new(); while let Some(accept_result) = run_until_cancelled(listener.accept(), &cancellation_token).await @@ -82,6 +83,7 @@ pub async fn task_main( let session_id = uuid::Uuid::new_v4(); let cancellation_handler = Arc::clone(&cancellation_handler); + let cancellations = cancellations.clone(); debug!(protocol = "tcp", %session_id, "accepted new TCP connection"); let endpoint_rate_limiter2 = endpoint_rate_limiter.clone(); @@ -133,6 +135,7 @@ pub async fn task_main( ClientMode::Tcp, endpoint_rate_limiter2, conn_gauge, + cancellations, ) .instrument(ctx.span()) .boxed() @@ -164,10 +167,12 @@ pub async fn task_main( } connections.close(); + cancellations.close(); drop(listener); // Drain connections connections.wait().await; + cancellations.wait().await; Ok(()) } @@ -250,6 +255,7 @@ pub(crate) async fn handle_client( mode: ClientMode, endpoint_rate_limiter: Arc, conn_gauge: NumClientConnectionsGuard<'static>, + cancellations: tokio_util::task::task_tracker::TaskTracker, ) -> Result>, ClientRequestError> { debug!( protocol = %ctx.protocol(), @@ -270,15 +276,26 @@ pub(crate) async fn handle_client( match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? { HandshakeData::Startup(stream, params) => (stream, params), HandshakeData::Cancel(cancel_key_data) => { - return Ok(cancellation_handler - .cancel_session( - cancel_key_data, - ctx.session_id(), - &ctx.peer_addr(), - config.authentication_config.ip_allowlist_check_enabled, - ) - .await - .map(|()| None)?) + // spawn a task to cancel the session, but don't wait for it + cancellations.spawn({ + let cancellation_handler_clone = Arc::clone(&cancellation_handler); + let session_id = ctx.session_id(); + let peer_ip = ctx.peer_addr(); + async move { + drop( + cancellation_handler_clone + .cancel_session( + cancel_key_data, + session_id, + peer_ip, + config.authentication_config.ip_allowlist_check_enabled, + ) + .await, + ); + } + }); + + return Ok(None); } }; drop(pause); diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 65008ae943..9ac07b7e90 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -149,7 +149,7 @@ impl MessageHandler { .cancel_session( cancel_session.cancel_key_data, uuid::Uuid::nil(), - &peer_addr, + peer_addr, cancel_session.peer_addr.is_some(), ) .await diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index 77025f419d..80b42f9e55 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -132,6 +132,7 @@ pub async fn task_main( let connections = tokio_util::task::task_tracker::TaskTracker::new(); connections.close(); // allows `connections.wait to complete` + let cancellations = tokio_util::task::task_tracker::TaskTracker::new(); while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await { let (conn, peer_addr) = res.context("could not accept TCP stream")?; if let Err(e) = conn.set_nodelay(true) { @@ -160,6 +161,7 @@ pub async fn task_main( let connections2 = connections.clone(); let cancellation_handler = cancellation_handler.clone(); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); + let cancellations = cancellations.clone(); connections.spawn( async move { let conn_token2 = conn_token.clone(); @@ -188,6 +190,7 @@ pub async fn task_main( config, backend, connections2, + cancellations, cancellation_handler, endpoint_rate_limiter, conn_token, @@ -313,6 +316,7 @@ async fn connection_handler( config: &'static ProxyConfig, backend: Arc, connections: TaskTracker, + cancellations: TaskTracker, cancellation_handler: Arc, endpoint_rate_limiter: Arc, cancellation_token: CancellationToken, @@ -353,6 +357,7 @@ async fn connection_handler( // `request_handler` is not cancel safe. It expects to be cancelled only at specific times. // By spawning the future, we ensure it never gets cancelled until it decides to. + let cancellations = cancellations.clone(); let handler = connections.spawn( request_handler( req, @@ -364,6 +369,7 @@ async fn connection_handler( conn_info2.clone(), http_request_token, endpoint_rate_limiter.clone(), + cancellations, ) .in_current_span() .map_ok_or_else(api_error_into_response, |r| r), @@ -411,6 +417,7 @@ async fn request_handler( // used to cancel in-flight HTTP requests. not used to cancel websockets http_cancellation_token: CancellationToken, endpoint_rate_limiter: Arc, + cancellations: TaskTracker, ) -> Result>, ApiError> { let host = request .headers() @@ -436,6 +443,7 @@ async fn request_handler( let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request) .map_err(|e| ApiError::BadRequest(e.into()))?; + let cancellations = cancellations.clone(); ws_connections.spawn( async move { if let Err(e) = websocket::serve_websocket( @@ -446,6 +454,7 @@ async fn request_handler( cancellation_handler, endpoint_rate_limiter, host, + cancellations, ) .await { diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 4088fea835..bdb83fe6be 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -123,6 +123,7 @@ impl AsyncBufRead for WebSocketRw { } } +#[allow(clippy::too_many_arguments)] pub(crate) async fn serve_websocket( config: &'static ProxyConfig, auth_backend: &'static crate::auth::Backend<'static, ()>, @@ -131,6 +132,7 @@ pub(crate) async fn serve_websocket( cancellation_handler: Arc, endpoint_rate_limiter: Arc, hostname: Option, + cancellations: tokio_util::task::task_tracker::TaskTracker, ) -> anyhow::Result<()> { let websocket = websocket.await?; let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket)); @@ -149,6 +151,7 @@ pub(crate) async fn serve_websocket( ClientMode::Websockets { hostname }, endpoint_rate_limiter, conn_gauge, + cancellations, )) .await;