mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-05 20:42:54 +00:00
proxy: limit concurrent wake_compute requests per endpoint (#5799)
## Problem A user can perform many database connections at the same instant of time - these will all cache miss and materialise as requests to the control plane. #5705 ## Summary of changes I am using a `DashMap` (a sharded `RwLock<HashMap>`) of endpoints -> semaphores to apply a limiter. If the limiter is enabled (permits > 0), the semaphore will be retrieved per endpoint and a permit will be awaited before continuing to call the wake_compute endpoint. ### Important details This dashmap would grow uncontrollably without maintenance. It's not a cache so I don't think an LRU-based reclamation makes sense. Instead, I've made use of the sharding functionality of DashMap to lock a single shard and clear out unused semaphores periodically. I ran a test in release, using 128 tokio tasks among 12 threads each pushing 1000 entries into the map per second, clearing a shard every 2 seconds (64 second epoch with 32 shards). The endpoint names were sampled from a gamma distribution to make sure some overlap would occur, and each permit was held for 1ms. The histogram for time to clear each shard settled between 256-512us without any variance in my testing. Holding a lock for under a millisecond for 1 of the shards does not concern me as blocking
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -6483,6 +6483,7 @@ dependencies = [
|
||||
"clap",
|
||||
"clap_builder",
|
||||
"crossbeam-utils",
|
||||
"dashmap",
|
||||
"either",
|
||||
"fail",
|
||||
"futures",
|
||||
|
||||
@@ -67,7 +67,7 @@ comfy-table = "6.1"
|
||||
const_format = "0.2"
|
||||
crc32c = "0.6"
|
||||
crossbeam-utils = "0.8.5"
|
||||
dashmap = "5.5.0"
|
||||
dashmap = { version = "5.5.0", features = ["raw-api"] }
|
||||
either = "1.8"
|
||||
enum-map = "2.4.2"
|
||||
enumset = "1.0.12"
|
||||
|
||||
@@ -80,6 +80,9 @@ struct ProxyCliArgs {
|
||||
/// cache for `wake_compute` api method (use `size=0` to disable)
|
||||
#[clap(long, default_value = config::CacheOptions::DEFAULT_OPTIONS_NODE_INFO)]
|
||||
wake_compute_cache: String,
|
||||
/// lock for `wake_compute` api method. example: "shards=32,permits=4,epoch=10m,timeout=1s". (use `permits=0` to disable).
|
||||
#[clap(long, default_value = config::WakeComputeLockOptions::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK)]
|
||||
wake_compute_lock: String,
|
||||
/// Allow self-signed certificates for compute nodes (for testing)
|
||||
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
|
||||
allow_self_signed_compute: bool,
|
||||
@@ -220,10 +223,23 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
node_info: console::caches::NodeInfoCache::new("node_info_cache", size, ttl),
|
||||
}));
|
||||
|
||||
let config::WakeComputeLockOptions {
|
||||
shards,
|
||||
permits,
|
||||
epoch,
|
||||
timeout,
|
||||
} = args.wake_compute_lock.parse()?;
|
||||
info!(permits, shards, ?epoch, "Using NodeLocks (wake_compute)");
|
||||
let locks = Box::leak(Box::new(
|
||||
console::locks::ApiLocks::new("wake_compute_lock", permits, shards, timeout)
|
||||
.unwrap(),
|
||||
));
|
||||
tokio::spawn(locks.garbage_collect_worker(epoch));
|
||||
|
||||
let url = args.auth_endpoint.parse()?;
|
||||
let endpoint = http::Endpoint::new(url, http::new_client());
|
||||
|
||||
let api = console::provider::neon::Api::new(endpoint, caches);
|
||||
let api = console::provider::neon::Api::new(endpoint, caches, locks);
|
||||
auth::BackendType::Console(Cow::Owned(api), ())
|
||||
}
|
||||
AuthBackend::Postgres => {
|
||||
|
||||
@@ -264,6 +264,79 @@ impl FromStr for CacheOptions {
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper for cmdline cache options parsing.
|
||||
pub struct WakeComputeLockOptions {
|
||||
/// The number of shards the lock map should have
|
||||
pub shards: usize,
|
||||
/// The number of allowed concurrent requests for each endpoitn
|
||||
pub permits: usize,
|
||||
/// Garbage collection epoch
|
||||
pub epoch: Duration,
|
||||
/// Lock timeout
|
||||
pub timeout: Duration,
|
||||
}
|
||||
|
||||
impl WakeComputeLockOptions {
|
||||
/// Default options for [`crate::console::provider::ApiLocks`].
|
||||
pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "permits=0";
|
||||
|
||||
// pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "shards=32,permits=4,epoch=10m,timeout=1s";
|
||||
|
||||
/// Parse lock options passed via cmdline.
|
||||
/// Example: [`Self::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK`].
|
||||
fn parse(options: &str) -> anyhow::Result<Self> {
|
||||
let mut shards = None;
|
||||
let mut permits = None;
|
||||
let mut epoch = None;
|
||||
let mut timeout = None;
|
||||
|
||||
for option in options.split(',') {
|
||||
let (key, value) = option
|
||||
.split_once('=')
|
||||
.with_context(|| format!("bad key-value pair: {option}"))?;
|
||||
|
||||
match key {
|
||||
"shards" => shards = Some(value.parse()?),
|
||||
"permits" => permits = Some(value.parse()?),
|
||||
"epoch" => epoch = Some(humantime::parse_duration(value)?),
|
||||
"timeout" => timeout = Some(humantime::parse_duration(value)?),
|
||||
unknown => bail!("unknown key: {unknown}"),
|
||||
}
|
||||
}
|
||||
|
||||
// these dont matter if lock is disabled
|
||||
if let Some(0) = permits {
|
||||
timeout = Some(Duration::default());
|
||||
epoch = Some(Duration::default());
|
||||
shards = Some(2);
|
||||
}
|
||||
|
||||
let out = Self {
|
||||
shards: shards.context("missing `shards`")?,
|
||||
permits: permits.context("missing `permits`")?,
|
||||
epoch: epoch.context("missing `epoch`")?,
|
||||
timeout: timeout.context("missing `timeout`")?,
|
||||
};
|
||||
|
||||
ensure!(out.shards > 1, "shard count must be > 1");
|
||||
ensure!(
|
||||
out.shards.is_power_of_two(),
|
||||
"shard count must be a power of two"
|
||||
);
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for WakeComputeLockOptions {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(options: &str) -> Result<Self, Self::Err> {
|
||||
let error = || format!("failed to parse cache lock options '{options}'");
|
||||
Self::parse(options).with_context(error)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -288,4 +361,42 @@ mod tests {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_lock_options() -> anyhow::Result<()> {
|
||||
let WakeComputeLockOptions {
|
||||
epoch,
|
||||
permits,
|
||||
shards,
|
||||
timeout,
|
||||
} = "shards=32,permits=4,epoch=10m,timeout=1s".parse()?;
|
||||
assert_eq!(epoch, Duration::from_secs(10 * 60));
|
||||
assert_eq!(timeout, Duration::from_secs(1));
|
||||
assert_eq!(shards, 32);
|
||||
assert_eq!(permits, 4);
|
||||
|
||||
let WakeComputeLockOptions {
|
||||
epoch,
|
||||
permits,
|
||||
shards,
|
||||
timeout,
|
||||
} = "epoch=60s,shards=16,timeout=100ms,permits=8".parse()?;
|
||||
assert_eq!(epoch, Duration::from_secs(60));
|
||||
assert_eq!(timeout, Duration::from_millis(100));
|
||||
assert_eq!(shards, 16);
|
||||
assert_eq!(permits, 8);
|
||||
|
||||
let WakeComputeLockOptions {
|
||||
epoch,
|
||||
permits,
|
||||
shards,
|
||||
timeout,
|
||||
} = "permits=0".parse()?;
|
||||
assert_eq!(epoch, Duration::ZERO);
|
||||
assert_eq!(timeout, Duration::ZERO);
|
||||
assert_eq!(shards, 2);
|
||||
assert_eq!(permits, 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,5 +13,10 @@ pub mod caches {
|
||||
pub use super::provider::{ApiCaches, NodeInfoCache};
|
||||
}
|
||||
|
||||
/// Various cache-related types.
|
||||
pub mod locks {
|
||||
pub use super::provider::ApiLocks;
|
||||
}
|
||||
|
||||
/// Console's management API.
|
||||
pub mod mgmt;
|
||||
|
||||
@@ -8,7 +8,13 @@ use crate::{
|
||||
compute, scram,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
use dashmap::DashMap;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use tokio::{
|
||||
sync::{OwnedSemaphorePermit, Semaphore},
|
||||
time::Instant,
|
||||
};
|
||||
use tracing::info;
|
||||
|
||||
pub mod errors {
|
||||
use crate::{
|
||||
@@ -149,6 +155,9 @@ pub mod errors {
|
||||
|
||||
#[error(transparent)]
|
||||
ApiError(ApiError),
|
||||
|
||||
#[error("Timeout waiting to acquire wake compute lock")]
|
||||
TimeoutError,
|
||||
}
|
||||
|
||||
// This allows more useful interactions than `#[from]`.
|
||||
@@ -158,6 +167,17 @@ pub mod errors {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tokio::sync::AcquireError> for WakeComputeError {
|
||||
fn from(_: tokio::sync::AcquireError) -> Self {
|
||||
WakeComputeError::TimeoutError
|
||||
}
|
||||
}
|
||||
impl From<tokio::time::error::Elapsed> for WakeComputeError {
|
||||
fn from(_: tokio::time::error::Elapsed) -> Self {
|
||||
WakeComputeError::TimeoutError
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for WakeComputeError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use WakeComputeError::*;
|
||||
@@ -167,6 +187,8 @@ pub mod errors {
|
||||
BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
|
||||
// However, API might return a meaningful error.
|
||||
ApiError(e) => e.to_string_client(),
|
||||
|
||||
TimeoutError => "timeout while acquiring the compute resource lock".to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -233,3 +255,145 @@ pub struct ApiCaches {
|
||||
/// Cache for the `wake_compute` API method.
|
||||
pub node_info: NodeInfoCache,
|
||||
}
|
||||
|
||||
/// Various caches for [`console`](super).
|
||||
pub struct ApiLocks {
|
||||
name: &'static str,
|
||||
node_locks: DashMap<Arc<str>, Arc<Semaphore>>,
|
||||
permits: usize,
|
||||
timeout: Duration,
|
||||
registered: prometheus::IntCounter,
|
||||
unregistered: prometheus::IntCounter,
|
||||
reclamation_lag: prometheus::Histogram,
|
||||
lock_acquire_lag: prometheus::Histogram,
|
||||
}
|
||||
|
||||
impl ApiLocks {
|
||||
pub fn new(
|
||||
name: &'static str,
|
||||
permits: usize,
|
||||
shards: usize,
|
||||
timeout: Duration,
|
||||
) -> prometheus::Result<Self> {
|
||||
let registered = prometheus::IntCounter::with_opts(
|
||||
prometheus::Opts::new(
|
||||
"semaphores_registered",
|
||||
"Number of semaphores registered in this api lock",
|
||||
)
|
||||
.namespace(name),
|
||||
)?;
|
||||
prometheus::register(Box::new(registered.clone()))?;
|
||||
let unregistered = prometheus::IntCounter::with_opts(
|
||||
prometheus::Opts::new(
|
||||
"semaphores_unregistered",
|
||||
"Number of semaphores unregistered in this api lock",
|
||||
)
|
||||
.namespace(name),
|
||||
)?;
|
||||
prometheus::register(Box::new(unregistered.clone()))?;
|
||||
let reclamation_lag = prometheus::Histogram::with_opts(
|
||||
prometheus::HistogramOpts::new(
|
||||
"reclamation_lag_seconds",
|
||||
"Time it takes to reclaim unused semaphores in the api lock",
|
||||
)
|
||||
.namespace(name)
|
||||
// 1us -> 65ms
|
||||
// benchmarks on my mac indicate it's usually in the range of 256us and 512us
|
||||
.buckets(prometheus::exponential_buckets(1e-6, 2.0, 16)?),
|
||||
)?;
|
||||
prometheus::register(Box::new(reclamation_lag.clone()))?;
|
||||
let lock_acquire_lag = prometheus::Histogram::with_opts(
|
||||
prometheus::HistogramOpts::new(
|
||||
"semaphore_acquire_seconds",
|
||||
"Time it takes to reclaim unused semaphores in the api lock",
|
||||
)
|
||||
.namespace(name)
|
||||
// 0.1ms -> 6s
|
||||
.buckets(prometheus::exponential_buckets(1e-4, 2.0, 16)?),
|
||||
)?;
|
||||
prometheus::register(Box::new(lock_acquire_lag.clone()))?;
|
||||
|
||||
Ok(Self {
|
||||
name,
|
||||
node_locks: DashMap::with_shard_amount(shards),
|
||||
permits,
|
||||
timeout,
|
||||
lock_acquire_lag,
|
||||
registered,
|
||||
unregistered,
|
||||
reclamation_lag,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_wake_compute_permit(
|
||||
&self,
|
||||
key: &Arc<str>,
|
||||
) -> Result<WakeComputePermit, errors::WakeComputeError> {
|
||||
if self.permits == 0 {
|
||||
return Ok(WakeComputePermit { permit: None });
|
||||
}
|
||||
let now = Instant::now();
|
||||
let semaphore = {
|
||||
// get fast path
|
||||
if let Some(semaphore) = self.node_locks.get(key) {
|
||||
semaphore.clone()
|
||||
} else {
|
||||
self.node_locks
|
||||
.entry(key.clone())
|
||||
.or_insert_with(|| {
|
||||
self.registered.inc();
|
||||
Arc::new(Semaphore::new(self.permits))
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
};
|
||||
let permit = tokio::time::timeout_at(now + self.timeout, semaphore.acquire_owned()).await;
|
||||
|
||||
self.lock_acquire_lag
|
||||
.observe((Instant::now() - now).as_secs_f64());
|
||||
|
||||
Ok(WakeComputePermit {
|
||||
permit: Some(permit??),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn garbage_collect_worker(&self, epoch: std::time::Duration) {
|
||||
if self.permits == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut interval = tokio::time::interval(epoch / (self.node_locks.shards().len()) as u32);
|
||||
loop {
|
||||
for (i, shard) in self.node_locks.shards().iter().enumerate() {
|
||||
interval.tick().await;
|
||||
// temporary lock a single shard and then clear any semaphores that aren't currently checked out
|
||||
// race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
|
||||
// therefore releasing it is safe from race conditions
|
||||
info!(
|
||||
name = self.name,
|
||||
shard = i,
|
||||
"performing epoch reclamation on api lock"
|
||||
);
|
||||
let mut lock = shard.write();
|
||||
let timer = self.reclamation_lag.start_timer();
|
||||
let count = lock
|
||||
.extract_if(|_, semaphore| Arc::strong_count(semaphore.get_mut()) == 1)
|
||||
.count();
|
||||
drop(lock);
|
||||
self.unregistered.inc_by(count as u64);
|
||||
timer.observe_duration()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WakeComputePermit {
|
||||
// None if the lock is disabled
|
||||
permit: Option<OwnedSemaphorePermit>,
|
||||
}
|
||||
|
||||
impl WakeComputePermit {
|
||||
pub fn should_check_cache(&self) -> bool {
|
||||
self.permit.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,12 +3,12 @@
|
||||
use super::{
|
||||
super::messages::{ConsoleError, GetRoleSecret, WakeCompute},
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
ApiCaches, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
|
||||
ApiCaches, ApiLocks, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
|
||||
};
|
||||
use crate::{auth::ClientCredentials, compute, http, scram};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use std::net::SocketAddr;
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use tokio::time::Instant;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
@@ -17,12 +17,17 @@ use tracing::{error, info, info_span, warn, Instrument};
|
||||
pub struct Api {
|
||||
endpoint: http::Endpoint,
|
||||
caches: &'static ApiCaches,
|
||||
locks: &'static ApiLocks,
|
||||
jwt: String,
|
||||
}
|
||||
|
||||
impl Api {
|
||||
/// Construct an API object containing the auth parameters.
|
||||
pub fn new(endpoint: http::Endpoint, caches: &'static ApiCaches) -> Self {
|
||||
pub fn new(
|
||||
endpoint: http::Endpoint,
|
||||
caches: &'static ApiCaches,
|
||||
locks: &'static ApiLocks,
|
||||
) -> Self {
|
||||
let jwt: String = match std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN") {
|
||||
Ok(v) => v,
|
||||
Err(_) => "".to_string(),
|
||||
@@ -30,6 +35,7 @@ impl Api {
|
||||
Self {
|
||||
endpoint,
|
||||
caches,
|
||||
locks,
|
||||
jwt,
|
||||
}
|
||||
}
|
||||
@@ -163,9 +169,22 @@ impl super::Api for Api {
|
||||
return Ok(cached);
|
||||
}
|
||||
|
||||
let key: Arc<str> = key.into();
|
||||
|
||||
let permit = self.locks.get_wake_compute_permit(&key).await?;
|
||||
|
||||
// after getting back a permit - it's possible the cache was filled
|
||||
// double check
|
||||
if permit.should_check_cache() {
|
||||
if let Some(cached) = self.caches.node_info.get(&key) {
|
||||
info!(key = &*key, "found cached compute node info");
|
||||
return Ok(cached);
|
||||
}
|
||||
}
|
||||
|
||||
let node = self.do_wake_compute(extra, creds).await?;
|
||||
let (_, cached) = self.caches.node_info.insert(key.into(), node);
|
||||
info!(key = key, "created a cache entry for compute node info");
|
||||
let (_, cached) = self.caches.node_info.insert(key.clone(), node);
|
||||
info!(key = &*key, "created a cache entry for compute node info");
|
||||
|
||||
Ok(cached)
|
||||
}
|
||||
|
||||
@@ -570,6 +570,7 @@ fn report_error(e: &WakeComputeError, retry: bool) {
|
||||
"api_console_other_server_error"
|
||||
}
|
||||
WakeComputeError::ApiError(ApiError::Console { .. }) => "api_console_other_error",
|
||||
WakeComputeError::TimeoutError => "timeout_error",
|
||||
};
|
||||
NUM_WAKEUP_FAILURES.with_label_values(&[retry, kind]).inc();
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ chrono = { version = "0.4", default-features = false, features = ["clock", "serd
|
||||
clap = { version = "4", features = ["derive", "string"] }
|
||||
clap_builder = { version = "4", default-features = false, features = ["color", "help", "std", "string", "suggestions", "usage"] }
|
||||
crossbeam-utils = { version = "0.8" }
|
||||
dashmap = { version = "5", default-features = false, features = ["raw-api"] }
|
||||
either = { version = "1" }
|
||||
fail = { version = "0.5", default-features = false, features = ["failpoints"] }
|
||||
futures = { version = "0.3" }
|
||||
|
||||
Reference in New Issue
Block a user