diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index ae8b294841..72ebc5d3be 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -9,7 +9,7 @@ use crate::{ console::{self, errors::WakeComputeError, messages::MetricsAuxInfo, Api}, http::StatusCode, protocol2::WithClientIp, - rate_limiter::EndpointRateLimiter, + rate_limiter::{EndpointRateLimiter, RateBucketInfo}, stream::{PqStream, Stream}, usage_metrics::{Ids, USAGE_METRICS}, }; @@ -308,7 +308,10 @@ pub async fn task_main( let connections = tokio_util::task::task_tracker::TaskTracker::new(); let cancel_map = Arc::new(CancelMap::default()); - let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(config.endpoint_rps_limit)); + let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new([RateBucketInfo::new( + config.endpoint_rps_limit, + time::Duration::from_secs(1), + )])); while let Some(accept_result) = run_until_cancelled(listener.accept(), &cancellation_token).await diff --git a/proxy/src/rate_limiter.rs b/proxy/src/rate_limiter.rs index f40b8dbd1c..b26386d159 100644 --- a/proxy/src/rate_limiter.rs +++ b/proxy/src/rate_limiter.rs @@ -3,5 +3,5 @@ mod limit_algorithm; mod limiter; pub use aimd::Aimd; pub use limit_algorithm::{AimdConfig, Fixed, RateLimitAlgorithm, RateLimiterConfig}; -pub use limiter::EndpointRateLimiter; pub use limiter::Limiter; +pub use limiter::{EndpointRateLimiter, RateBucketInfo}; diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index 9d28bb67b3..e493082796 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -1,16 +1,13 @@ -use std::{ - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, - time::Duration, +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, }; use dashmap::DashMap; -use parking_lot::Mutex; +use rand::{thread_rng, Rng}; use smol_str::SmolStr; use tokio::sync::{Mutex as AsyncMutex, Semaphore, SemaphorePermit}; -use tokio::time::{timeout, Instant}; +use tokio::time::{timeout, Duration, Instant}; use tracing::info; use super::{ @@ -32,57 +29,106 @@ use super::{ // // TODO: add a better bucketing here, e.g. not more than 300 requests per second, // and not more than 1000 requests per 10 seconds, etc. Short bursts of reconnects -// are noramal during redeployments, so we should not block them. +// are normal during redeployments, so we should not block them. pub struct EndpointRateLimiter { - map: DashMap>>, - max_rps: u32, + map: DashMap>, + info: Vec, access_count: AtomicUsize, } -impl EndpointRateLimiter { - pub fn new(max_rps: u32) -> Self { +#[derive(Clone, Copy)] +struct RateBucket { + start: Instant, + count: u32, +} + +impl RateBucket { + fn should_allow_request(&mut self, info: &RateBucketInfo, now: Instant) -> bool { + if now - self.start < info.interval { + self.count < info.max_rpi + } else { + // bucket expired, reset + self.count = 0; + self.start = now; + + true + } + } + + fn inc(&mut self) { + self.count += 1; + } +} + +pub struct RateBucketInfo { + interval: Duration, + // requests per interval + max_rpi: u32, +} + +impl RateBucketInfo { + pub fn new(max_rps: u32, interval: Duration) -> Self { Self { - map: DashMap::new(), - max_rps, + interval, + max_rpi: max_rps * 1000 / interval.as_millis() as u32, + } + } +} + +impl EndpointRateLimiter { + pub fn new(info: impl IntoIterator) -> Self { + Self { + info: info.into_iter().collect(), + map: DashMap::with_shard_amount(64), access_count: AtomicUsize::new(1), // start from 1 to avoid GC on the first request } } /// Check that number of connections to the endpoint is below `max_rps` rps. pub fn check(&self, endpoint: SmolStr) -> bool { - // do GC every 100k requests (worst case memory usage is about 10MB) - if self.access_count.fetch_add(1, Ordering::AcqRel) % 100_000 == 0 { + // do a partial GC every 2k requests. This cleans up ~ 1/64th of the map. + // worst case memory usage is about: + // = 2 * 2048 * 64 * (48B + 72B) + // = 30MB + if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 { self.do_gc(); } - let now = chrono::Utc::now().naive_utc().time(); - let entry = self - .map - .entry(endpoint) - .or_insert_with(|| Arc::new(Mutex::new((now, 0)))); - let mut entry = entry.lock(); - let (last_time, count) = *entry; + let now = Instant::now(); + let mut entry = self.map.entry(endpoint).or_insert_with(|| { + vec![ + RateBucket { + start: now, + count: 0, + }; + self.info.len() + ] + }); - if now - last_time < chrono::Duration::seconds(1) { - if count >= self.max_rps { - return false; - } - *entry = (last_time, count + 1); - } else { - *entry = (now, 1); + let should_allow_request = entry + .iter_mut() + .zip(&self.info) + .all(|(bucket, info)| bucket.should_allow_request(info, now)); + + if should_allow_request { + // only increment the bucket counts if the request will actually be accepted + entry.iter_mut().for_each(RateBucket::inc); } - true + + should_allow_request } - /// Clean the map. Simple strategy: remove all entries. At worst, we'll - /// double the effective max_rps during the cleanup. But that way deletion - /// does not aquire mutex on each entry access. + /// Clean the map. Simple strategy: remove all entries in a random shard. + /// At worst, we'll double the effective max_rps during the cleanup. + /// But that way deletion does not aquire mutex on each entry access. pub fn do_gc(&self) { info!( "cleaning up endpoint rate limiter, current size = {}", self.map.len() ); - self.map.clear(); + let n = self.map.shards().len(); + let shard = thread_rng().gen_range(0..n); + self.map.shards()[shard].write().clear(); } } diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index 92d6e2d851..daac396ed6 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -10,11 +10,12 @@ use anyhow::bail; use hyper::StatusCode; pub use reqwest_middleware::{ClientWithMiddleware, Error}; pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; +use tokio::time; use tokio_util::task::TaskTracker; use crate::protocol2::{ProxyProtocolAccept, WithClientIp}; use crate::proxy::{NUM_CLIENT_CONNECTION_CLOSED_COUNTER, NUM_CLIENT_CONNECTION_OPENED_COUNTER}; -use crate::rate_limiter::EndpointRateLimiter; +use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo}; use crate::{cancellation::CancelMap, config::ProxyConfig}; use futures::StreamExt; use hyper::{ @@ -44,7 +45,10 @@ pub async fn task_main( } let conn_pool = conn_pool::GlobalConnPool::new(config); - let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(config.endpoint_rps_limit)); + let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new([RateBucketInfo::new( + config.endpoint_rps_limit, + time::Duration::from_secs(1), + )])); // shutdown the connection pool tokio::spawn({