mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-10 23:12:54 +00:00
proxy: add more rates to endpoint limiter (#6130)
## Problem Single rate bucket is limited in usefulness ## Summary of changes Introduce a secondary bucket allowing an average of 200 requests per second over 1 minute, and a tertiary bucket allowing an average of 100 requests per second over 10 minutes. Configured by using a format like ```sh proxy --endpoint-rps-limit 300@1s --endpoint-rps-limit 100@10s --endpoint-rps-limit 50@1m ``` If the bucket limits are inconsistent, an error is returned on startup ``` $ proxy --endpoint-rps-limit 300@1s --endpoint-rps-limit 10@10s Error: invalid endpoint RPS limits. 10@10s allows fewer requests per bucket than 300@1s (100 vs 300) ```
This commit is contained in:
@@ -7,6 +7,8 @@ use proxy::console;
|
||||
use proxy::console::provider::AllowedIpsCache;
|
||||
use proxy::console::provider::NodeInfoCache;
|
||||
use proxy::http;
|
||||
use proxy::rate_limiter::EndpointRateLimiter;
|
||||
use proxy::rate_limiter::RateBucketInfo;
|
||||
use proxy::rate_limiter::RateLimiterConfig;
|
||||
use proxy::usage_metrics;
|
||||
|
||||
@@ -14,6 +16,7 @@ use anyhow::bail;
|
||||
use proxy::config::{self, ProxyConfig};
|
||||
use proxy::serverless;
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
use std::{borrow::Cow, net::SocketAddr};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::task::JoinSet;
|
||||
@@ -113,8 +116,11 @@ struct ProxyCliArgs {
|
||||
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
|
||||
rate_limiter_timeout: tokio::time::Duration,
|
||||
/// Endpoint rate limiter max number of requests per second.
|
||||
#[clap(long, default_value_t = 300)]
|
||||
endpoint_rps_limit: u32,
|
||||
///
|
||||
/// Provided in the form '<Requests Per Second>@<Bucket Duration Size>'.
|
||||
/// Can be given multiple times for different bucket sizes.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
|
||||
endpoint_rps_limit: Vec<RateBucketInfo>,
|
||||
/// Initial limit for dynamic rate limiter. Makes sense only if `rate_limit_algorithm` is *not* `None`.
|
||||
#[clap(long, default_value_t = 100)]
|
||||
initial_limit: usize,
|
||||
@@ -157,6 +163,8 @@ async fn main() -> anyhow::Result<()> {
|
||||
let proxy_listener = TcpListener::bind(proxy_address).await?;
|
||||
let cancellation_token = CancellationToken::new();
|
||||
|
||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(&config.endpoint_rps_limit));
|
||||
|
||||
// client facing tasks. these will exit on error or on cancellation
|
||||
// cancellation returns Ok(())
|
||||
let mut client_tasks = JoinSet::new();
|
||||
@@ -164,6 +172,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
config,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
));
|
||||
|
||||
// TODO: rename the argument to something like serverless.
|
||||
@@ -177,6 +186,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
config,
|
||||
serverless_listener,
|
||||
cancellation_token.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
@@ -311,6 +321,10 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
let authentication_config = AuthenticationConfig {
|
||||
scram_protocol_timeout: args.scram_protocol_timeout,
|
||||
};
|
||||
|
||||
let mut endpoint_rps_limit = args.endpoint_rps_limit.clone();
|
||||
RateBucketInfo::validate(&mut endpoint_rps_limit)?;
|
||||
|
||||
let config = Box::leak(Box::new(ProxyConfig {
|
||||
tls_config,
|
||||
auth_backend,
|
||||
@@ -320,8 +334,35 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
authentication_config,
|
||||
require_client_ip: args.require_client_ip,
|
||||
disable_ip_check_for_http: args.disable_ip_check_for_http,
|
||||
endpoint_rps_limit: args.endpoint_rps_limit,
|
||||
endpoint_rps_limit,
|
||||
}));
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use clap::Parser;
|
||||
use proxy::rate_limiter::RateBucketInfo;
|
||||
|
||||
#[test]
|
||||
fn parse_endpoint_rps_limit() {
|
||||
let config = super::ProxyCliArgs::parse_from([
|
||||
"proxy",
|
||||
"--endpoint-rps-limit",
|
||||
"100@1s",
|
||||
"--endpoint-rps-limit",
|
||||
"20@30s",
|
||||
]);
|
||||
|
||||
assert_eq!(
|
||||
config.endpoint_rps_limit,
|
||||
vec![
|
||||
RateBucketInfo::new(100, Duration::from_secs(1)),
|
||||
RateBucketInfo::new(20, Duration::from_secs(30)),
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::auth;
|
||||
use crate::{auth, rate_limiter::RateBucketInfo};
|
||||
use anyhow::{bail, ensure, Context, Ok};
|
||||
use rustls::{sign, Certificate, PrivateKey};
|
||||
use sha2::{Digest, Sha256};
|
||||
@@ -20,7 +20,7 @@ pub struct ProxyConfig {
|
||||
pub authentication_config: AuthenticationConfig,
|
||||
pub require_client_ip: bool,
|
||||
pub disable_ip_check_for_http: bool,
|
||||
pub endpoint_rps_limit: u32,
|
||||
pub endpoint_rps_limit: Vec<RateBucketInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
||||
@@ -9,7 +9,7 @@ use crate::{
|
||||
console::{self, errors::WakeComputeError, messages::MetricsAuxInfo, Api},
|
||||
http::StatusCode,
|
||||
protocol2::WithClientIp,
|
||||
rate_limiter::{EndpointRateLimiter, RateBucketInfo},
|
||||
rate_limiter::EndpointRateLimiter,
|
||||
stream::{PqStream, Stream},
|
||||
usage_metrics::{Ids, USAGE_METRICS},
|
||||
};
|
||||
@@ -297,6 +297,7 @@ pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("proxy has shut down");
|
||||
@@ -308,10 +309,6 @@ pub async fn task_main(
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
let cancel_map = Arc::new(CancelMap::default());
|
||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new([RateBucketInfo::new(
|
||||
config.endpoint_rps_limit,
|
||||
time::Duration::from_secs(1),
|
||||
)]));
|
||||
|
||||
while let Some(accept_result) =
|
||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||
|
||||
@@ -3,7 +3,9 @@ use std::sync::{
|
||||
Arc,
|
||||
};
|
||||
|
||||
use anyhow::bail;
|
||||
use dashmap::DashMap;
|
||||
use itertools::Itertools;
|
||||
use rand::{thread_rng, Rng};
|
||||
use smol_str::SmolStr;
|
||||
use tokio::sync::{Mutex as AsyncMutex, Semaphore, SemaphorePermit};
|
||||
@@ -26,13 +28,9 @@ use super::{
|
||||
// saw SNI, before doing TLS handshake. User-side error messages in that case
|
||||
// does not look very nice (`SSL SYSCALL error: Undefined error: 0`), so for now
|
||||
// I went with a more expensive way that yields user-friendlier error messages.
|
||||
//
|
||||
// TODO: add a better bucketing here, e.g. not more than 300 requests per second,
|
||||
// and not more than 1000 requests per 10 seconds, etc. Short bursts of reconnects
|
||||
// are normal during redeployments, so we should not block them.
|
||||
pub struct EndpointRateLimiter {
|
||||
map: DashMap<SmolStr, Vec<RateBucket>>,
|
||||
info: Vec<RateBucketInfo>,
|
||||
info: &'static [RateBucketInfo],
|
||||
access_count: AtomicUsize,
|
||||
}
|
||||
|
||||
@@ -60,25 +58,76 @@ impl RateBucket {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq)]
|
||||
pub struct RateBucketInfo {
|
||||
interval: Duration,
|
||||
pub interval: Duration,
|
||||
// requests per interval
|
||||
max_rpi: u32,
|
||||
pub max_rpi: u32,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for RateBucketInfo {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let rps = self.max_rpi * 1000 / self.interval.as_millis() as u32;
|
||||
write!(f, "{rps}@{}", humantime::format_duration(self.interval))
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for RateBucketInfo {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{self}")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for RateBucketInfo {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let Some((max_rps, interval)) = s.split_once('@') else {
|
||||
bail!("invalid rate info")
|
||||
};
|
||||
let max_rps = max_rps.parse()?;
|
||||
let interval = humantime::parse_duration(interval)?;
|
||||
Ok(Self::new(max_rps, interval))
|
||||
}
|
||||
}
|
||||
|
||||
impl RateBucketInfo {
|
||||
pub fn new(max_rps: u32, interval: Duration) -> Self {
|
||||
pub const DEFAULT_SET: [Self; 3] = [
|
||||
Self::new(300, Duration::from_secs(1)),
|
||||
Self::new(200, Duration::from_secs(60)),
|
||||
Self::new(100, Duration::from_secs(600)),
|
||||
];
|
||||
|
||||
pub fn validate(info: &mut [Self]) -> anyhow::Result<()> {
|
||||
info.sort_unstable_by_key(|info| info.interval);
|
||||
let invalid = info
|
||||
.iter()
|
||||
.tuple_windows()
|
||||
.find(|(a, b)| a.max_rpi > b.max_rpi);
|
||||
if let Some((a, b)) = invalid {
|
||||
bail!(
|
||||
"invalid endpoint RPS limits. {b} allows fewer requests per bucket than {a} ({} vs {})",
|
||||
b.max_rpi,
|
||||
a.max_rpi,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub const fn new(max_rps: u32, interval: Duration) -> Self {
|
||||
Self {
|
||||
interval,
|
||||
max_rpi: max_rps * 1000 / interval.as_millis() as u32,
|
||||
max_rpi: max_rps * interval.as_millis() as u32 / 1000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EndpointRateLimiter {
|
||||
pub fn new(info: impl IntoIterator<Item = RateBucketInfo>) -> Self {
|
||||
pub fn new(info: &'static [RateBucketInfo]) -> Self {
|
||||
info!(buckets = ?info, "endpoint rate limiter");
|
||||
Self {
|
||||
info: info.into_iter().collect(),
|
||||
info,
|
||||
map: DashMap::with_shard_amount(64),
|
||||
access_count: AtomicUsize::new(1), // start from 1 to avoid GC on the first request
|
||||
}
|
||||
@@ -107,7 +156,7 @@ impl EndpointRateLimiter {
|
||||
|
||||
let should_allow_request = entry
|
||||
.iter_mut()
|
||||
.zip(&self.info)
|
||||
.zip(self.info)
|
||||
.all(|(bucket, info)| bucket.should_allow_request(info, now));
|
||||
|
||||
if should_allow_request {
|
||||
@@ -444,9 +493,11 @@ mod tests {
|
||||
use std::{pin::pin, task::Context, time::Duration};
|
||||
|
||||
use futures::{task::noop_waker_ref, Future};
|
||||
use smol_str::SmolStr;
|
||||
use tokio::time;
|
||||
|
||||
use super::{Limiter, Outcome};
|
||||
use crate::rate_limiter::RateLimitAlgorithm;
|
||||
use super::{EndpointRateLimiter, Limiter, Outcome};
|
||||
use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm};
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_works() {
|
||||
@@ -555,4 +606,88 @@ mod tests {
|
||||
limiter.release(token1, None).await;
|
||||
limiter.release(token2, None).await;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rate_bucket_rpi() {
|
||||
let rate_bucket = RateBucketInfo::new(50, Duration::from_secs(5));
|
||||
assert_eq!(rate_bucket.max_rpi, 50 * 5);
|
||||
|
||||
let rate_bucket = RateBucketInfo::new(50, Duration::from_millis(500));
|
||||
assert_eq!(rate_bucket.max_rpi, 50 / 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rate_bucket_parse() {
|
||||
let rate_bucket: RateBucketInfo = "100@10s".parse().unwrap();
|
||||
assert_eq!(rate_bucket.interval, Duration::from_secs(10));
|
||||
assert_eq!(rate_bucket.max_rpi, 100 * 10);
|
||||
assert_eq!(rate_bucket.to_string(), "100@10s");
|
||||
|
||||
let rate_bucket: RateBucketInfo = "100@1m".parse().unwrap();
|
||||
assert_eq!(rate_bucket.interval, Duration::from_secs(60));
|
||||
assert_eq!(rate_bucket.max_rpi, 100 * 60);
|
||||
assert_eq!(rate_bucket.to_string(), "100@1m");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_rate_buckets() {
|
||||
let mut defaults = RateBucketInfo::DEFAULT_SET;
|
||||
RateBucketInfo::validate(&mut defaults[..]).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "invalid endpoint RPS limits. 10@10s allows fewer requests per bucket than 300@1s (100 vs 300)"]
|
||||
fn rate_buckets_validate() {
|
||||
let mut rates: Vec<RateBucketInfo> = ["300@1s", "10@10s"]
|
||||
.into_iter()
|
||||
.map(|s| s.parse().unwrap())
|
||||
.collect();
|
||||
RateBucketInfo::validate(&mut rates).unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rate_limits() {
|
||||
let mut rates: Vec<RateBucketInfo> = ["100@1s", "20@30s"]
|
||||
.into_iter()
|
||||
.map(|s| s.parse().unwrap())
|
||||
.collect();
|
||||
RateBucketInfo::validate(&mut rates).unwrap();
|
||||
let limiter = EndpointRateLimiter::new(Vec::leak(rates));
|
||||
|
||||
let endpoint = SmolStr::from("ep-my-endpoint-1234");
|
||||
|
||||
time::pause();
|
||||
|
||||
for _ in 0..100 {
|
||||
assert!(limiter.check(endpoint.clone()));
|
||||
}
|
||||
// more connections fail
|
||||
assert!(!limiter.check(endpoint.clone()));
|
||||
|
||||
// fail even after 500ms as it's in the same bucket
|
||||
time::advance(time::Duration::from_millis(500)).await;
|
||||
assert!(!limiter.check(endpoint.clone()));
|
||||
|
||||
// after a full 1s, 100 requests are allowed again
|
||||
time::advance(time::Duration::from_millis(500)).await;
|
||||
for _ in 1..6 {
|
||||
for _ in 0..100 {
|
||||
assert!(limiter.check(endpoint.clone()));
|
||||
}
|
||||
time::advance(time::Duration::from_millis(1000)).await;
|
||||
}
|
||||
|
||||
// more connections after 600 will exceed the 20rps@30s limit
|
||||
assert!(!limiter.check(endpoint.clone()));
|
||||
|
||||
// will still fail before the 30 second limit
|
||||
time::advance(time::Duration::from_millis(30_000 - 6_000 - 1)).await;
|
||||
assert!(!limiter.check(endpoint.clone()));
|
||||
|
||||
// after the full 30 seconds, 100 requests are allowed again
|
||||
time::advance(time::Duration::from_millis(1)).await;
|
||||
for _ in 0..100 {
|
||||
assert!(limiter.check(endpoint.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,12 +10,11 @@ use anyhow::bail;
|
||||
use hyper::StatusCode;
|
||||
pub use reqwest_middleware::{ClientWithMiddleware, Error};
|
||||
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
use tokio::time;
|
||||
use tokio_util::task::TaskTracker;
|
||||
|
||||
use crate::protocol2::{ProxyProtocolAccept, WithClientIp};
|
||||
use crate::proxy::{NUM_CLIENT_CONNECTION_CLOSED_COUNTER, NUM_CLIENT_CONNECTION_OPENED_COUNTER};
|
||||
use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::{cancellation::CancelMap, config::ProxyConfig};
|
||||
use futures::StreamExt;
|
||||
use hyper::{
|
||||
@@ -39,16 +38,13 @@ pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
ws_listener: TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("websocket server has shut down");
|
||||
}
|
||||
|
||||
let conn_pool = conn_pool::GlobalConnPool::new(config);
|
||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new([RateBucketInfo::new(
|
||||
config.endpoint_rps_limit,
|
||||
time::Duration::from_secs(1),
|
||||
)]));
|
||||
|
||||
// shutdown the connection pool
|
||||
tokio::spawn({
|
||||
|
||||
Reference in New Issue
Block a user