diff --git a/Cargo.lock b/Cargo.lock index 2b56095bc8..773a000898 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2838,17 +2838,6 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" -[[package]] -name = "leaky-bucket" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8eb491abd89e9794d50f93c8db610a29509123e3fbbc9c8c67a528e9391cd853" -dependencies = [ - "parking_lot 0.12.1", - "tokio", - "tracing", -] - [[package]] name = "libc" version = "0.2.150" @@ -3575,7 +3564,6 @@ dependencies = [ "humantime-serde", "hyper 0.14.26", "itertools", - "leaky-bucket", "md5", "metrics", "nix 0.27.1", @@ -6777,7 +6765,6 @@ dependencies = [ "humantime", "hyper 0.14.26", "jsonwebtoken", - "leaky-bucket", "metrics", "nix 0.27.1", "once_cell", diff --git a/Cargo.toml b/Cargo.toml index 7749378114..e1f11d5620 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -107,7 +107,6 @@ ipnet = "2.9.0" itertools = "0.10" jsonwebtoken = "9" lasso = "0.7" -leaky-bucket = "1.0.1" libc = "0.2" md5 = "0.7.0" measured = { version = "0.0.22", features=["lasso"] } diff --git a/libs/utils/Cargo.toml b/libs/utils/Cargo.toml index ec05f849cf..6b86de12f4 100644 --- a/libs/utils/Cargo.toml +++ b/libs/utils/Cargo.toml @@ -26,7 +26,6 @@ hyper = { workspace = true, features = ["full"] } fail.workspace = true futures = { workspace = true} jsonwebtoken.workspace = true -leaky-bucket.workspace = true nix.workspace = true once_cell.workspace = true pin-project-lite.workspace = true diff --git a/pageserver/Cargo.toml b/pageserver/Cargo.toml index 0d9343d643..f2e2f1ee2e 100644 --- a/pageserver/Cargo.toml +++ b/pageserver/Cargo.toml @@ -36,7 +36,6 @@ humantime.workspace = true humantime-serde.workspace = true hyper.workspace = true itertools.workspace = true -leaky-bucket.workspace = true md5.workspace = true nix.workspace = true # hack to get the number of worker threads tokio uses diff --git a/pageserver/src/tenant/throttle.rs b/pageserver/src/tenant/throttle.rs index f3f3d5e3ae..a2f3649e9f 100644 --- a/pageserver/src/tenant/throttle.rs +++ b/pageserver/src/tenant/throttle.rs @@ -9,6 +9,7 @@ use std::{ use arc_swap::ArcSwap; use enumset::EnumSet; +use tokio::sync::Notify; use tracing::{error, warn}; use crate::{context::RequestContext, task_mgr::TaskKind}; @@ -33,7 +34,7 @@ pub struct Throttle { pub struct Inner { task_kinds: EnumSet, - rate_limiter: Arc, + rate_limiter: Arc, config: Config, } @@ -96,13 +97,14 @@ where Inner { task_kinds, rate_limiter: Arc::new( - leaky_bucket::RateLimiter::builder() - .initial(*initial) - .interval(*refill_interval) - .refill(refill_amount.get()) - .max(*max) - .fair(*fair) - .build(), + RateLimiterBuilder { + initial: *initial, + interval: *refill_interval, + refill: refill_amount.get(), + max: *max, + fair: *fair, + } + .build(), ), config, } @@ -136,18 +138,9 @@ where return None; }; let start = std::time::Instant::now(); - let mut did_throttle = false; - let acquire = inner.rate_limiter.acquire(key_count); - // turn off runtime-induced preemption (aka coop) so our `did_throttle` is accurate - let acquire = tokio::task::unconstrained(acquire); - let mut acquire = std::pin::pin!(acquire); - std::future::poll_fn(|cx| { - use std::future::Future; - let poll = acquire.as_mut().poll(cx); - did_throttle = did_throttle || poll.is_pending(); - poll - }) - .await; + + let did_throttle = !inner.rate_limiter.acquire(key_count).await; + self.count_accounted.fetch_add(1, Ordering::Relaxed); if did_throttle { self.count_throttled.fetch_add(1, Ordering::Relaxed); @@ -176,3 +169,117 @@ where } } } + +struct RateLimiter { + /// "time cost" of a single request unit. + /// loosely represents how long it takes to handle a request unit in active CPU time. + time_cost: Duration, + + bucket_width: Duration, + + /// Bucket is represented by `start..end` where `start = end - config.bucket_width`. + /// + /// At any given time, `end - now` represents the number of tokens in the bucket, multiplied by the "time_cost". + /// Adding `n` tokens to the bucket is done by moving `end` forward by `n * config.time_cost`. + /// If `now < start`, the bucket is considered filled and cannot accept any more tokens. + /// Draining the bucket will happen naturally as `now` moves forward. + /// + /// Let `n` be some "time cost" for the request, + /// If now is after end, the bucket is empty and the end is reset to now, + /// If now is within the `bucket window + n`, we are within time budget. + /// If now is before the `bucket window + n`, we have run out of budget. + /// + /// This is inspired by the generic cell rate algorithm (GCRA) and works + /// exactly the same as a leaky-bucket. + end: Mutex, + + queue: Option, +} + +struct RateLimiterBuilder { + /// The max number of tokens. + max: usize, + /// The initial count of tokens. + initial: usize, + /// Tokens to add every `per` duration. + refill: usize, + /// Interval to add tokens in milliseconds. + interval: Duration, + /// If the rate limiter is fair or not. + fair: bool, +} + +impl RateLimiterBuilder { + fn build(self) -> RateLimiter { + let queue = self.fair.then(Notify::new); + + let time_cost = self.interval / self.refill as u32; + let bucket_width = time_cost * (self.max as u32); + let initial_allow = time_cost * (self.initial as u32); + let end = tokio::time::Instant::now() + bucket_width - initial_allow; + + RateLimiter { + time_cost, + bucket_width, + end: Mutex::new(end), + queue, + } + } +} + +impl RateLimiter { + /// returns true if not throttled + async fn acquire(&self, count: usize) -> bool { + let mut not_throttled = true; + + let n = self.time_cost.mul_f64(count as f64); + + // wait until we are the first in the queue + if let Some(queue) = &self.queue { + let mut notified = std::pin::pin!(queue.notified()); + if !notified.as_mut().enable() { + not_throttled = false; + notified.await; + } + } + + // notify the next waiter in the queue + scopeguard::defer! { + if let Some(queue) = &self.queue { + queue.notify_one(); + } + }; + + loop { + let now = tokio::time::Instant::now(); + let ready_at = { + // start end + // | start+n | end+n + // | / | / + // ------{o-[---------o-}--]----o---- + // now1 ^ now2 ^ ^ now3 + // + // at now1, the bucket would be completely filled if we add n tokens. + // at now2, the bucket would be partially filled if we add n tokens. + // at now3, the bucket would start completely empty before we add n tokens. + + let mut end = self.end.lock().unwrap(); + let start = *end - self.bucket_width; + let ready_at = start + n; + + if *end + n <= now { + *end = now + n; + return not_throttled; + } else if ready_at <= now { + *end += n; + return not_throttled; + } + + ready_at + }; + + not_throttled = false; + tokio::time::sleep_until(ready_at).await; + } + } +}