diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 246b27bb9c..5edd3fb1d9 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -364,12 +364,16 @@ async fn main2() -> anyhow::Result<()> { // client facing tasks. these will exit on error or on cancellation // cancellation returns Ok(()) let mut client_tasks = JoinSet::new(); - client_tasks.spawn(proxy::proxy::task_main( - config, - proxy_listener, - cancellation_token.clone(), - cancellation_handler.clone(), - )); + client_tasks + .build_task() + .name("tcp main") + .spawn(proxy::proxy::task_main( + config, + proxy_listener, + cancellation_token.clone(), + cancellation_handler.clone(), + )) + .unwrap(); // TODO: rename the argument to something like serverless. // It now covers more than just websockets, it also covers SQL over HTTP. @@ -378,58 +382,98 @@ async fn main2() -> anyhow::Result<()> { info!("Starting wss on {serverless_address}"); let serverless_listener = TcpListener::bind(serverless_address).await?; - client_tasks.spawn(serverless::task_main( - config, - serverless_listener, - cancellation_token.clone(), - cancellation_handler.clone(), - )); + client_tasks + .build_task() + .name("serverless main") + .spawn(serverless::task_main( + config, + serverless_listener, + cancellation_token.clone(), + cancellation_handler.clone(), + )) + .unwrap(); } - client_tasks.spawn(proxy::context::parquet::worker( - cancellation_token.clone(), - args.parquet_upload, - )); + client_tasks + .build_task() + .name("parquet worker") + .spawn(proxy::context::parquet::worker( + cancellation_token.clone(), + args.parquet_upload, + )) + .unwrap(); // maintenance tasks. these never return unless there's an error let mut maintenance_tasks = JoinSet::new(); - maintenance_tasks.spawn(proxy::handle_signals(cancellation_token.clone())); - maintenance_tasks.spawn(http::health_server::task_main( - http_listener, - AppMetrics { - jemalloc, - neon_metrics, - proxy: proxy::metrics::Metrics::get(), - }, - )); - maintenance_tasks.spawn(console::mgmt::task_main(mgmt_listener)); + maintenance_tasks + .build_task() + .name("signal handler") + .spawn(proxy::handle_signals(cancellation_token.clone())) + .unwrap(); + maintenance_tasks + .build_task() + .name("health server") + .spawn(http::health_server::task_main( + http_listener, + AppMetrics { + jemalloc, + neon_metrics, + proxy: proxy::metrics::Metrics::get(), + }, + )) + .unwrap(); + maintenance_tasks + .build_task() + .name("mangement main") + .spawn(console::mgmt::task_main(mgmt_listener)) + .unwrap(); if let Some(metrics_config) = &config.metric_collection { // TODO: Add gc regardles of the metric collection being enabled. - maintenance_tasks.spawn(usage_metrics::task_main(metrics_config)); - client_tasks.spawn(usage_metrics::task_backup( - &metrics_config.backup_metric_collection_config, - cancellation_token, - )); + maintenance_tasks + .build_task() + .name("") + .spawn(usage_metrics::task_main(metrics_config)) + .unwrap(); + client_tasks + .build_task() + .name("") + .spawn(usage_metrics::task_backup( + &metrics_config.backup_metric_collection_config, + cancellation_token, + )) + .unwrap(); } if let auth::BackendType::Console(api, _) = &config.auth_backend { if let proxy::console::provider::ConsoleBackend::Console(api) = &**api { if let Some(redis_notifications_client) = redis_notifications_client { let cache = api.caches.project_info.clone(); - maintenance_tasks.spawn(notifications::task_main( - redis_notifications_client, - cache.clone(), - cancel_map.clone(), - args.region.clone(), - )); - maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); + maintenance_tasks + .build_task() + .name("redis notifications") + .spawn(notifications::task_main( + redis_notifications_client, + cache.clone(), + cancel_map.clone(), + args.region.clone(), + )) + .unwrap(); + maintenance_tasks + .build_task() + .name("proj info cache gc") + .spawn(async move { cache.clone().gc_worker().await }) + .unwrap(); } if let Some(regional_redis_client) = regional_redis_client { let cache = api.caches.endpoints_cache.clone(); let con = regional_redis_client; let span = tracing::info_span!("endpoints_cache"); - maintenance_tasks.spawn(async move { cache.do_read(con).await }.instrument(span)); + maintenance_tasks + .build_task() + .name("redis endpoints cache read") + .spawn(async move { cache.do_read(con).await }.instrument(span)) + .unwrap(); } } } diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 9db7649003..7705896ec8 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -87,7 +87,7 @@ pub async fn task_main( tracing::info!(protocol = "tcp", %session_id, "accepted new TCP connection"); - connections.spawn(async move { + tokio::task::Builder::new().name("tcp client connection").spawn(connections.track_future(async move { let mut socket = WithClientIp::new(socket); let mut peer_addr = peer_addr.ip(); match socket.wait_for_addr().await { @@ -152,7 +152,7 @@ pub async fn task_main( } } } - }); + })).unwrap(); } connections.close(); diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index 4a56cff887..9be0f4d6c9 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -115,20 +115,25 @@ pub async fn task_main( let conn_id = uuid::Uuid::new_v4(); let http_conn_span = tracing::info_span!("http_conn", ?conn_id); - connections.spawn( - connection_handler( - config, - backend.clone(), - connections.clone(), - cancellation_handler.clone(), - cancellation_token.clone(), - server.clone(), - tls_acceptor.clone(), - conn, - peer_addr, + tokio::task::Builder::new() + .name("serverless conn handler") + .spawn( + connections.track_future( + connection_handler( + config, + backend.clone(), + connections.clone(), + cancellation_handler.clone(), + cancellation_token.clone(), + server.clone(), + tls_acceptor.clone(), + conn, + peer_addr, + ) + .instrument(http_conn_span), + ), ) - .instrument(http_conn_span), - ); + .unwrap(); } connections.wait().await; @@ -224,20 +229,25 @@ async fn connection_handler( // `request_handler` is not cancel safe. It expects to be cancelled only at specific times. // By spawning the future, we ensure it never gets cancelled until it decides to. - let handler = connections.spawn( - request_handler( - req, - config, - backend.clone(), - connections.clone(), - cancellation_handler.clone(), - session_id, - peer_addr, - http_request_token, + let handler = tokio::task::Builder::new() + .name("serverless request handler") + .spawn( + connections.track_future( + request_handler( + req, + config, + backend.clone(), + connections.clone(), + cancellation_handler.clone(), + session_id, + peer_addr, + http_request_token, + ) + .in_current_span() + .map_ok_or_else(api_error_into_response, |r| r), + ), ) - .in_current_span() - .map_ok_or_else(api_error_into_response, |r| r), - ); + .unwrap(); async move { let res = handler.await; @@ -296,17 +306,27 @@ async fn request_handler( let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None) .map_err(|e| ApiError::BadRequest(e.into()))?; - ws_connections.spawn( - async move { - if let Err(e) = - websocket::serve_websocket(config, ctx, websocket, cancellation_handler, host) + tokio::task::Builder::new() + .name("websocket client conn") + .spawn( + ws_connections.track_future( + async move { + if let Err(e) = websocket::serve_websocket( + config, + ctx, + websocket, + cancellation_handler, + host, + ) .await - { - error!("error in websocket connection: {e:#}"); - } - } - .instrument(span), - ); + { + error!("error in websocket connection: {e:#}"); + } + } + .instrument(span), + ), + ) + .unwrap(); // Return the response so the spawned future can continue. Ok(response)