diff --git a/Cargo.lock b/Cargo.lock index 0a434c6ee6..738771f88b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6483,6 +6483,7 @@ dependencies = [ "clap", "clap_builder", "crossbeam-utils", + "dashmap", "either", "fail", "futures", diff --git a/Cargo.toml b/Cargo.toml index 5b719d776b..e528489f1e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -67,7 +67,7 @@ comfy-table = "6.1" const_format = "0.2" crc32c = "0.6" crossbeam-utils = "0.8.5" -dashmap = "5.5.0" +dashmap = { version = "5.5.0", features = ["raw-api"] } either = "1.8" enum-map = "2.4.2" enumset = "1.0.12" diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 28d0d95c5b..7d1b7eaaae 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -80,6 +80,9 @@ struct ProxyCliArgs { /// cache for `wake_compute` api method (use `size=0` to disable) #[clap(long, default_value = config::CacheOptions::DEFAULT_OPTIONS_NODE_INFO)] wake_compute_cache: String, + /// lock for `wake_compute` api method. example: "shards=32,permits=4,epoch=10m,timeout=1s". (use `permits=0` to disable). + #[clap(long, default_value = config::WakeComputeLockOptions::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK)] + wake_compute_lock: String, /// Allow self-signed certificates for compute nodes (for testing) #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] allow_self_signed_compute: bool, @@ -220,10 +223,23 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { node_info: console::caches::NodeInfoCache::new("node_info_cache", size, ttl), })); + let config::WakeComputeLockOptions { + shards, + permits, + epoch, + timeout, + } = args.wake_compute_lock.parse()?; + info!(permits, shards, ?epoch, "Using NodeLocks (wake_compute)"); + let locks = Box::leak(Box::new( + console::locks::ApiLocks::new("wake_compute_lock", permits, shards, timeout) + .unwrap(), + )); + tokio::spawn(locks.garbage_collect_worker(epoch)); + let url = args.auth_endpoint.parse()?; let endpoint = http::Endpoint::new(url, http::new_client()); - let api = console::provider::neon::Api::new(endpoint, caches); + let api = console::provider::neon::Api::new(endpoint, caches, locks); auth::BackendType::Console(Cow::Owned(api), ()) } AuthBackend::Postgres => { diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 9607ecd153..bd00123905 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -264,6 +264,79 @@ impl FromStr for CacheOptions { } } +/// Helper for cmdline cache options parsing. +pub struct WakeComputeLockOptions { + /// The number of shards the lock map should have + pub shards: usize, + /// The number of allowed concurrent requests for each endpoitn + pub permits: usize, + /// Garbage collection epoch + pub epoch: Duration, + /// Lock timeout + pub timeout: Duration, +} + +impl WakeComputeLockOptions { + /// Default options for [`crate::console::provider::ApiLocks`]. + pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "permits=0"; + + // pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "shards=32,permits=4,epoch=10m,timeout=1s"; + + /// Parse lock options passed via cmdline. + /// Example: [`Self::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK`]. + fn parse(options: &str) -> anyhow::Result { + let mut shards = None; + let mut permits = None; + let mut epoch = None; + let mut timeout = None; + + for option in options.split(',') { + let (key, value) = option + .split_once('=') + .with_context(|| format!("bad key-value pair: {option}"))?; + + match key { + "shards" => shards = Some(value.parse()?), + "permits" => permits = Some(value.parse()?), + "epoch" => epoch = Some(humantime::parse_duration(value)?), + "timeout" => timeout = Some(humantime::parse_duration(value)?), + unknown => bail!("unknown key: {unknown}"), + } + } + + // these dont matter if lock is disabled + if let Some(0) = permits { + timeout = Some(Duration::default()); + epoch = Some(Duration::default()); + shards = Some(2); + } + + let out = Self { + shards: shards.context("missing `shards`")?, + permits: permits.context("missing `permits`")?, + epoch: epoch.context("missing `epoch`")?, + timeout: timeout.context("missing `timeout`")?, + }; + + ensure!(out.shards > 1, "shard count must be > 1"); + ensure!( + out.shards.is_power_of_two(), + "shard count must be a power of two" + ); + + Ok(out) + } +} + +impl FromStr for WakeComputeLockOptions { + type Err = anyhow::Error; + + fn from_str(options: &str) -> Result { + let error = || format!("failed to parse cache lock options '{options}'"); + Self::parse(options).with_context(error) + } +} + #[cfg(test)] mod tests { use super::*; @@ -288,4 +361,42 @@ mod tests { Ok(()) } + + #[test] + fn test_parse_lock_options() -> anyhow::Result<()> { + let WakeComputeLockOptions { + epoch, + permits, + shards, + timeout, + } = "shards=32,permits=4,epoch=10m,timeout=1s".parse()?; + assert_eq!(epoch, Duration::from_secs(10 * 60)); + assert_eq!(timeout, Duration::from_secs(1)); + assert_eq!(shards, 32); + assert_eq!(permits, 4); + + let WakeComputeLockOptions { + epoch, + permits, + shards, + timeout, + } = "epoch=60s,shards=16,timeout=100ms,permits=8".parse()?; + assert_eq!(epoch, Duration::from_secs(60)); + assert_eq!(timeout, Duration::from_millis(100)); + assert_eq!(shards, 16); + assert_eq!(permits, 8); + + let WakeComputeLockOptions { + epoch, + permits, + shards, + timeout, + } = "permits=0".parse()?; + assert_eq!(epoch, Duration::ZERO); + assert_eq!(timeout, Duration::ZERO); + assert_eq!(shards, 2); + assert_eq!(permits, 0); + + Ok(()) + } } diff --git a/proxy/src/console.rs b/proxy/src/console.rs index 0e5eaaf845..6da627389e 100644 --- a/proxy/src/console.rs +++ b/proxy/src/console.rs @@ -13,5 +13,10 @@ pub mod caches { pub use super::provider::{ApiCaches, NodeInfoCache}; } +/// Various cache-related types. +pub mod locks { + pub use super::provider::ApiLocks; +} + /// Console's management API. pub mod mgmt; diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index c7cfc88c75..54bcd1f081 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -8,7 +8,13 @@ use crate::{ compute, scram, }; use async_trait::async_trait; -use std::sync::Arc; +use dashmap::DashMap; +use std::{sync::Arc, time::Duration}; +use tokio::{ + sync::{OwnedSemaphorePermit, Semaphore}, + time::Instant, +}; +use tracing::info; pub mod errors { use crate::{ @@ -149,6 +155,9 @@ pub mod errors { #[error(transparent)] ApiError(ApiError), + + #[error("Timeout waiting to acquire wake compute lock")] + TimeoutError, } // This allows more useful interactions than `#[from]`. @@ -158,6 +167,17 @@ pub mod errors { } } + impl From for WakeComputeError { + fn from(_: tokio::sync::AcquireError) -> Self { + WakeComputeError::TimeoutError + } + } + impl From for WakeComputeError { + fn from(_: tokio::time::error::Elapsed) -> Self { + WakeComputeError::TimeoutError + } + } + impl UserFacingError for WakeComputeError { fn to_string_client(&self) -> String { use WakeComputeError::*; @@ -167,6 +187,8 @@ pub mod errors { BadComputeAddress(_) => REQUEST_FAILED.to_owned(), // However, API might return a meaningful error. ApiError(e) => e.to_string_client(), + + TimeoutError => "timeout while acquiring the compute resource lock".to_owned(), } } } @@ -233,3 +255,145 @@ pub struct ApiCaches { /// Cache for the `wake_compute` API method. pub node_info: NodeInfoCache, } + +/// Various caches for [`console`](super). +pub struct ApiLocks { + name: &'static str, + node_locks: DashMap, Arc>, + permits: usize, + timeout: Duration, + registered: prometheus::IntCounter, + unregistered: prometheus::IntCounter, + reclamation_lag: prometheus::Histogram, + lock_acquire_lag: prometheus::Histogram, +} + +impl ApiLocks { + pub fn new( + name: &'static str, + permits: usize, + shards: usize, + timeout: Duration, + ) -> prometheus::Result { + let registered = prometheus::IntCounter::with_opts( + prometheus::Opts::new( + "semaphores_registered", + "Number of semaphores registered in this api lock", + ) + .namespace(name), + )?; + prometheus::register(Box::new(registered.clone()))?; + let unregistered = prometheus::IntCounter::with_opts( + prometheus::Opts::new( + "semaphores_unregistered", + "Number of semaphores unregistered in this api lock", + ) + .namespace(name), + )?; + prometheus::register(Box::new(unregistered.clone()))?; + let reclamation_lag = prometheus::Histogram::with_opts( + prometheus::HistogramOpts::new( + "reclamation_lag_seconds", + "Time it takes to reclaim unused semaphores in the api lock", + ) + .namespace(name) + // 1us -> 65ms + // benchmarks on my mac indicate it's usually in the range of 256us and 512us + .buckets(prometheus::exponential_buckets(1e-6, 2.0, 16)?), + )?; + prometheus::register(Box::new(reclamation_lag.clone()))?; + let lock_acquire_lag = prometheus::Histogram::with_opts( + prometheus::HistogramOpts::new( + "semaphore_acquire_seconds", + "Time it takes to reclaim unused semaphores in the api lock", + ) + .namespace(name) + // 0.1ms -> 6s + .buckets(prometheus::exponential_buckets(1e-4, 2.0, 16)?), + )?; + prometheus::register(Box::new(lock_acquire_lag.clone()))?; + + Ok(Self { + name, + node_locks: DashMap::with_shard_amount(shards), + permits, + timeout, + lock_acquire_lag, + registered, + unregistered, + reclamation_lag, + }) + } + + pub async fn get_wake_compute_permit( + &self, + key: &Arc, + ) -> Result { + if self.permits == 0 { + return Ok(WakeComputePermit { permit: None }); + } + let now = Instant::now(); + let semaphore = { + // get fast path + if let Some(semaphore) = self.node_locks.get(key) { + semaphore.clone() + } else { + self.node_locks + .entry(key.clone()) + .or_insert_with(|| { + self.registered.inc(); + Arc::new(Semaphore::new(self.permits)) + }) + .clone() + } + }; + let permit = tokio::time::timeout_at(now + self.timeout, semaphore.acquire_owned()).await; + + self.lock_acquire_lag + .observe((Instant::now() - now).as_secs_f64()); + + Ok(WakeComputePermit { + permit: Some(permit??), + }) + } + + pub async fn garbage_collect_worker(&self, epoch: std::time::Duration) { + if self.permits == 0 { + return; + } + + let mut interval = tokio::time::interval(epoch / (self.node_locks.shards().len()) as u32); + loop { + for (i, shard) in self.node_locks.shards().iter().enumerate() { + interval.tick().await; + // temporary lock a single shard and then clear any semaphores that aren't currently checked out + // race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked + // therefore releasing it is safe from race conditions + info!( + name = self.name, + shard = i, + "performing epoch reclamation on api lock" + ); + let mut lock = shard.write(); + let timer = self.reclamation_lag.start_timer(); + let count = lock + .extract_if(|_, semaphore| Arc::strong_count(semaphore.get_mut()) == 1) + .count(); + drop(lock); + self.unregistered.inc_by(count as u64); + timer.observe_duration() + } + } + } +} + +pub struct WakeComputePermit { + // None if the lock is disabled + permit: Option, +} + +impl WakeComputePermit { + pub fn should_check_cache(&self) -> bool { + self.permit.is_some() + } +} diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index 6229840c46..0dc7c71534 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -3,12 +3,12 @@ use super::{ super::messages::{ConsoleError, GetRoleSecret, WakeCompute}, errors::{ApiError, GetAuthInfoError, WakeComputeError}, - ApiCaches, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo, + ApiCaches, ApiLocks, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo, }; use crate::{auth::ClientCredentials, compute, http, scram}; use async_trait::async_trait; use futures::TryFutureExt; -use std::net::SocketAddr; +use std::{net::SocketAddr, sync::Arc}; use tokio::time::Instant; use tokio_postgres::config::SslMode; use tracing::{error, info, info_span, warn, Instrument}; @@ -17,12 +17,17 @@ use tracing::{error, info, info_span, warn, Instrument}; pub struct Api { endpoint: http::Endpoint, caches: &'static ApiCaches, + locks: &'static ApiLocks, jwt: String, } impl Api { /// Construct an API object containing the auth parameters. - pub fn new(endpoint: http::Endpoint, caches: &'static ApiCaches) -> Self { + pub fn new( + endpoint: http::Endpoint, + caches: &'static ApiCaches, + locks: &'static ApiLocks, + ) -> Self { let jwt: String = match std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN") { Ok(v) => v, Err(_) => "".to_string(), @@ -30,6 +35,7 @@ impl Api { Self { endpoint, caches, + locks, jwt, } } @@ -163,9 +169,22 @@ impl super::Api for Api { return Ok(cached); } + let key: Arc = key.into(); + + let permit = self.locks.get_wake_compute_permit(&key).await?; + + // after getting back a permit - it's possible the cache was filled + // double check + if permit.should_check_cache() { + if let Some(cached) = self.caches.node_info.get(&key) { + info!(key = &*key, "found cached compute node info"); + return Ok(cached); + } + } + let node = self.do_wake_compute(extra, creds).await?; - let (_, cached) = self.caches.node_info.insert(key.into(), node); - info!(key = key, "created a cache entry for compute node info"); + let (_, cached) = self.caches.node_info.insert(key.clone(), node); + info!(key = &*key, "created a cache entry for compute node info"); Ok(cached) } diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 54c3503c93..a1ebf03545 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -570,6 +570,7 @@ fn report_error(e: &WakeComputeError, retry: bool) { "api_console_other_server_error" } WakeComputeError::ApiError(ApiError::Console { .. }) => "api_console_other_error", + WakeComputeError::TimeoutError => "timeout_error", }; NUM_WAKEUP_FAILURES.with_label_values(&[retry, kind]).inc(); } diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index e2a65ad150..a088f1868b 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -25,6 +25,7 @@ chrono = { version = "0.4", default-features = false, features = ["clock", "serd clap = { version = "4", features = ["derive", "string"] } clap_builder = { version = "4", default-features = false, features = ["color", "help", "std", "string", "suggestions", "usage"] } crossbeam-utils = { version = "0.8" } +dashmap = { version = "5", default-features = false, features = ["raw-api"] } either = { version = "1" } fail = { version = "0.5", default-features = false, features = ["failpoints"] } futures = { version = "0.3" }