diff --git a/proxy/src/rate_limiter.rs b/proxy/src/rate_limiter.rs index 222cd431d2..97752062ee 100644 --- a/proxy/src/rate_limiter.rs +++ b/proxy/src/rate_limiter.rs @@ -5,6 +5,4 @@ pub use limit_algorithm::{ }; pub use limiter::{BucketRateLimiter, GlobalRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; mod leaky_bucket; -pub use leaky_bucket::{ - EndpointRateLimiter, LeakyBucketConfig, LeakyBucketRateLimiter, LeakyBucketState, -}; +pub use leaky_bucket::{EndpointRateLimiter, LeakyBucketConfig, LeakyBucketRateLimiter}; diff --git a/proxy/src/rate_limiter/leaky_bucket.rs b/proxy/src/rate_limiter/leaky_bucket.rs index 2d5e056540..91d6967767 100644 --- a/proxy/src/rate_limiter/leaky_bucket.rs +++ b/proxy/src/rate_limiter/leaky_bucket.rs @@ -1,6 +1,7 @@ use std::{ hash::Hash, sync::atomic::{AtomicUsize, Ordering}, + time::Duration, }; use ahash::RandomState; @@ -16,7 +17,7 @@ pub type EndpointRateLimiter = LeakyBucketRateLimiter; pub struct LeakyBucketRateLimiter { map: DashMap, - config: LeakyBucketConfig, + config: LeakyBucketConfigInner, access_count: AtomicUsize, } @@ -29,7 +30,7 @@ impl LeakyBucketRateLimiter { pub fn new_with_shards(config: LeakyBucketConfig, shards: usize) -> Self { Self { map: DashMap::with_hasher_and_shard_amount(RandomState::new(), shards), - config, + config: config.into(), access_count: AtomicUsize::new(0), } } @@ -42,10 +43,10 @@ impl LeakyBucketRateLimiter { self.do_gc(now); } - let mut entry = self.map.entry(key).or_insert_with(|| LeakyBucketState { - time: now, - filled: 0.0, - }); + let mut entry = self + .map + .entry(key) + .or_insert_with(|| LeakyBucketState::new(now)); entry.check(&self.config, now, n as f64) } @@ -59,7 +60,7 @@ impl LeakyBucketRateLimiter { let shard = thread_rng().gen_range(0..n); self.map.shards()[shard] .write() - .retain(|_, value| !value.get_mut().update(&self.config, now)); + .retain(|_, value| value.get().should_retain(now)); } } @@ -68,11 +69,6 @@ pub struct LeakyBucketConfig { pub max: f64, } -pub struct LeakyBucketState { - filled: f64, - time: Instant, -} - impl LeakyBucketConfig { pub fn new(rps: f64, max: f64) -> Self { assert!(rps > 0.0, "rps must be positive"); @@ -81,40 +77,76 @@ impl LeakyBucketConfig { } } -impl LeakyBucketState { - pub fn new() -> Self { +struct LeakyBucketConfigInner { + /// "time cost" of a single request unit. + /// loosely represents how long it takes to handle a request unit in active CPU time. + time_cost: Duration, + bucket_width: Duration, +} + +impl From for LeakyBucketConfigInner { + fn from(config: LeakyBucketConfig) -> Self { + // seconds-per-request = 1/(request-per-second) + let spr = config.rps.recip(); Self { - filled: 0.0, - time: Instant::now(), + time_cost: Duration::from_secs_f64(spr), + bucket_width: Duration::from_secs_f64(config.max * spr), } } - - /// updates the timer and returns true if the bucket is empty - fn update(&mut self, info: &LeakyBucketConfig, now: Instant) -> bool { - let drain = now.duration_since(self.time); - let drain = drain.as_secs_f64() * info.rps; - - self.filled = (self.filled - drain).clamp(0.0, info.max); - self.time = now; - - self.filled == 0.0 - } - - pub fn check(&mut self, info: &LeakyBucketConfig, now: Instant, n: f64) -> bool { - self.update(info, now); - - if self.filled + n > info.max { - return false; - } - self.filled += n; - - true - } } -impl Default for LeakyBucketState { - fn default() -> Self { - Self::new() +struct LeakyBucketState { + /// Bucket is represented by `start..end` where `start = end - config.bucket_width`. + /// + /// At any given time, `end-now` represents the number of tokens in the bucket, multiplied by the "time_cost". + /// Adding `n` tokens to the bucket is done by moving `end` forward by `n * config.time_cost`. + /// If `now < start`, the bucket is considered filled and cannot accept any more tokens. + /// Draining the bucket will happen naturally as `now` moves forward. + /// + /// Let `n` be some "time cost" for the request, + /// If now is after end, the bucket is empty and the end is reset to now, + /// If now is within the `bucket window + n`, we are within time budget. + /// If now is before the `bucket window + n`, we have run out of budget. + /// + /// This is inspired by the generic cell rate algorithm (GCRA) and works + /// exactly the same as a leaky-bucket. + end: Instant, +} + +impl LeakyBucketState { + fn new(now: Instant) -> Self { + Self { end: now } + } + + fn should_retain(&self, now: Instant) -> bool { + // if self.end is after now, the bucket is not empty + now < self.end + } + + fn check(&mut self, config: &LeakyBucketConfigInner, now: Instant, n: f64) -> bool { + let start = self.end - config.bucket_width; + + let n = config.time_cost.mul_f64(n); + + // start end + // | start+n | end+n + // | / | / + // ------{o-[---------o-}--]----o---- + // now1 ^ now2 ^ ^ now3 + // + // at now1, the bucket would be completely filled if we add n tokens. + // at now2, the bucket would be partially filled if we add n tokens. + // at now3, the bucket would start completely empty before we add n tokens. + + if self.end + n <= now { + self.end = now + n; + true + } else if start + n <= now { + self.end += n; + true + } else { + false + } } } @@ -124,47 +156,50 @@ mod tests { use tokio::time::Instant; - use super::{LeakyBucketConfig, LeakyBucketState}; + use super::{LeakyBucketConfig, LeakyBucketConfigInner, LeakyBucketState}; #[tokio::test(start_paused = true)] async fn check() { - let info = LeakyBucketConfig::new(500.0, 2000.0); - let mut bucket = LeakyBucketState::new(); + let config: LeakyBucketConfigInner = LeakyBucketConfig::new(500.0, 2000.0).into(); + assert_eq!(config.time_cost, Duration::from_millis(2)); + assert_eq!(config.bucket_width, Duration::from_secs(4)); + + let mut bucket = LeakyBucketState::new(Instant::now()); // should work for 2000 requests this second for _ in 0..2000 { - assert!(bucket.check(&info, Instant::now(), 1.0)); + assert!(bucket.check(&config, Instant::now(), 1.0)); } - assert!(!bucket.check(&info, Instant::now(), 1.0)); - assert_eq!(bucket.filled, 2000.0); + assert!(!bucket.check(&config, Instant::now(), 1.0)); + assert_eq!(bucket.end - Instant::now(), config.bucket_width); // in 1ms we should drain 0.5 tokens. // make sure we don't lose any tokens tokio::time::advance(Duration::from_millis(1)).await; - assert!(!bucket.check(&info, Instant::now(), 1.0)); + assert!(!bucket.check(&config, Instant::now(), 1.0)); tokio::time::advance(Duration::from_millis(1)).await; - assert!(bucket.check(&info, Instant::now(), 1.0)); + assert!(bucket.check(&config, Instant::now(), 1.0)); // in 10ms we should drain 5 tokens tokio::time::advance(Duration::from_millis(10)).await; for _ in 0..5 { - assert!(bucket.check(&info, Instant::now(), 1.0)); + assert!(bucket.check(&config, Instant::now(), 1.0)); } - assert!(!bucket.check(&info, Instant::now(), 1.0)); + assert!(!bucket.check(&config, Instant::now(), 1.0)); // in 10s we should drain 5000 tokens // but cap is only 2000 tokio::time::advance(Duration::from_secs(10)).await; for _ in 0..2000 { - assert!(bucket.check(&info, Instant::now(), 1.0)); + assert!(bucket.check(&config, Instant::now(), 1.0)); } - assert!(!bucket.check(&info, Instant::now(), 1.0)); + assert!(!bucket.check(&config, Instant::now(), 1.0)); // should sustain 500rps for _ in 0..2000 { tokio::time::advance(Duration::from_millis(10)).await; for _ in 0..5 { - assert!(bucket.check(&info, Instant::now(), 1.0)); + assert!(bucket.check(&config, Instant::now(), 1.0)); } } }