diff --git a/Cargo.lock b/Cargo.lock index fcdc424636..d9489cdd97 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -44,6 +44,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "allocator-api2" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" + [[package]] name = "android_system_properties" version = "0.1.5" @@ -2042,6 +2048,10 @@ name = "hashbrown" version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" +dependencies = [ + "ahash", + "allocator-api2", +] [[package]] name = "hashlink" @@ -5224,6 +5234,8 @@ dependencies = [ "futures-core", "futures-io", "futures-sink", + "futures-util", + "hashbrown 0.14.0", "pin-project-lite", "tokio", "tracing", diff --git a/Cargo.toml b/Cargo.toml index ba8b49c0e0..ce590f3c7a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -149,7 +149,7 @@ tokio-postgres-rustls = "0.10.0" tokio-rustls = "0.24" tokio-stream = "0.1" tokio-tar = "0.3" -tokio-util = { version = "0.7", features = ["io"] } +tokio-util = { version = "0.7.10", features = ["io", "rt"] } toml = "0.7" toml_edit = "0.19" tonic = {version = "0.9", features = ["tls", "tls-roots"]} diff --git a/libs/utils/src/completion.rs b/libs/utils/src/completion.rs index e2e84dd0ee..ca6827c9b8 100644 --- a/libs/utils/src/completion.rs +++ b/libs/utils/src/completion.rs @@ -1,16 +1,14 @@ -use std::sync::Arc; - -use tokio::sync::{mpsc, Mutex}; +use tokio_util::task::{task_tracker::TaskTrackerToken, TaskTracker}; /// While a reference is kept around, the associated [`Barrier::wait`] will wait. /// /// Can be cloned, moved and kept around in futures as "guard objects". #[derive(Clone)] -pub struct Completion(mpsc::Sender<()>); +pub struct Completion(TaskTrackerToken); /// Barrier will wait until all clones of [`Completion`] have been dropped. #[derive(Clone)] -pub struct Barrier(Arc>>); +pub struct Barrier(TaskTracker); impl Default for Barrier { fn default() -> Self { @@ -21,7 +19,7 @@ impl Default for Barrier { impl Barrier { pub async fn wait(self) { - self.0.lock().await.recv().await; + self.0.wait().await; } pub async fn maybe_wait(barrier: Option) { @@ -33,8 +31,7 @@ impl Barrier { impl PartialEq for Barrier { fn eq(&self, other: &Self) -> bool { - // we don't use dyn so this is good - Arc::ptr_eq(&self.0, &other.0) + TaskTracker::ptr_eq(&self.0, &other.0) } } @@ -42,8 +39,10 @@ impl Eq for Barrier {} /// Create new Guard and Barrier pair. pub fn channel() -> (Completion, Barrier) { - let (tx, rx) = mpsc::channel::<()>(1); - let rx = Mutex::new(rx); - let rx = Arc::new(rx); - (Completion(tx), Barrier(rx)) + let tracker = TaskTracker::new(); + // otherwise wait never exits + tracker.close(); + + let token = tracker.token(); + (Completion(token), Barrier(tracker)) } diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index 43b35c6d08..7607119dda 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -425,7 +425,6 @@ fn start_pageserver( let tenant_manager = Arc::new(tenant_manager); BACKGROUND_RUNTIME.spawn({ - let init_done_rx = init_done_rx; let shutdown_pageserver = shutdown_pageserver.clone(); let drive_init = async move { // NOTE: unlike many futures in pageserver, this one is cancellation-safe @@ -560,7 +559,6 @@ fn start_pageserver( } if let Some(metric_collection_endpoint) = &conf.metric_collection_endpoint { - let background_jobs_barrier = background_jobs_barrier; let metrics_ctx = RequestContext::todo_child( TaskKind::MetricsCollection, // This task itself shouldn't download anything. diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 438190261d..c94cd55417 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -61,6 +61,7 @@ thiserror.workspace = true tls-listener.workspace = true tokio-postgres.workspace = true tokio-rustls.workspace = true +tokio-util.workspace = true tokio = { workspace = true, features = ["signal"] } tracing-opentelemetry.workspace = true tracing-subscriber.workspace = true @@ -77,7 +78,6 @@ postgres-protocol.workspace = true smol_str.workspace = true workspace_hack.workspace = true -tokio-util.workspace = true [dev-dependencies] rcgen.workspace = true diff --git a/proxy/src/bin/pg_sni_router.rs b/proxy/src/bin/pg_sni_router.rs index bedbdbcc83..d48ba3a54e 100644 --- a/proxy/src/bin/pg_sni_router.rs +++ b/proxy/src/bin/pg_sni_router.rs @@ -8,6 +8,7 @@ use std::{net::SocketAddr, sync::Arc}; use futures::future::Either; use itertools::Itertools; use proxy::config::TlsServerEndPoint; +use proxy::proxy::run_until_cancelled; use tokio::net::TcpListener; use anyhow::{anyhow, bail, ensure, Context}; @@ -20,7 +21,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::sync::CancellationToken; use utils::{project_git_version, sentry_init::init_sentry}; -use tracing::{error, info, warn, Instrument}; +use tracing::{error, info, Instrument}; project_git_version!(GIT_VERSION); @@ -151,63 +152,39 @@ async fn task_main( // will be inherited by all accepted client sockets. socket2::SockRef::from(&listener).set_keepalive(true)?; - let mut connections = tokio::task::JoinSet::new(); + let connections = tokio_util::task::task_tracker::TaskTracker::new(); - loop { - tokio::select! { - accept_result = listener.accept() => { - let (socket, peer_addr) = accept_result?; + while let Some(accept_result) = + run_until_cancelled(listener.accept(), &cancellation_token).await + { + let (socket, peer_addr) = accept_result?; - let session_id = uuid::Uuid::new_v4(); - let tls_config = Arc::clone(&tls_config); - let dest_suffix = Arc::clone(&dest_suffix); + let session_id = uuid::Uuid::new_v4(); + let tls_config = Arc::clone(&tls_config); + let dest_suffix = Arc::clone(&dest_suffix); - connections.spawn( - async move { - socket - .set_nodelay(true) - .context("failed to set socket option")?; + connections.spawn( + async move { + socket + .set_nodelay(true) + .context("failed to set socket option")?; - info!(%peer_addr, "serving"); - handle_client(dest_suffix, tls_config, tls_server_end_point, socket).await - } - .unwrap_or_else(|e| { - // Acknowledge that the task has finished with an error. - error!("per-client task finished with an error: {e:#}"); - }) - .instrument(tracing::info_span!("handle_client", ?session_id)) - ); + info!(%peer_addr, "serving"); + handle_client(dest_suffix, tls_config, tls_server_end_point, socket).await } - // Don't modify this unless you read https://docs.rs/tokio/latest/tokio/macro.select.html carefully. - // If this future completes and the pattern doesn't match, this branch is disabled for this call to `select!`. - // This only counts for this loop and it will be enabled again on next `select!`. - // - // Prior code had this as `Some(Err(e))` which _looks_ equivalent to the current setup, but it's not. - // When `connections.join_next()` returned `Some(Ok(()))` (which we expect), it would disable the join_next and it would - // not get called again, even if there are more connections to remove. - Some(res) = connections.join_next() => { - if let Err(e) = res { - if !e.is_panic() && !e.is_cancelled() { - warn!("unexpected error from joined connection task: {e:?}"); - } - } - } - _ = cancellation_token.cancelled() => { - drop(listener); - break; - } - } + .unwrap_or_else(|e| { + // Acknowledge that the task has finished with an error. + error!("per-client task finished with an error: {e:#}"); + }) + .instrument(tracing::info_span!("handle_client", ?session_id)), + ); } - // Drain connections - info!("waiting for all client connections to finish"); - while let Some(res) = connections.join_next().await { - if let Err(e) = res { - if !e.is_panic() && !e.is_cancelled() { - warn!("unexpected error from joined connection task: {e:?}"); - } - } - } + connections.close(); + drop(listener); + + connections.wait().await; + info!("all client connections have finished"); Ok(()) } diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 7cf3ed5b8a..4dbffa850a 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -277,6 +277,21 @@ static NUM_BYTES_PROXIED_COUNTER: Lazy = Lazy::new(|| { .unwrap() }); +pub async fn run_until_cancelled( + f: F, + cancellation_token: &CancellationToken, +) -> Option { + match futures::future::select( + std::pin::pin!(f), + std::pin::pin!(cancellation_token.cancelled()), + ) + .await + { + futures::future::Either::Left((f, _)) => Some(f), + futures::future::Either::Right(((), _)) => None, + } +} + pub async fn task_main( config: &'static ProxyConfig, listener: tokio::net::TcpListener, @@ -290,71 +305,62 @@ pub async fn task_main( // will be inherited by all accepted client sockets. socket2::SockRef::from(&listener).set_keepalive(true)?; - let mut connections = tokio::task::JoinSet::new(); + let connections = tokio_util::task::task_tracker::TaskTracker::new(); let cancel_map = Arc::new(CancelMap::default()); - loop { - tokio::select! { - accept_result = listener.accept() => { - let (socket, peer_addr) = accept_result?; + while let Some(accept_result) = + run_until_cancelled(listener.accept(), &cancellation_token).await + { + let (socket, peer_addr) = accept_result?; - let session_id = uuid::Uuid::new_v4(); - let cancel_map = Arc::clone(&cancel_map); - connections.spawn( - async move { - info!("accepted postgres client connection"); + let session_id = uuid::Uuid::new_v4(); + let cancel_map = Arc::clone(&cancel_map); + connections.spawn( + async move { + info!("accepted postgres client connection"); - let mut socket = WithClientIp::new(socket); - let mut peer_addr = peer_addr; - if let Some(ip) = socket.wait_for_addr().await? { - peer_addr = ip; - tracing::Span::current().record("peer_addr", &tracing::field::display(ip)); - } else if config.require_client_ip { - bail!("missing required client IP"); - } - - socket - .inner - .set_nodelay(true) - .context("failed to set socket option")?; - - handle_client(config, &cancel_map, session_id, socket, ClientMode::Tcp, peer_addr.ip()).await - } - .instrument(info_span!("handle_client", ?session_id, peer_addr = tracing::field::Empty)) - .unwrap_or_else(move |e| { - // Acknowledge that the task has finished with an error. - error!(?session_id, "per-client task finished with an error: {e:#}"); - }), - ); - } - // Don't modify this unless you read https://docs.rs/tokio/latest/tokio/macro.select.html carefully. - // If this future completes and the pattern doesn't match, this branch is disabled for this call to `select!`. - // This only counts for this loop and it will be enabled again on next `select!`. - // - // Prior code had this as `Some(Err(e))` which _looks_ equivalent to the current setup, but it's not. - // When `connections.join_next()` returned `Some(Ok(()))` (which we expect), it would disable the join_next and it would - // not get called again, even if there are more connections to remove. - Some(res) = connections.join_next() => { - if let Err(e) = res { - if !e.is_panic() && !e.is_cancelled() { - warn!("unexpected error from joined connection task: {e:?}"); - } + let mut socket = WithClientIp::new(socket); + let mut peer_addr = peer_addr; + if let Some(ip) = socket.wait_for_addr().await? { + peer_addr = ip; + tracing::Span::current().record("peer_addr", &tracing::field::display(ip)); + } else if config.require_client_ip { + bail!("missing required client IP"); } + + socket + .inner + .set_nodelay(true) + .context("failed to set socket option")?; + + handle_client( + config, + &cancel_map, + session_id, + socket, + ClientMode::Tcp, + peer_addr.ip(), + ) + .await } - _ = cancellation_token.cancelled() => { - drop(listener); - break; - } - } + .instrument(info_span!( + "handle_client", + ?session_id, + peer_addr = tracing::field::Empty + )) + .unwrap_or_else(move |e| { + // Acknowledge that the task has finished with an error. + error!(?session_id, "per-client task finished with an error: {e:#}"); + }), + ); } + + connections.close(); + drop(listener); + // Drain connections - while let Some(res) = connections.join_next().await { - if let Err(e) = res { - if !e.is_panic() && !e.is_cancelled() { - warn!("unexpected error from joined connection task: {e:?}"); - } - } - } + connections.wait().await; + Ok(()) } diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index 5a992d6461..cd496ff01e 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -10,6 +10,7 @@ use anyhow::bail; use hyper::StatusCode; pub use reqwest_middleware::{ClientWithMiddleware, Error}; pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; +use tokio_util::task::TaskTracker; use crate::protocol2::{ProxyProtocolAccept, WithClientIp}; use crate::proxy::{NUM_CLIENT_CONNECTION_CLOSED_COUNTER, NUM_CLIENT_CONNECTION_OPENED_COUNTER}; @@ -70,6 +71,9 @@ pub async fn task_main( incoming: addr_incoming, }; + let ws_connections = tokio_util::task::task_tracker::TaskTracker::new(); + ws_connections.close(); // allows `ws_connections.wait to complete` + let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| { if let Err(err) = conn { error!("failed to accept TLS connection for websockets: {err:?}"); @@ -86,6 +90,7 @@ pub async fn task_main( let remote_addr = io.inner.remote_addr(); let sni_name = tls.server_name().map(|s| s.to_string()); let conn_pool = conn_pool.clone(); + let ws_connections = ws_connections.clone(); async move { let peer_addr = match client_addr { @@ -97,6 +102,7 @@ pub async fn task_main( move |req: Request| { let sni_name = sni_name.clone(); let conn_pool = conn_pool.clone(); + let ws_connections = ws_connections.clone(); async move { let cancel_map = Arc::new(CancelMap::default()); @@ -106,6 +112,7 @@ pub async fn task_main( req, config, conn_pool, + ws_connections, cancel_map, session_id, sni_name, @@ -129,6 +136,9 @@ pub async fn task_main( .with_graceful_shutdown(cancellation_token.cancelled()) .await?; + // await websocket connections + ws_connections.wait().await; + Ok(()) } @@ -170,10 +180,12 @@ where } } +#[allow(clippy::too_many_arguments)] async fn request_handler( mut request: Request, config: &'static ProxyConfig, conn_pool: Arc, + ws_connections: TaskTracker, cancel_map: Arc, session_id: uuid::Uuid, sni_hostname: Option, @@ -193,7 +205,7 @@ async fn request_handler( let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None) .map_err(|e| ApiError::BadRequest(e.into()))?; - tokio::spawn( + ws_connections.spawn( async move { if let Err(e) = websocket::serve_websocket( websocket, diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 82945dfacb..3653643d7e 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -65,7 +65,7 @@ subtle = { version = "2" } time = { version = "0.3", features = ["local-offset", "macros", "serde-well-known"] } tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "test-util"] } tokio-rustls = { version = "0.24" } -tokio-util = { version = "0.7", features = ["codec", "compat", "io"] } +tokio-util = { version = "0.7", features = ["codec", "compat", "io", "rt"] } toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } toml_edit = { version = "0.19", features = ["serde"] } tower = { version = "0.4", default-features = false, features = ["balance", "buffer", "limit", "log", "timeout", "util"] }