reuse the same tracker token for websockets and http

This commit is contained in:
Conrad Ludgate
2025-05-29 16:04:14 +01:00
parent eefac5d78b
commit 0cdb0c5704

View File

@@ -41,7 +41,7 @@ use tokio::net::{TcpListener, TcpStream};
use tokio::time::timeout;
use tokio_rustls::TlsAcceptor;
use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::{Instrument, info, warn};
use crate::cancellation::CancellationHandler;
@@ -149,10 +149,11 @@ pub async fn task_main(
let conn_token = cancellation_token.child_token();
let tls_acceptor = tls_acceptor.clone();
let backend = backend.clone();
let connections2 = connections.clone();
let cancellation_handler = cancellation_handler.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
connections.spawn(
let tracker = connections.token();
tokio::spawn(
async move {
let conn_token2 = conn_token.clone();
let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token2);
@@ -179,7 +180,7 @@ pub async fn task_main(
Box::pin(connection_handler(
config,
backend,
connections2,
tracker,
cancellation_handler,
endpoint_rate_limiter,
conn_token,
@@ -302,7 +303,7 @@ async fn connection_startup(
async fn connection_handler(
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
connections: TaskTracker,
tracker: TaskTrackerToken,
cancellation_handler: Arc<CancellationHandler>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_token: CancellationToken,
@@ -343,12 +344,12 @@ async fn connection_handler(
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
// By spawning the future, we ensure it never gets cancelled until it decides to.
let handler = connections.spawn(
let handler = tokio::spawn(
request_handler(
req,
config,
backend.clone(),
connections.clone(),
tracker.clone(),
cancellation_handler.clone(),
session_id,
conn_info2.clone(),
@@ -394,7 +395,7 @@ async fn request_handler(
mut request: hyper::Request<Incoming>,
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
ws_connections: TaskTracker,
tracker: TaskTrackerToken,
cancellation_handler: Arc<CancellationHandler>,
session_id: uuid::Uuid,
conn_info: ConnectionInfo,
@@ -434,7 +435,6 @@ async fn request_handler(
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
.map_err(|e| ApiError::BadRequest(e.into()))?;
let tracker = ws_connections.token();
tokio::spawn(
async move {
if let Err(e) = websocket::serve_websocket(