simplify endpoint limiter (#6122)

## Problem

1. Using chrono for durations only is wasteful
2. The arc/mutex was not being utilised
3. Locking every shard in the dashmap every GC could cause latency
spikes
4. More buckets

## Summary of changes

1. Use `Instant` instead of `NaiveTime`.
2. Remove the `Arc<Mutex<_>>` wrapper, utilising that dashmap entry
returns mut access
3. Clear only a random shard, update gc interval accordingly
4. Multiple buckets can be checked before allowing access

When I benchmarked the check function, it took on average 811ns when
multithreaded over the course of 10 million checks.
This commit is contained in:
Conrad Ludgate
2023-12-13 13:53:23 +00:00
committed by GitHub
parent 8460654f61
commit c8316b7a3f
4 changed files with 94 additions and 41 deletions

View File

@@ -9,7 +9,7 @@ use crate::{
console::{self, errors::WakeComputeError, messages::MetricsAuxInfo, Api},
http::StatusCode,
protocol2::WithClientIp,
rate_limiter::EndpointRateLimiter,
rate_limiter::{EndpointRateLimiter, RateBucketInfo},
stream::{PqStream, Stream},
usage_metrics::{Ids, USAGE_METRICS},
};
@@ -308,7 +308,10 @@ pub async fn task_main(
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancel_map = Arc::new(CancelMap::default());
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(config.endpoint_rps_limit));
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new([RateBucketInfo::new(
config.endpoint_rps_limit,
time::Duration::from_secs(1),
)]));
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await

View File

@@ -3,5 +3,5 @@ mod limit_algorithm;
mod limiter;
pub use aimd::Aimd;
pub use limit_algorithm::{AimdConfig, Fixed, RateLimitAlgorithm, RateLimiterConfig};
pub use limiter::EndpointRateLimiter;
pub use limiter::Limiter;
pub use limiter::{EndpointRateLimiter, RateBucketInfo};

View File

@@ -1,16 +1,13 @@
use std::{
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
time::Duration,
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use dashmap::DashMap;
use parking_lot::Mutex;
use rand::{thread_rng, Rng};
use smol_str::SmolStr;
use tokio::sync::{Mutex as AsyncMutex, Semaphore, SemaphorePermit};
use tokio::time::{timeout, Instant};
use tokio::time::{timeout, Duration, Instant};
use tracing::info;
use super::{
@@ -32,57 +29,106 @@ use super::{
//
// TODO: add a better bucketing here, e.g. not more than 300 requests per second,
// and not more than 1000 requests per 10 seconds, etc. Short bursts of reconnects
// are noramal during redeployments, so we should not block them.
// are normal during redeployments, so we should not block them.
pub struct EndpointRateLimiter {
map: DashMap<SmolStr, Arc<Mutex<(chrono::NaiveTime, u32)>>>,
max_rps: u32,
map: DashMap<SmolStr, Vec<RateBucket>>,
info: Vec<RateBucketInfo>,
access_count: AtomicUsize,
}
impl EndpointRateLimiter {
pub fn new(max_rps: u32) -> Self {
#[derive(Clone, Copy)]
struct RateBucket {
start: Instant,
count: u32,
}
impl RateBucket {
fn should_allow_request(&mut self, info: &RateBucketInfo, now: Instant) -> bool {
if now - self.start < info.interval {
self.count < info.max_rpi
} else {
// bucket expired, reset
self.count = 0;
self.start = now;
true
}
}
fn inc(&mut self) {
self.count += 1;
}
}
pub struct RateBucketInfo {
interval: Duration,
// requests per interval
max_rpi: u32,
}
impl RateBucketInfo {
pub fn new(max_rps: u32, interval: Duration) -> Self {
Self {
map: DashMap::new(),
max_rps,
interval,
max_rpi: max_rps * 1000 / interval.as_millis() as u32,
}
}
}
impl EndpointRateLimiter {
pub fn new(info: impl IntoIterator<Item = RateBucketInfo>) -> Self {
Self {
info: info.into_iter().collect(),
map: DashMap::with_shard_amount(64),
access_count: AtomicUsize::new(1), // start from 1 to avoid GC on the first request
}
}
/// Check that number of connections to the endpoint is below `max_rps` rps.
pub fn check(&self, endpoint: SmolStr) -> bool {
// do GC every 100k requests (worst case memory usage is about 10MB)
if self.access_count.fetch_add(1, Ordering::AcqRel) % 100_000 == 0 {
// do a partial GC every 2k requests. This cleans up ~ 1/64th of the map.
// worst case memory usage is about:
// = 2 * 2048 * 64 * (48B + 72B)
// = 30MB
if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
self.do_gc();
}
let now = chrono::Utc::now().naive_utc().time();
let entry = self
.map
.entry(endpoint)
.or_insert_with(|| Arc::new(Mutex::new((now, 0))));
let mut entry = entry.lock();
let (last_time, count) = *entry;
let now = Instant::now();
let mut entry = self.map.entry(endpoint).or_insert_with(|| {
vec![
RateBucket {
start: now,
count: 0,
};
self.info.len()
]
});
if now - last_time < chrono::Duration::seconds(1) {
if count >= self.max_rps {
return false;
}
*entry = (last_time, count + 1);
} else {
*entry = (now, 1);
let should_allow_request = entry
.iter_mut()
.zip(&self.info)
.all(|(bucket, info)| bucket.should_allow_request(info, now));
if should_allow_request {
// only increment the bucket counts if the request will actually be accepted
entry.iter_mut().for_each(RateBucket::inc);
}
true
should_allow_request
}
/// Clean the map. Simple strategy: remove all entries. At worst, we'll
/// double the effective max_rps during the cleanup. But that way deletion
/// does not aquire mutex on each entry access.
/// Clean the map. Simple strategy: remove all entries in a random shard.
/// At worst, we'll double the effective max_rps during the cleanup.
/// But that way deletion does not aquire mutex on each entry access.
pub fn do_gc(&self) {
info!(
"cleaning up endpoint rate limiter, current size = {}",
self.map.len()
);
self.map.clear();
let n = self.map.shards().len();
let shard = thread_rng().gen_range(0..n);
self.map.shards()[shard].write().clear();
}
}

View File

@@ -10,11 +10,12 @@ use anyhow::bail;
use hyper::StatusCode;
pub use reqwest_middleware::{ClientWithMiddleware, Error};
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use tokio::time;
use tokio_util::task::TaskTracker;
use crate::protocol2::{ProxyProtocolAccept, WithClientIp};
use crate::proxy::{NUM_CLIENT_CONNECTION_CLOSED_COUNTER, NUM_CLIENT_CONNECTION_OPENED_COUNTER};
use crate::rate_limiter::EndpointRateLimiter;
use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo};
use crate::{cancellation::CancelMap, config::ProxyConfig};
use futures::StreamExt;
use hyper::{
@@ -44,7 +45,10 @@ pub async fn task_main(
}
let conn_pool = conn_pool::GlobalConnPool::new(config);
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(config.endpoint_rps_limit));
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new([RateBucketInfo::new(
config.endpoint_rps_limit,
time::Duration::from_secs(1),
)]));
// shutdown the connection pool
tokio::spawn({