proxy: spawn cancellation checks in the background (#9918)

## Problem
For cancellation, a connection is open during all the cancel checks.
## Summary of changes
Spawn cancellation checks in the background, and close connection
immediately.
Use task_tracker for cancellation checks.
This commit is contained in:
Ivan Efremov
2024-11-28 08:32:22 +02:00
committed by GitHub
parent da1daa2426
commit 8173dc600a
6 changed files with 75 additions and 24 deletions

View File

@@ -132,6 +132,7 @@ pub async fn task_main(
let connections = tokio_util::task::task_tracker::TaskTracker::new();
connections.close(); // allows `connections.wait to complete`
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
if let Err(e) = conn.set_nodelay(true) {
@@ -160,6 +161,7 @@ pub async fn task_main(
let connections2 = connections.clone();
let cancellation_handler = cancellation_handler.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let cancellations = cancellations.clone();
connections.spawn(
async move {
let conn_token2 = conn_token.clone();
@@ -188,6 +190,7 @@ pub async fn task_main(
config,
backend,
connections2,
cancellations,
cancellation_handler,
endpoint_rate_limiter,
conn_token,
@@ -313,6 +316,7 @@ async fn connection_handler(
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
connections: TaskTracker,
cancellations: TaskTracker,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_token: CancellationToken,
@@ -353,6 +357,7 @@ async fn connection_handler(
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
// By spawning the future, we ensure it never gets cancelled until it decides to.
let cancellations = cancellations.clone();
let handler = connections.spawn(
request_handler(
req,
@@ -364,6 +369,7 @@ async fn connection_handler(
conn_info2.clone(),
http_request_token,
endpoint_rate_limiter.clone(),
cancellations,
)
.in_current_span()
.map_ok_or_else(api_error_into_response, |r| r),
@@ -411,6 +417,7 @@ async fn request_handler(
// used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellations: TaskTracker,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
let host = request
.headers()
@@ -436,6 +443,7 @@ async fn request_handler(
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
.map_err(|e| ApiError::BadRequest(e.into()))?;
let cancellations = cancellations.clone();
ws_connections.spawn(
async move {
if let Err(e) = websocket::serve_websocket(
@@ -446,6 +454,7 @@ async fn request_handler(
cancellation_handler,
endpoint_rate_limiter,
host,
cancellations,
)
.await
{

View File

@@ -123,6 +123,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn serve_websocket(
config: &'static ProxyConfig,
auth_backend: &'static crate::auth::Backend<'static, ()>,
@@ -131,6 +132,7 @@ pub(crate) async fn serve_websocket(
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
hostname: Option<String>,
cancellations: tokio_util::task::task_tracker::TaskTracker,
) -> anyhow::Result<()> {
let websocket = websocket.await?;
let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
@@ -149,6 +151,7 @@ pub(crate) async fn serve_websocket(
ClientMode::Websockets { hostname },
endpoint_rate_limiter,
conn_gauge,
cancellations,
))
.await;