diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index f757a15fbb..67c4dd019e 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -717,8 +717,10 @@ mod tests { _ => panic!("wrong message"), } }); - let endpoint_rate_limiter = - Arc::new(EndpointRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET)); + let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards( + EndpointRateLimiter::DEFAULT, + 64, + )); let _creds = auth_quirks( &mut ctx, @@ -767,8 +769,10 @@ mod tests { frontend::password_message(b"my-secret-password", &mut write).unwrap(); client.write_all(&write).await.unwrap(); }); - let endpoint_rate_limiter = - Arc::new(EndpointRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET)); + let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards( + EndpointRateLimiter::DEFAULT, + 64, + )); let _creds = auth_quirks( &mut ctx, @@ -818,8 +822,10 @@ mod tests { client.write_all(&write).await.unwrap(); }); - let endpoint_rate_limiter = - Arc::new(EndpointRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET)); + let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards( + EndpointRateLimiter::DEFAULT, + 64, + )); let creds = auth_quirks( &mut ctx, diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 7314710508..c1fd6dfd80 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -22,7 +22,9 @@ use proxy::http; use proxy::http::health_server::AppMetrics; use proxy::metrics::Metrics; use proxy::rate_limiter::EndpointRateLimiter; +use proxy::rate_limiter::LeakyBucketConfig; use proxy::rate_limiter::RateBucketInfo; +use proxy::rate_limiter::WakeComputeRateLimiter; use proxy::redis::cancellation_publisher::RedisPublisherClient; use proxy::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use proxy::redis::elasticache; @@ -390,9 +392,24 @@ async fn main() -> anyhow::Result<()> { proxy::metrics::CancellationSource::FromClient, )); - let mut endpoint_rps_limit = args.endpoint_rps_limit.clone(); - RateBucketInfo::validate(&mut endpoint_rps_limit)?; - let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(endpoint_rps_limit)); + // bit of a hack - find the min rps and max rps supported and turn it into + // leaky bucket config instead + let max = args + .endpoint_rps_limit + .iter() + .map(|x| x.rps()) + .max_by(f64::total_cmp) + .unwrap_or(EndpointRateLimiter::DEFAULT.max); + let rps = args + .endpoint_rps_limit + .iter() + .map(|x| x.rps()) + .min_by(f64::total_cmp) + .unwrap_or(EndpointRateLimiter::DEFAULT.rps); + let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards( + LeakyBucketConfig { rps, max }, + 64, + )); // client facing tasks. these will exit on error or on cancellation // cancellation returns Ok(()) @@ -594,7 +611,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { let mut wake_compute_rps_limit = args.wake_compute_limit.clone(); RateBucketInfo::validate(&mut wake_compute_rps_limit)?; let wake_compute_endpoint_rate_limiter = - Arc::new(EndpointRateLimiter::new(wake_compute_rps_limit)); + Arc::new(WakeComputeRateLimiter::new(wake_compute_rps_limit)); let api = console::provider::neon::Api::new( endpoint, caches, diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index a6e67be22f..768cd2fdfa 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -12,7 +12,7 @@ use crate::{ console::messages::{ColdStartInfo, Reason}, http, metrics::{CacheOutcome, Metrics}, - rate_limiter::EndpointRateLimiter, + rate_limiter::WakeComputeRateLimiter, scram, EndpointCacheKey, }; use crate::{cache::Cached, context::RequestMonitoring}; @@ -26,7 +26,7 @@ pub struct Api { endpoint: http::Endpoint, pub caches: &'static ApiCaches, pub locks: &'static ApiLocks, - pub wake_compute_endpoint_rate_limiter: Arc, + pub wake_compute_endpoint_rate_limiter: Arc, jwt: String, } @@ -36,7 +36,7 @@ impl Api { endpoint: http::Endpoint, caches: &'static ApiCaches, locks: &'static ApiLocks, - wake_compute_endpoint_rate_limiter: Arc, + wake_compute_endpoint_rate_limiter: Arc, ) -> Self { let jwt: String = match std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN") { Ok(v) => v, diff --git a/proxy/src/rate_limiter.rs b/proxy/src/rate_limiter.rs index be9072dd8c..222cd431d2 100644 --- a/proxy/src/rate_limiter.rs +++ b/proxy/src/rate_limiter.rs @@ -3,4 +3,8 @@ mod limiter; pub use limit_algorithm::{ aimd::Aimd, DynamicLimiter, Outcome, RateLimitAlgorithm, RateLimiterConfig, Token, }; -pub use limiter::{BucketRateLimiter, EndpointRateLimiter, GlobalRateLimiter, RateBucketInfo}; +pub use limiter::{BucketRateLimiter, GlobalRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; +mod leaky_bucket; +pub use leaky_bucket::{ + EndpointRateLimiter, LeakyBucketConfig, LeakyBucketRateLimiter, LeakyBucketState, +}; diff --git a/proxy/src/rate_limiter/leaky_bucket.rs b/proxy/src/rate_limiter/leaky_bucket.rs new file mode 100644 index 0000000000..2d5e056540 --- /dev/null +++ b/proxy/src/rate_limiter/leaky_bucket.rs @@ -0,0 +1,171 @@ +use std::{ + hash::Hash, + sync::atomic::{AtomicUsize, Ordering}, +}; + +use ahash::RandomState; +use dashmap::DashMap; +use rand::{thread_rng, Rng}; +use tokio::time::Instant; +use tracing::info; + +use crate::intern::EndpointIdInt; + +// Simple per-endpoint rate limiter. +pub type EndpointRateLimiter = LeakyBucketRateLimiter; + +pub struct LeakyBucketRateLimiter { + map: DashMap, + config: LeakyBucketConfig, + access_count: AtomicUsize, +} + +impl LeakyBucketRateLimiter { + pub const DEFAULT: LeakyBucketConfig = LeakyBucketConfig { + rps: 600.0, + max: 1500.0, + }; + + pub fn new_with_shards(config: LeakyBucketConfig, shards: usize) -> Self { + Self { + map: DashMap::with_hasher_and_shard_amount(RandomState::new(), shards), + config, + access_count: AtomicUsize::new(0), + } + } + + /// Check that number of connections to the endpoint is below `max_rps` rps. + pub fn check(&self, key: K, n: u32) -> bool { + let now = Instant::now(); + + if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 { + self.do_gc(now); + } + + let mut entry = self.map.entry(key).or_insert_with(|| LeakyBucketState { + time: now, + filled: 0.0, + }); + + entry.check(&self.config, now, n as f64) + } + + fn do_gc(&self, now: Instant) { + info!( + "cleaning up bucket rate limiter, current size = {}", + self.map.len() + ); + let n = self.map.shards().len(); + let shard = thread_rng().gen_range(0..n); + self.map.shards()[shard] + .write() + .retain(|_, value| !value.get_mut().update(&self.config, now)); + } +} + +pub struct LeakyBucketConfig { + pub rps: f64, + pub max: f64, +} + +pub struct LeakyBucketState { + filled: f64, + time: Instant, +} + +impl LeakyBucketConfig { + pub fn new(rps: f64, max: f64) -> Self { + assert!(rps > 0.0, "rps must be positive"); + assert!(max > 0.0, "max must be positive"); + Self { rps, max } + } +} + +impl LeakyBucketState { + pub fn new() -> Self { + Self { + filled: 0.0, + time: Instant::now(), + } + } + + /// updates the timer and returns true if the bucket is empty + fn update(&mut self, info: &LeakyBucketConfig, now: Instant) -> bool { + let drain = now.duration_since(self.time); + let drain = drain.as_secs_f64() * info.rps; + + self.filled = (self.filled - drain).clamp(0.0, info.max); + self.time = now; + + self.filled == 0.0 + } + + pub fn check(&mut self, info: &LeakyBucketConfig, now: Instant, n: f64) -> bool { + self.update(info, now); + + if self.filled + n > info.max { + return false; + } + self.filled += n; + + true + } +} + +impl Default for LeakyBucketState { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use tokio::time::Instant; + + use super::{LeakyBucketConfig, LeakyBucketState}; + + #[tokio::test(start_paused = true)] + async fn check() { + let info = LeakyBucketConfig::new(500.0, 2000.0); + let mut bucket = LeakyBucketState::new(); + + // should work for 2000 requests this second + for _ in 0..2000 { + assert!(bucket.check(&info, Instant::now(), 1.0)); + } + assert!(!bucket.check(&info, Instant::now(), 1.0)); + assert_eq!(bucket.filled, 2000.0); + + // in 1ms we should drain 0.5 tokens. + // make sure we don't lose any tokens + tokio::time::advance(Duration::from_millis(1)).await; + assert!(!bucket.check(&info, Instant::now(), 1.0)); + tokio::time::advance(Duration::from_millis(1)).await; + assert!(bucket.check(&info, Instant::now(), 1.0)); + + // in 10ms we should drain 5 tokens + tokio::time::advance(Duration::from_millis(10)).await; + for _ in 0..5 { + assert!(bucket.check(&info, Instant::now(), 1.0)); + } + assert!(!bucket.check(&info, Instant::now(), 1.0)); + + // in 10s we should drain 5000 tokens + // but cap is only 2000 + tokio::time::advance(Duration::from_secs(10)).await; + for _ in 0..2000 { + assert!(bucket.check(&info, Instant::now(), 1.0)); + } + assert!(!bucket.check(&info, Instant::now(), 1.0)); + + // should sustain 500rps + for _ in 0..2000 { + tokio::time::advance(Duration::from_millis(10)).await; + for _ in 0..5 { + assert!(bucket.check(&info, Instant::now(), 1.0)); + } + } + } +} diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index b8c9490696..5db4efed37 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -61,7 +61,7 @@ impl GlobalRateLimiter { // Purposefully ignore user name and database name as clients can reconnect // with different names, so we'll end up sending some http requests to // the control plane. -pub type EndpointRateLimiter = BucketRateLimiter; +pub type WakeComputeRateLimiter = BucketRateLimiter; pub struct BucketRateLimiter { map: DashMap, Hasher>, @@ -103,7 +103,7 @@ pub struct RateBucketInfo { impl std::fmt::Display for RateBucketInfo { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let rps = (self.max_rpi as u64) * 1000 / self.interval.as_millis() as u64; + let rps = self.rps().floor() as u64; write!(f, "{rps}@{}", humantime::format_duration(self.interval)) } } @@ -140,6 +140,10 @@ impl RateBucketInfo { Self::new(200, Duration::from_secs(600)), ]; + pub fn rps(&self) -> f64 { + (self.max_rpi as f64) / self.interval.as_secs_f64() + } + pub fn validate(info: &mut [Self]) -> anyhow::Result<()> { info.sort_unstable_by_key(|info| info.interval); let invalid = info @@ -245,7 +249,7 @@ mod tests { use rustc_hash::FxHasher; use tokio::time; - use super::{BucketRateLimiter, EndpointRateLimiter}; + use super::{BucketRateLimiter, WakeComputeRateLimiter}; use crate::{intern::EndpointIdInt, rate_limiter::RateBucketInfo, EndpointId}; #[test] @@ -293,7 +297,7 @@ mod tests { .map(|s| s.parse().unwrap()) .collect(); RateBucketInfo::validate(&mut rates).unwrap(); - let limiter = EndpointRateLimiter::new(rates); + let limiter = WakeComputeRateLimiter::new(rates); let endpoint = EndpointId::from("ep-my-endpoint-1234"); let endpoint = EndpointIdInt::from(endpoint);