lock free semaphore map

This commit is contained in:
Conrad Ludgate
2025-07-30 17:58:56 +01:00
parent eb2741758b
commit abdee0524e
6 changed files with 60 additions and 86 deletions

View File

@@ -225,21 +225,14 @@ pub async fn run() -> anyhow::Result<()> {
/// ProxyConfig is created at proxy startup, and lives forever.
fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
let config::ConcurrencyLockOptions {
shards,
limiter,
epoch,
timeout,
} = args.connect_compute_lock.parse()?;
info!(
?limiter,
shards,
?epoch,
"Using NodeLocks (connect_compute)"
);
info!(?limiter, ?epoch, "Using NodeLocks (connect_compute)");
let connect_compute_locks = ApiLocks::new(
"connect_compute_lock",
limiter,
shards,
timeout,
epoch,
&Metrics::get().proxy.connect_compute_lock,

View File

@@ -658,21 +658,14 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
};
let config::ConcurrencyLockOptions {
shards,
limiter,
epoch,
timeout,
} = args.connect_compute_lock.parse()?;
info!(
?limiter,
shards,
?epoch,
"Using NodeLocks (connect_compute)"
);
info!(?limiter, ?epoch, "Using NodeLocks (connect_compute)");
let connect_compute_locks = control_plane::locks::ApiLocks::new(
"connect_compute_lock",
limiter,
shards,
timeout,
epoch,
&Metrics::get().proxy.connect_compute_lock,
@@ -796,16 +789,14 @@ fn build_auth_backend(
)));
let config::ConcurrencyLockOptions {
shards,
limiter,
epoch,
timeout,
} = args.wake_compute_lock.parse()?;
info!(?limiter, shards, ?epoch, "Using NodeLocks (wake_compute)");
info!(?limiter, ?epoch, "Using NodeLocks (wake_compute)");
let locks = Box::leak(Box::new(control_plane::locks::ApiLocks::new(
"wake_compute_lock",
limiter,
shards,
timeout,
epoch,
&Metrics::get().wake_compute_lock,
@@ -874,16 +865,14 @@ fn build_auth_backend(
)));
let config::ConcurrencyLockOptions {
shards,
limiter,
epoch,
timeout,
} = args.wake_compute_lock.parse()?;
info!(?limiter, shards, ?epoch, "Using NodeLocks (wake_compute)");
info!(?limiter, ?epoch, "Using NodeLocks (wake_compute)");
let locks = Box::leak(Box::new(control_plane::locks::ApiLocks::new(
"wake_compute_lock",
limiter,
shards,
timeout,
epoch,
&Metrics::get().wake_compute_lock,

View File

@@ -290,8 +290,6 @@ impl RetryConfig {
/// Helper for cmdline cache options parsing.
#[derive(serde::Deserialize)]
pub struct ConcurrencyLockOptions {
/// The number of shards the lock map should have
pub shards: usize,
/// The number of allowed concurrent requests for each endpoitn
#[serde(flatten)]
pub limiter: RateLimiterConfig,
@@ -308,7 +306,7 @@ impl ConcurrencyLockOptions {
pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "permits=0";
/// Default options for [`crate::control_plane::client::ApiLocks`].
pub const DEFAULT_OPTIONS_CONNECT_COMPUTE_LOCK: &'static str =
"shards=64,permits=100,epoch=10m,timeout=10ms";
"permits=100,epoch=1m,timeout=10ms";
// pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "shards=32,permits=4,epoch=10m,timeout=1s";
@@ -320,7 +318,6 @@ impl ConcurrencyLockOptions {
return Ok(serde_json::from_str(options)?);
}
let mut shards = None;
let mut permits = None;
let mut epoch = None;
let mut timeout = None;
@@ -331,7 +328,8 @@ impl ConcurrencyLockOptions {
.with_context(|| format!("bad key-value pair: {option}"))?;
match key {
"shards" => shards = Some(value.parse()?),
// removed
"shards" => {}
"permits" => permits = Some(value.parse()?),
"epoch" => epoch = Some(humantime::parse_duration(value)?),
"timeout" => timeout = Some(humantime::parse_duration(value)?),
@@ -343,12 +341,10 @@ impl ConcurrencyLockOptions {
if let Some(0) = permits {
timeout = Some(Duration::default());
epoch = Some(Duration::default());
shards = Some(2);
}
let permits = permits.context("missing `permits`")?;
let out = Self {
shards: shards.context("missing `shards`")?,
limiter: RateLimiterConfig {
algorithm: RateLimitAlgorithm::Fixed,
initial_limit: permits,
@@ -357,12 +353,6 @@ impl ConcurrencyLockOptions {
timeout: timeout.context("missing `timeout`")?,
};
ensure!(out.shards > 1, "shard count must be > 1");
ensure!(
out.shards.is_power_of_two(),
"shard count must be a power of two"
);
Ok(out)
}
}
@@ -552,36 +542,30 @@ mod tests {
let ConcurrencyLockOptions {
epoch,
limiter,
shards,
timeout,
} = "shards=32,permits=4,epoch=10m,timeout=1s".parse()?;
assert_eq!(epoch, Duration::from_secs(10 * 60));
assert_eq!(timeout, Duration::from_secs(1));
assert_eq!(shards, 32);
assert_eq!(limiter.initial_limit, 4);
assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
let ConcurrencyLockOptions {
epoch,
limiter,
shards,
timeout,
} = "epoch=60s,shards=16,timeout=100ms,permits=8".parse()?;
assert_eq!(epoch, Duration::from_secs(60));
assert_eq!(timeout, Duration::from_millis(100));
assert_eq!(shards, 16);
assert_eq!(limiter.initial_limit, 8);
assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
let ConcurrencyLockOptions {
epoch,
limiter,
shards,
timeout,
} = "permits=0".parse()?;
assert_eq!(epoch, Duration::ZERO);
assert_eq!(timeout, Duration::ZERO);
assert_eq!(shards, 2);
assert_eq!(limiter.initial_limit, 0);
assert_eq!(limiter.algorithm, RateLimitAlgorithm::Fixed);
@@ -593,13 +577,11 @@ mod tests {
let ConcurrencyLockOptions {
epoch,
limiter,
shards,
timeout,
} = r#"{"shards":32,"initial_limit":44,"aimd":{"min":5,"max":500,"inc":10,"dec":0.9,"utilisation":0.8},"epoch":"10m","timeout":"1s"}"#
.parse()?;
assert_eq!(epoch, Duration::from_secs(10 * 60));
assert_eq!(timeout, Duration::from_secs(1));
assert_eq!(shards, 32);
assert_eq!(limiter.initial_limit, 44);
assert_eq!(
limiter.algorithm,

View File

@@ -6,7 +6,6 @@ use std::hash::Hash;
use std::sync::Arc;
use std::time::Duration;
use clashmap::ClashMap;
use tokio::time::Instant;
use tracing::{debug, info};
@@ -138,7 +137,7 @@ impl ApiCaches {
/// Various caches for [`control_plane`](super).
pub struct ApiLocks<K> {
name: &'static str,
node_locks: ClashMap<K, Arc<DynamicLimiter>>,
node_locks: papaya::HashMap<K, Arc<DynamicLimiter>>,
config: RateLimiterConfig,
timeout: Duration,
epoch: std::time::Duration,
@@ -163,14 +162,13 @@ impl<K: Hash + Eq + Clone> ApiLocks<K> {
pub fn new(
name: &'static str,
config: RateLimiterConfig,
shards: usize,
timeout: Duration,
epoch: std::time::Duration,
metrics: &'static ApiLockMetrics,
) -> Self {
Self {
name,
node_locks: ClashMap::with_shard_amount(shards),
node_locks: papaya::HashMap::new(),
config,
timeout,
epoch,
@@ -184,21 +182,17 @@ impl<K: Hash + Eq + Clone> ApiLocks<K> {
permit: Token::disabled(),
});
}
let now = Instant::now();
let semaphore = {
// get fast path
if let Some(semaphore) = self.node_locks.get(key) {
semaphore.clone()
} else {
self.node_locks
.entry(key.clone())
.or_insert_with(|| {
self.metrics.semaphores_registered.inc();
DynamicLimiter::new(self.config)
})
.clone()
}
};
let semaphore = self
.node_locks
.pin()
.get_or_insert_with(key.clone(), || {
self.metrics.semaphores_registered.inc();
DynamicLimiter::new(self.config)
})
.clone();
let permit = semaphore.acquire_timeout(self.timeout).await;
self.metrics
@@ -217,28 +211,28 @@ impl<K: Hash + Eq + Clone> ApiLocks<K> {
if self.config.initial_limit == 0 {
return;
}
let mut interval =
tokio::time::interval(self.epoch / (self.node_locks.shards().len()) as u32);
let mut interval = tokio::time::interval(self.epoch);
loop {
for (i, shard) in self.node_locks.shards().iter().enumerate() {
interval.tick().await;
// temporary lock a single shard and then clear any semaphores that aren't currently checked out
// race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
// therefore releasing it is safe from race conditions
info!(
name = self.name,
shard = i,
"performing epoch reclamation on api lock"
);
let mut lock = shard.write();
let timer = self.metrics.reclamation_lag_seconds.start_timer();
let count = lock
.extract_if(|(_, semaphore)| Arc::strong_count(semaphore) == 1)
.count();
drop(lock);
self.metrics.semaphores_unregistered.inc_by(count as u64);
timer.observe();
interval.tick().await;
info!(name = self.name, "performing epoch reclamation on api lock");
let timer = self.metrics.reclamation_lag_seconds.start_timer();
let mut count = 0;
let guard = self.node_locks.pin();
for (key, sem) in &guard {
// check if we might be able to remove
if Arc::strong_count(sem) == 1 {
// try and atomically remove
let res = guard.remove_if(key, |_key, sem| Arc::strong_count(sem) == 1);
if let Ok(Some(..)) = res {
count += 1;
}
}
}
drop(guard);
timer.observe();
self.metrics.semaphores_unregistered.inc_by(count as u64);
}
}
}

View File

@@ -160,25 +160,27 @@ impl DynamicLimiter {
/// Try to acquire a concurrency [Token], waiting for `duration` if there are none available.
pub(crate) async fn acquire_timeout(
self: &Arc<Self>,
self: Arc<Self>,
duration: Duration,
) -> Result<Token, Elapsed> {
tokio::time::timeout(duration, self.acquire()).await?
}
/// Try to acquire a concurrency [Token].
async fn acquire(self: &Arc<Self>) -> Result<Token, Elapsed> {
async fn acquire(self: Arc<Self>) -> Result<Token, Elapsed> {
if self.config.initial_limit == 0 {
// If the rate limiter is disabled, we can always acquire a token.
Ok(Token::disabled())
} else {
return Ok(Token::disabled());
}
{
let mut notified = pin!(self.ready.notified());
let mut ready = notified.as_mut().enable();
loop {
if ready {
let mut inner = self.inner.lock();
if inner.take(&self.ready).is_some() {
break Ok(Token::new(self.clone()));
break;
}
notified.set(self.ready.notified());
}
@@ -186,6 +188,8 @@ impl DynamicLimiter {
ready = true;
}
}
Ok(Token::new(self))
}
/// Return the concurrency [Token], along with the outcome of the job.

View File

@@ -89,6 +89,7 @@ mod tests {
let limiter = DynamicLimiter::new(config);
let token = limiter
.clone()
.acquire_timeout(Duration::from_millis(1))
.await
.unwrap();
@@ -97,6 +98,7 @@ mod tests {
assert_eq!(limiter.state().limit(), 2);
let token = limiter
.clone()
.acquire_timeout(Duration::from_millis(1))
.await
.unwrap();
@@ -104,6 +106,7 @@ mod tests {
assert_eq!(limiter.state().limit(), 2);
let token = limiter
.clone()
.acquire_timeout(Duration::from_millis(1))
.await
.unwrap();
@@ -111,6 +114,7 @@ mod tests {
assert_eq!(limiter.state().limit(), 1);
let token = limiter
.clone()
.acquire_timeout(Duration::from_millis(1))
.await
.unwrap();
@@ -136,6 +140,7 @@ mod tests {
let limiter = DynamicLimiter::new(config);
let token = limiter
.clone()
.acquire_timeout(Duration::from_millis(100))
.await
.unwrap();
@@ -162,11 +167,13 @@ mod tests {
let limiter = DynamicLimiter::new(config);
let token = limiter
.clone()
.acquire_timeout(Duration::from_millis(1))
.await
.unwrap();
let now = tokio::time::Instant::now();
limiter
.clone()
.acquire_timeout(Duration::from_secs(1))
.await
.err()
@@ -197,14 +204,17 @@ mod tests {
let limiter = DynamicLimiter::new(config);
let token = limiter
.clone()
.acquire_timeout(Duration::from_millis(1))
.await
.unwrap();
let _token = limiter
.clone()
.acquire_timeout(Duration::from_millis(1))
.await
.unwrap();
let _token = limiter
.clone()
.acquire_timeout(Duration::from_millis(1))
.await
.unwrap();
@@ -231,6 +241,7 @@ mod tests {
let limiter = DynamicLimiter::new(config);
let token = limiter
.clone()
.acquire_timeout(Duration::from_millis(1))
.await
.unwrap();
@@ -261,6 +272,7 @@ mod tests {
let limiter = DynamicLimiter::new(config);
let token = limiter
.clone()
.acquire_timeout(Duration::from_millis(1))
.await
.unwrap();