mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-28 02:20:42 +00:00
add authentication rate limiting (#6865)
## Problem https://github.com/neondatabase/cloud/issues/9642 ## Summary of changes 1. Make `EndpointRateLimiter` generic, renamed as `BucketRateLimiter` 2. Add support for claiming multiple tokens at once 3. Add `AuthRateLimiter` alias. 4. Check `(Endpoint, IP)` pair during authentication, weighted by how many hashes proxy would be doing. TODO: handle ipv6 subnets. will do this in a separate PR.
This commit is contained in:
@@ -40,7 +40,7 @@ macro_rules! register_hll {
|
||||
}};
|
||||
|
||||
($N:literal, $NAME:expr, $HELP:expr $(,)?) => {{
|
||||
$crate::register_hll!($N, $crate::opts!($NAME, $HELP), $LABELS_NAMES)
|
||||
$crate::register_hll!($N, $crate::opts!($NAME, $HELP))
|
||||
}};
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@ use crate::console::errors::GetAuthInfoError;
|
||||
use crate::console::provider::{CachedRoleSecret, ConsoleBackend};
|
||||
use crate::console::{AuthSecret, NodeInfo};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::metrics::{AUTH_RATE_LIMIT_HITS, ENDPOINTS_AUTH_RATE_LIMITED};
|
||||
use crate::proxy::connect_compute::ComputeConnectBackend;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::stream::Stream;
|
||||
@@ -28,7 +30,7 @@ use crate::{
|
||||
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
|
||||
pub enum MaybeOwned<'a, T> {
|
||||
@@ -174,6 +176,52 @@ impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthenticationConfig {
|
||||
pub fn check_rate_limit(
|
||||
&self,
|
||||
|
||||
ctx: &mut RequestMonitoring,
|
||||
secret: AuthSecret,
|
||||
endpoint: &EndpointId,
|
||||
is_cleartext: bool,
|
||||
) -> auth::Result<AuthSecret> {
|
||||
// we have validated the endpoint exists, so let's intern it.
|
||||
let endpoint_int = EndpointIdInt::from(endpoint);
|
||||
|
||||
// only count the full hash count if password hack or websocket flow.
|
||||
// in other words, if proxy needs to run the hashing
|
||||
let password_weight = if is_cleartext {
|
||||
match &secret {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
AuthSecret::Md5(_) => 1,
|
||||
AuthSecret::Scram(s) => s.iterations + 1,
|
||||
}
|
||||
} else {
|
||||
// validating scram takes just 1 hmac_sha_256 operation.
|
||||
1
|
||||
};
|
||||
|
||||
let limit_not_exceeded = self
|
||||
.rate_limiter
|
||||
.check((endpoint_int, ctx.peer_addr), password_weight);
|
||||
|
||||
if !limit_not_exceeded {
|
||||
warn!(
|
||||
enabled = self.rate_limiter_enabled,
|
||||
"rate limiting authentication"
|
||||
);
|
||||
AUTH_RATE_LIMIT_HITS.inc();
|
||||
ENDPOINTS_AUTH_RATE_LIMITED.measure(endpoint);
|
||||
|
||||
if self.rate_limiter_enabled {
|
||||
return Err(auth::AuthError::too_many_connections());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(secret)
|
||||
}
|
||||
}
|
||||
|
||||
/// True to its name, this function encapsulates our current auth trade-offs.
|
||||
/// Here, we choose the appropriate auth flow based on circumstances.
|
||||
///
|
||||
@@ -214,14 +262,24 @@ async fn auth_quirks(
|
||||
Some(secret) => secret,
|
||||
None => api.get_role_secret(ctx, &info).await?,
|
||||
};
|
||||
let (cached_entry, secret) = cached_secret.take_value();
|
||||
|
||||
let secret = match secret {
|
||||
Some(secret) => config.check_rate_limit(
|
||||
ctx,
|
||||
secret,
|
||||
&info.endpoint,
|
||||
unauthenticated_password.is_some() || allow_cleartext,
|
||||
)?,
|
||||
None => {
|
||||
// If we don't have an authentication secret, we mock one to
|
||||
// prevent malicious probing (possible due to missing protocol steps).
|
||||
// This mocked secret will never lead to successful authentication.
|
||||
info!("authentication info not found, mocking it");
|
||||
AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
|
||||
}
|
||||
};
|
||||
|
||||
let secret = cached_secret.value.clone().unwrap_or_else(|| {
|
||||
// If we don't have an authentication secret, we mock one to
|
||||
// prevent malicious probing (possible due to missing protocol steps).
|
||||
// This mocked secret will never lead to successful authentication.
|
||||
info!("authentication info not found, mocking it");
|
||||
AuthSecret::Scram(scram::ServerSecret::mock(&info.user, rand::random()))
|
||||
});
|
||||
match authenticate_with_secret(
|
||||
ctx,
|
||||
secret,
|
||||
@@ -237,7 +295,7 @@ async fn auth_quirks(
|
||||
Err(e) => {
|
||||
if e.is_auth_failed() {
|
||||
// The password could have been changed, so we invalidate the cache.
|
||||
cached_secret.invalidate();
|
||||
cached_entry.invalidate();
|
||||
}
|
||||
Err(e)
|
||||
}
|
||||
@@ -415,6 +473,7 @@ mod tests {
|
||||
|
||||
use bytes::BytesMut;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use once_cell::sync::Lazy;
|
||||
use postgres_protocol::{
|
||||
authentication::sasl::{ChannelBinding, ScramSha256},
|
||||
message::{backend::Message as PgMessage, frontend},
|
||||
@@ -432,6 +491,7 @@ mod tests {
|
||||
},
|
||||
context::RequestMonitoring,
|
||||
proxy::NeonOptions,
|
||||
rate_limiter::{AuthRateLimiter, RateBucketInfo},
|
||||
scram::ServerSecret,
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
@@ -473,9 +533,11 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
static CONFIG: &AuthenticationConfig = &AuthenticationConfig {
|
||||
static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
|
||||
scram_protocol_timeout: std::time::Duration::from_secs(5),
|
||||
};
|
||||
rate_limiter_enabled: true,
|
||||
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
|
||||
});
|
||||
|
||||
async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
|
||||
loop {
|
||||
@@ -544,7 +606,7 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, false, CONFIG)
|
||||
let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, false, &CONFIG)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -584,7 +646,7 @@ mod tests {
|
||||
client.write_all(&write).await.unwrap();
|
||||
});
|
||||
|
||||
let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, CONFIG)
|
||||
let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, &CONFIG)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -624,7 +686,7 @@ mod tests {
|
||||
client.write_all(&write).await.unwrap();
|
||||
});
|
||||
|
||||
let creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, CONFIG)
|
||||
let creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, &CONFIG)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ use proxy::console;
|
||||
use proxy::context::parquet::ParquetUploadArgs;
|
||||
use proxy::http;
|
||||
use proxy::metrics::NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT;
|
||||
use proxy::rate_limiter::AuthRateLimiter;
|
||||
use proxy::rate_limiter::EndpointRateLimiter;
|
||||
use proxy::rate_limiter::RateBucketInfo;
|
||||
use proxy::rate_limiter::RateLimiterConfig;
|
||||
@@ -141,10 +142,16 @@ struct ProxyCliArgs {
|
||||
///
|
||||
/// Provided in the form '<Requests Per Second>@<Bucket Duration Size>'.
|
||||
/// Can be given multiple times for different bucket sizes.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
|
||||
endpoint_rps_limit: Vec<RateBucketInfo>,
|
||||
/// Whether the auth rate limiter actually takes effect (for testing)
|
||||
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
|
||||
auth_rate_limit_enabled: bool,
|
||||
/// Authentication rate limiter max number of hashes per second.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)]
|
||||
auth_rate_limit: Vec<RateBucketInfo>,
|
||||
/// Redis rate limiter max number of requests per second.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
|
||||
redis_rps_limit: Vec<RateBucketInfo>,
|
||||
/// Initial limit for dynamic rate limiter. Makes sense only if `rate_limit_algorithm` is *not* `None`.
|
||||
#[clap(long, default_value_t = 100)]
|
||||
@@ -510,6 +517,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
};
|
||||
let authentication_config = AuthenticationConfig {
|
||||
scram_protocol_timeout: args.scram_protocol_timeout,
|
||||
rate_limiter_enabled: args.auth_rate_limit_enabled,
|
||||
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
|
||||
};
|
||||
|
||||
let mut endpoint_rps_limit = args.endpoint_rps_limit.clone();
|
||||
|
||||
10
proxy/src/cache/common.rs
vendored
10
proxy/src/cache/common.rs
vendored
@@ -43,6 +43,16 @@ impl<C: Cache, V> Cached<C, V> {
|
||||
Self { token: None, value }
|
||||
}
|
||||
|
||||
pub fn take_value(self) -> (Cached<C, ()>, V) {
|
||||
(
|
||||
Cached {
|
||||
token: self.token,
|
||||
value: (),
|
||||
},
|
||||
self.value,
|
||||
)
|
||||
}
|
||||
|
||||
/// Drop this entry from a cache if it's still there.
|
||||
pub fn invalidate(self) -> V {
|
||||
if let Some((cache, info)) = &self.token {
|
||||
|
||||
30
proxy/src/cache/project_info.rs
vendored
30
proxy/src/cache/project_info.rs
vendored
@@ -373,10 +373,7 @@ mod tests {
|
||||
let endpoint_id = "endpoint".into();
|
||||
let user1: RoleName = "user1".into();
|
||||
let user2: RoleName = "user2".into();
|
||||
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock(
|
||||
user1.as_str(),
|
||||
[1; 32],
|
||||
)));
|
||||
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
|
||||
let secret2 = None;
|
||||
let allowed_ips = Arc::new(vec![
|
||||
"127.0.0.1".parse().unwrap(),
|
||||
@@ -395,10 +392,7 @@ mod tests {
|
||||
|
||||
// Shouldn't add more than 2 roles.
|
||||
let user3: RoleName = "user3".into();
|
||||
let secret3 = Some(AuthSecret::Scram(ServerSecret::mock(
|
||||
user3.as_str(),
|
||||
[3; 32],
|
||||
)));
|
||||
let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32])));
|
||||
cache.insert_role_secret(&project_id, &endpoint_id, &user3, secret3.clone());
|
||||
assert!(cache.get_role_secret(&endpoint_id, &user3).is_none());
|
||||
|
||||
@@ -431,14 +425,8 @@ mod tests {
|
||||
let endpoint_id = "endpoint".into();
|
||||
let user1: RoleName = "user1".into();
|
||||
let user2: RoleName = "user2".into();
|
||||
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock(
|
||||
user1.as_str(),
|
||||
[1; 32],
|
||||
)));
|
||||
let secret2 = Some(AuthSecret::Scram(ServerSecret::mock(
|
||||
user2.as_str(),
|
||||
[2; 32],
|
||||
)));
|
||||
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
|
||||
let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32])));
|
||||
let allowed_ips = Arc::new(vec![
|
||||
"127.0.0.1".parse().unwrap(),
|
||||
"127.0.0.2".parse().unwrap(),
|
||||
@@ -486,14 +474,8 @@ mod tests {
|
||||
let endpoint_id = "endpoint".into();
|
||||
let user1: RoleName = "user1".into();
|
||||
let user2: RoleName = "user2".into();
|
||||
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock(
|
||||
user1.as_str(),
|
||||
[1; 32],
|
||||
)));
|
||||
let secret2 = Some(AuthSecret::Scram(ServerSecret::mock(
|
||||
user2.as_str(),
|
||||
[2; 32],
|
||||
)));
|
||||
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
|
||||
let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32])));
|
||||
let allowed_ips = Arc::new(vec![
|
||||
"127.0.0.1".parse().unwrap(),
|
||||
"127.0.0.2".parse().unwrap(),
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
use crate::{auth, rate_limiter::RateBucketInfo, serverless::GlobalConnPoolOptions};
|
||||
use crate::{
|
||||
auth,
|
||||
rate_limiter::{AuthRateLimiter, RateBucketInfo},
|
||||
serverless::GlobalConnPoolOptions,
|
||||
};
|
||||
use anyhow::{bail, ensure, Context, Ok};
|
||||
use itertools::Itertools;
|
||||
use rustls::{
|
||||
@@ -50,6 +54,8 @@ pub struct HttpConfig {
|
||||
|
||||
pub struct AuthenticationConfig {
|
||||
pub scram_protocol_timeout: tokio::time::Duration,
|
||||
pub rate_limiter_enabled: bool,
|
||||
pub rate_limiter: AuthRateLimiter,
|
||||
}
|
||||
|
||||
impl TlsConfig {
|
||||
|
||||
@@ -4,7 +4,10 @@ use ::metrics::{
|
||||
register_int_gauge_vec, Histogram, HistogramVec, HyperLogLogVec, IntCounterPairVec,
|
||||
IntCounterVec, IntGauge, IntGaugeVec,
|
||||
};
|
||||
use metrics::{register_int_counter, register_int_counter_pair, IntCounter, IntCounterPair};
|
||||
use metrics::{
|
||||
register_hll, register_int_counter, register_int_counter_pair, HyperLogLog, IntCounter,
|
||||
IntCounterPair,
|
||||
};
|
||||
|
||||
use once_cell::sync::Lazy;
|
||||
use tokio::time::{self, Instant};
|
||||
@@ -358,3 +361,20 @@ pub static TLS_HANDSHAKE_FAILURES: Lazy<IntCounter> = Lazy::new(|| {
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
pub static ENDPOINTS_AUTH_RATE_LIMITED: Lazy<HyperLogLog<32>> = Lazy::new(|| {
|
||||
register_hll!(
|
||||
32,
|
||||
"proxy_endpoints_auth_rate_limits",
|
||||
"Number of endpoints affected by authentication rate limits",
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
pub static AUTH_RATE_LIMIT_HITS: Lazy<IntCounter> = Lazy::new(|| {
|
||||
register_int_counter!(
|
||||
"proxy_requests_auth_rate_limits_total",
|
||||
"Number of connection requests affected by authentication rate limits",
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
@@ -280,7 +280,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
|
||||
// check rate limit
|
||||
if let Some(ep) = user_info.get_endpoint() {
|
||||
if !endpoint_rate_limiter.check(ep) {
|
||||
if !endpoint_rate_limiter.check(ep, 1) {
|
||||
return stream
|
||||
.throw_error(auth::AuthError::too_many_connections())
|
||||
.await?;
|
||||
|
||||
@@ -142,8 +142,8 @@ impl Scram {
|
||||
Ok(Scram(secret))
|
||||
}
|
||||
|
||||
fn mock(user: &str) -> Self {
|
||||
Scram(scram::ServerSecret::mock(user, rand::random()))
|
||||
fn mock() -> Self {
|
||||
Scram(scram::ServerSecret::mock(rand::random()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -330,11 +330,7 @@ async fn scram_auth_mock() -> anyhow::Result<()> {
|
||||
|
||||
let (client_config, server_config) =
|
||||
generate_tls_config("generic-project-name.localhost", "localhost")?;
|
||||
let proxy = tokio::spawn(dummy_proxy(
|
||||
client,
|
||||
Some(server_config),
|
||||
Scram::mock("user"),
|
||||
));
|
||||
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), Scram::mock()));
|
||||
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
let password: String = rand::thread_rng()
|
||||
|
||||
@@ -4,4 +4,4 @@ mod limiter;
|
||||
pub use aimd::Aimd;
|
||||
pub use limit_algorithm::{AimdConfig, Fixed, RateLimitAlgorithm, RateLimiterConfig};
|
||||
pub use limiter::Limiter;
|
||||
pub use limiter::{EndpointRateLimiter, RateBucketInfo, RedisRateLimiter};
|
||||
pub use limiter::{AuthRateLimiter, EndpointRateLimiter, RateBucketInfo, RedisRateLimiter};
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
collections::hash_map::RandomState,
|
||||
hash::BuildHasher,
|
||||
hash::{BuildHasher, Hash},
|
||||
net::IpAddr,
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc, Mutex,
|
||||
@@ -15,7 +17,7 @@ use tokio::sync::{Mutex as AsyncMutex, Semaphore, SemaphorePermit};
|
||||
use tokio::time::{timeout, Duration, Instant};
|
||||
use tracing::info;
|
||||
|
||||
use crate::EndpointId;
|
||||
use crate::{intern::EndpointIdInt, EndpointId};
|
||||
|
||||
use super::{
|
||||
limit_algorithm::{LimitAlgorithm, Sample},
|
||||
@@ -49,11 +51,11 @@ impl RedisRateLimiter {
|
||||
.data
|
||||
.iter_mut()
|
||||
.zip(self.info)
|
||||
.all(|(bucket, info)| bucket.should_allow_request(info, now));
|
||||
.all(|(bucket, info)| bucket.should_allow_request(info, now, 1));
|
||||
|
||||
if should_allow_request {
|
||||
// only increment the bucket counts if the request will actually be accepted
|
||||
self.data.iter_mut().for_each(RateBucket::inc);
|
||||
self.data.iter_mut().for_each(|b| b.inc(1));
|
||||
}
|
||||
|
||||
should_allow_request
|
||||
@@ -71,9 +73,14 @@ impl RedisRateLimiter {
|
||||
// saw SNI, before doing TLS handshake. User-side error messages in that case
|
||||
// does not look very nice (`SSL SYSCALL error: Undefined error: 0`), so for now
|
||||
// I went with a more expensive way that yields user-friendlier error messages.
|
||||
pub struct EndpointRateLimiter<Rand = StdRng, Hasher = RandomState> {
|
||||
map: DashMap<EndpointId, Vec<RateBucket>, Hasher>,
|
||||
info: &'static [RateBucketInfo],
|
||||
pub type EndpointRateLimiter = BucketRateLimiter<EndpointId, StdRng, RandomState>;
|
||||
|
||||
// This can't be just per IP because that would limit some PaaS that share IP addresses
|
||||
pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, IpAddr), StdRng, RandomState>;
|
||||
|
||||
pub struct BucketRateLimiter<Key, Rand = StdRng, Hasher = RandomState> {
|
||||
map: DashMap<Key, Vec<RateBucket>, Hasher>,
|
||||
info: Cow<'static, [RateBucketInfo]>,
|
||||
access_count: AtomicUsize,
|
||||
rand: Mutex<Rand>,
|
||||
}
|
||||
@@ -85,9 +92,9 @@ struct RateBucket {
|
||||
}
|
||||
|
||||
impl RateBucket {
|
||||
fn should_allow_request(&mut self, info: &RateBucketInfo, now: Instant) -> bool {
|
||||
fn should_allow_request(&mut self, info: &RateBucketInfo, now: Instant, n: u32) -> bool {
|
||||
if now - self.start < info.interval {
|
||||
self.count < info.max_rpi
|
||||
self.count + n <= info.max_rpi
|
||||
} else {
|
||||
// bucket expired, reset
|
||||
self.count = 0;
|
||||
@@ -97,8 +104,8 @@ impl RateBucket {
|
||||
}
|
||||
}
|
||||
|
||||
fn inc(&mut self) {
|
||||
self.count += 1;
|
||||
fn inc(&mut self, n: u32) {
|
||||
self.count += n;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,7 +118,7 @@ pub struct RateBucketInfo {
|
||||
|
||||
impl std::fmt::Display for RateBucketInfo {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let rps = self.max_rpi * 1000 / self.interval.as_millis() as u32;
|
||||
let rps = (self.max_rpi as u64) * 1000 / self.interval.as_millis() as u64;
|
||||
write!(f, "{rps}@{}", humantime::format_duration(self.interval))
|
||||
}
|
||||
}
|
||||
@@ -136,12 +143,25 @@ impl std::str::FromStr for RateBucketInfo {
|
||||
}
|
||||
|
||||
impl RateBucketInfo {
|
||||
pub const DEFAULT_SET: [Self; 3] = [
|
||||
pub const DEFAULT_ENDPOINT_SET: [Self; 3] = [
|
||||
Self::new(300, Duration::from_secs(1)),
|
||||
Self::new(200, Duration::from_secs(60)),
|
||||
Self::new(100, Duration::from_secs(600)),
|
||||
];
|
||||
|
||||
/// All of these are per endpoint-ip pair.
|
||||
/// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus).
|
||||
///
|
||||
/// First bucket: 300mcpus total per endpoint-ip pair
|
||||
/// * 1228800 requests per second with 1 hash rounds. (endpoint rate limiter will catch this first)
|
||||
/// * 300 requests per second with 4096 hash rounds.
|
||||
/// * 2 requests per second with 600000 hash rounds.
|
||||
pub const DEFAULT_AUTH_SET: [Self; 3] = [
|
||||
Self::new(300 * 4096, Duration::from_secs(1)),
|
||||
Self::new(200 * 4096, Duration::from_secs(60)),
|
||||
Self::new(100 * 4096, Duration::from_secs(600)),
|
||||
];
|
||||
|
||||
pub fn validate(info: &mut [Self]) -> anyhow::Result<()> {
|
||||
info.sort_unstable_by_key(|info| info.interval);
|
||||
let invalid = info
|
||||
@@ -150,7 +170,7 @@ impl RateBucketInfo {
|
||||
.find(|(a, b)| a.max_rpi > b.max_rpi);
|
||||
if let Some((a, b)) = invalid {
|
||||
bail!(
|
||||
"invalid endpoint RPS limits. {b} allows fewer requests per bucket than {a} ({} vs {})",
|
||||
"invalid bucket RPS limits. {b} allows fewer requests per bucket than {a} ({} vs {})",
|
||||
b.max_rpi,
|
||||
a.max_rpi,
|
||||
);
|
||||
@@ -162,19 +182,24 @@ impl RateBucketInfo {
|
||||
pub const fn new(max_rps: u32, interval: Duration) -> Self {
|
||||
Self {
|
||||
interval,
|
||||
max_rpi: max_rps * interval.as_millis() as u32 / 1000,
|
||||
max_rpi: ((max_rps as u64) * (interval.as_millis() as u64) / 1000) as u32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EndpointRateLimiter {
|
||||
pub fn new(info: &'static [RateBucketInfo]) -> Self {
|
||||
impl<K: Hash + Eq> BucketRateLimiter<K> {
|
||||
pub fn new(info: impl Into<Cow<'static, [RateBucketInfo]>>) -> Self {
|
||||
Self::new_with_rand_and_hasher(info, StdRng::from_entropy(), RandomState::new())
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Rng, S: BuildHasher + Clone> EndpointRateLimiter<R, S> {
|
||||
fn new_with_rand_and_hasher(info: &'static [RateBucketInfo], rand: R, hasher: S) -> Self {
|
||||
impl<K: Hash + Eq, R: Rng, S: BuildHasher + Clone> BucketRateLimiter<K, R, S> {
|
||||
fn new_with_rand_and_hasher(
|
||||
info: impl Into<Cow<'static, [RateBucketInfo]>>,
|
||||
rand: R,
|
||||
hasher: S,
|
||||
) -> Self {
|
||||
let info = info.into();
|
||||
info!(buckets = ?info, "endpoint rate limiter");
|
||||
Self {
|
||||
info,
|
||||
@@ -185,7 +210,7 @@ impl<R: Rng, S: BuildHasher + Clone> EndpointRateLimiter<R, S> {
|
||||
}
|
||||
|
||||
/// Check that number of connections to the endpoint is below `max_rps` rps.
|
||||
pub fn check(&self, endpoint: EndpointId) -> bool {
|
||||
pub fn check(&self, key: K, n: u32) -> bool {
|
||||
// do a partial GC every 2k requests. This cleans up ~ 1/64th of the map.
|
||||
// worst case memory usage is about:
|
||||
// = 2 * 2048 * 64 * (48B + 72B)
|
||||
@@ -195,7 +220,7 @@ impl<R: Rng, S: BuildHasher + Clone> EndpointRateLimiter<R, S> {
|
||||
}
|
||||
|
||||
let now = Instant::now();
|
||||
let mut entry = self.map.entry(endpoint).or_insert_with(|| {
|
||||
let mut entry = self.map.entry(key).or_insert_with(|| {
|
||||
vec![
|
||||
RateBucket {
|
||||
start: now,
|
||||
@@ -207,12 +232,12 @@ impl<R: Rng, S: BuildHasher + Clone> EndpointRateLimiter<R, S> {
|
||||
|
||||
let should_allow_request = entry
|
||||
.iter_mut()
|
||||
.zip(self.info)
|
||||
.all(|(bucket, info)| bucket.should_allow_request(info, now));
|
||||
.zip(&*self.info)
|
||||
.all(|(bucket, info)| bucket.should_allow_request(info, now, n));
|
||||
|
||||
if should_allow_request {
|
||||
// only increment the bucket counts if the request will actually be accepted
|
||||
entry.iter_mut().for_each(RateBucket::inc);
|
||||
entry.iter_mut().for_each(|b| b.inc(n));
|
||||
}
|
||||
|
||||
should_allow_request
|
||||
@@ -223,7 +248,7 @@ impl<R: Rng, S: BuildHasher + Clone> EndpointRateLimiter<R, S> {
|
||||
/// But that way deletion does not aquire mutex on each entry access.
|
||||
pub fn do_gc(&self) {
|
||||
info!(
|
||||
"cleaning up endpoint rate limiter, current size = {}",
|
||||
"cleaning up bucket rate limiter, current size = {}",
|
||||
self.map.len()
|
||||
);
|
||||
let n = self.map.shards().len();
|
||||
@@ -534,7 +559,7 @@ mod tests {
|
||||
use rustc_hash::FxHasher;
|
||||
use tokio::time;
|
||||
|
||||
use super::{EndpointRateLimiter, Limiter, Outcome};
|
||||
use super::{BucketRateLimiter, EndpointRateLimiter, Limiter, Outcome};
|
||||
use crate::{
|
||||
rate_limiter::{RateBucketInfo, RateLimitAlgorithm},
|
||||
EndpointId,
|
||||
@@ -672,12 +697,12 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn default_rate_buckets() {
|
||||
let mut defaults = RateBucketInfo::DEFAULT_SET;
|
||||
let mut defaults = RateBucketInfo::DEFAULT_ENDPOINT_SET;
|
||||
RateBucketInfo::validate(&mut defaults[..]).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "invalid endpoint RPS limits. 10@10s allows fewer requests per bucket than 300@1s (100 vs 300)"]
|
||||
#[should_panic = "invalid bucket RPS limits. 10@10s allows fewer requests per bucket than 300@1s (100 vs 300)"]
|
||||
fn rate_buckets_validate() {
|
||||
let mut rates: Vec<RateBucketInfo> = ["300@1s", "10@10s"]
|
||||
.into_iter()
|
||||
@@ -693,42 +718,42 @@ mod tests {
|
||||
.map(|s| s.parse().unwrap())
|
||||
.collect();
|
||||
RateBucketInfo::validate(&mut rates).unwrap();
|
||||
let limiter = EndpointRateLimiter::new(Vec::leak(rates));
|
||||
let limiter = EndpointRateLimiter::new(rates);
|
||||
|
||||
let endpoint = EndpointId::from("ep-my-endpoint-1234");
|
||||
|
||||
time::pause();
|
||||
|
||||
for _ in 0..100 {
|
||||
assert!(limiter.check(endpoint.clone()));
|
||||
assert!(limiter.check(endpoint.clone(), 1));
|
||||
}
|
||||
// more connections fail
|
||||
assert!(!limiter.check(endpoint.clone()));
|
||||
assert!(!limiter.check(endpoint.clone(), 1));
|
||||
|
||||
// fail even after 500ms as it's in the same bucket
|
||||
time::advance(time::Duration::from_millis(500)).await;
|
||||
assert!(!limiter.check(endpoint.clone()));
|
||||
assert!(!limiter.check(endpoint.clone(), 1));
|
||||
|
||||
// after a full 1s, 100 requests are allowed again
|
||||
time::advance(time::Duration::from_millis(500)).await;
|
||||
for _ in 1..6 {
|
||||
for _ in 0..100 {
|
||||
assert!(limiter.check(endpoint.clone()));
|
||||
for _ in 0..50 {
|
||||
assert!(limiter.check(endpoint.clone(), 2));
|
||||
}
|
||||
time::advance(time::Duration::from_millis(1000)).await;
|
||||
}
|
||||
|
||||
// more connections after 600 will exceed the 20rps@30s limit
|
||||
assert!(!limiter.check(endpoint.clone()));
|
||||
assert!(!limiter.check(endpoint.clone(), 1));
|
||||
|
||||
// will still fail before the 30 second limit
|
||||
time::advance(time::Duration::from_millis(30_000 - 6_000 - 1)).await;
|
||||
assert!(!limiter.check(endpoint.clone()));
|
||||
assert!(!limiter.check(endpoint.clone(), 1));
|
||||
|
||||
// after the full 30 seconds, 100 requests are allowed again
|
||||
time::advance(time::Duration::from_millis(1)).await;
|
||||
for _ in 0..100 {
|
||||
assert!(limiter.check(endpoint.clone()));
|
||||
assert!(limiter.check(endpoint.clone(), 1));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -738,14 +763,41 @@ mod tests {
|
||||
let rand = rand::rngs::StdRng::from_seed([1; 32]);
|
||||
let hasher = BuildHasherDefault::<FxHasher>::default();
|
||||
|
||||
let limiter = EndpointRateLimiter::new_with_rand_and_hasher(
|
||||
&RateBucketInfo::DEFAULT_SET,
|
||||
let limiter = BucketRateLimiter::new_with_rand_and_hasher(
|
||||
&RateBucketInfo::DEFAULT_ENDPOINT_SET,
|
||||
rand,
|
||||
hasher,
|
||||
);
|
||||
for i in 0..1_000_000 {
|
||||
limiter.check(format!("{i}").into());
|
||||
limiter.check(i, 1);
|
||||
}
|
||||
assert!(limiter.map.len() < 150_000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_auth_set() {
|
||||
// these values used to exceed u32::MAX
|
||||
assert_eq!(
|
||||
RateBucketInfo::DEFAULT_AUTH_SET,
|
||||
[
|
||||
RateBucketInfo {
|
||||
interval: Duration::from_secs(1),
|
||||
max_rpi: 300 * 4096,
|
||||
},
|
||||
RateBucketInfo {
|
||||
interval: Duration::from_secs(60),
|
||||
max_rpi: 200 * 4096 * 60,
|
||||
},
|
||||
RateBucketInfo {
|
||||
interval: Duration::from_secs(600),
|
||||
max_rpi: 100 * 4096 * 600,
|
||||
}
|
||||
]
|
||||
);
|
||||
|
||||
for x in RateBucketInfo::DEFAULT_AUTH_SET {
|
||||
let y = x.to_string().parse().unwrap();
|
||||
assert_eq!(x, y);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,13 +50,13 @@ impl ServerSecret {
|
||||
/// To avoid revealing information to an attacker, we use a
|
||||
/// mocked server secret even if the user doesn't exist.
|
||||
/// See `auth-scram.c : mock_scram_secret` for details.
|
||||
pub fn mock(user: &str, nonce: [u8; 32]) -> Self {
|
||||
// Refer to `auth-scram.c : scram_mock_salt`.
|
||||
let mocked_salt = super::sha256([user.as_bytes(), &nonce]);
|
||||
|
||||
pub fn mock(nonce: [u8; 32]) -> Self {
|
||||
Self {
|
||||
iterations: 4096,
|
||||
salt_base64: base64::encode(mocked_salt),
|
||||
// this doesn't reveal much information as we're going to use
|
||||
// iteration count 1 for our generated passwords going forward.
|
||||
// PG16 users can set iteration count=1 already today.
|
||||
iterations: 1,
|
||||
salt_base64: base64::encode(nonce),
|
||||
stored_key: ScramKey::default(),
|
||||
server_key: ScramKey::default(),
|
||||
doomed: true,
|
||||
|
||||
@@ -42,7 +42,12 @@ impl PoolingBackend {
|
||||
};
|
||||
|
||||
let secret = match cached_secret.value.clone() {
|
||||
Some(secret) => secret,
|
||||
Some(secret) => self.config.authentication_config.check_rate_limit(
|
||||
ctx,
|
||||
secret,
|
||||
&user_info.endpoint,
|
||||
true,
|
||||
)?,
|
||||
None => {
|
||||
// If we don't have an authentication secret, for the http flow we can just return an error.
|
||||
info!("authentication info not found");
|
||||
|
||||
Reference in New Issue
Block a user