diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index 7ceb1e6814..2e3013ead0 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -14,12 +14,13 @@ use serde::{Deserialize, Serialize}; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{debug, info}; -use crate::auth::{self, AuthError, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange}; +use crate::auth::{self, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange}; use crate::cache::Cached; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::client::ControlPlaneClient; use crate::control_plane::errors::GetAuthInfoError; +use crate::control_plane::messages::EndpointRateLimitConfig; use crate::control_plane::{ self, AccessBlockerFlags, AuthSecret, CachedNodeInfo, ControlPlaneApi, EndpointAccessControl, RoleAccessControl, @@ -230,11 +231,8 @@ async fn auth_quirks( config.is_vpc_acccess_proxy, )?; - let endpoint = EndpointIdInt::from(&info.endpoint); - let rate_limit_config = None; - if !endpoint_rate_limiter.check(endpoint, rate_limit_config, 1) { - return Err(AuthError::too_many_connections()); - } + access_controls.connection_attempt_rate_limit(ctx, &info.endpoint, &endpoint_rate_limiter)?; + let role_access = api .get_role_access_control(ctx, &info.endpoint, &info.user) .await?; @@ -401,6 +399,7 @@ impl Backend<'_, ComputeUserInfo> { allowed_ips: Arc::new(vec![]), allowed_vpce: Arc::new(vec![]), flags: AccessBlockerFlags::default(), + rate_limits: EndpointRateLimitConfig::default(), }), } } @@ -439,6 +438,7 @@ mod tests { use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern}; use crate::config::AuthenticationConfig; use crate::context::RequestContext; + use crate::control_plane::messages::EndpointRateLimitConfig; use crate::control_plane::{ self, AccessBlockerFlags, CachedNodeInfo, EndpointAccessControl, RoleAccessControl, }; @@ -477,6 +477,7 @@ mod tests { allowed_ips: Arc::new(self.ips.clone()), allowed_vpce: Arc::new(self.vpc_endpoint_ids.clone()), flags: self.access_blocker_flags, + rate_limits: EndpointRateLimitConfig::default(), }) } diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index 9a4be2f904..d37c107323 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -364,6 +364,7 @@ mod tests { use std::sync::Arc; use super::*; + use crate::control_plane::messages::EndpointRateLimitConfig; use crate::control_plane::{AccessBlockerFlags, AuthSecret}; use crate::scram::ServerSecret; use crate::types::ProjectId; @@ -399,6 +400,7 @@ mod tests { allowed_ips: allowed_ips.clone(), allowed_vpce: Arc::new(vec![]), flags: AccessBlockerFlags::default(), + rate_limits: EndpointRateLimitConfig::default(), }, RoleAccessControl { secret: secret1.clone(), @@ -414,6 +416,7 @@ mod tests { allowed_ips: allowed_ips.clone(), allowed_vpce: Arc::new(vec![]), flags: AccessBlockerFlags::default(), + rate_limits: EndpointRateLimitConfig::default(), }, RoleAccessControl { secret: secret2.clone(), @@ -439,6 +442,7 @@ mod tests { allowed_ips: allowed_ips.clone(), allowed_vpce: Arc::new(vec![]), flags: AccessBlockerFlags::default(), + rate_limits: EndpointRateLimitConfig::default(), }, RoleAccessControl { secret: secret3.clone(), diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index cf2d9fba14..8c76d034f7 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -146,6 +146,7 @@ impl NeonControlPlaneClient { public_access_blocked: block_public_connections, vpc_access_blocked: block_vpc_connections, }, + rate_limits: body.rate_limits, }) } .inspect_err(|e| tracing::debug!(error = ?e)) @@ -312,6 +313,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { allowed_ips: Arc::new(auth_info.allowed_ips), allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids), flags: auth_info.access_blocker_flags, + rate_limits: auth_info.rate_limits, }; let role_control = RoleAccessControl { secret: auth_info.secret, @@ -357,6 +359,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { allowed_ips: Arc::new(auth_info.allowed_ips), allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids), flags: auth_info.access_blocker_flags, + rate_limits: auth_info.rate_limits, }; let role_control = RoleAccessControl { secret: auth_info.secret, diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index aeea57f2fc..b84dba6b09 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -20,7 +20,7 @@ use crate::context::RequestContext; use crate::control_plane::errors::{ ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError, }; -use crate::control_plane::messages::MetricsAuxInfo; +use crate::control_plane::messages::{EndpointRateLimitConfig, MetricsAuxInfo}; use crate::control_plane::{ AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo, RoleAccessControl, @@ -130,6 +130,7 @@ impl MockControlPlane { project_id: None, account_id: None, access_blocker_flags: AccessBlockerFlags::default(), + rate_limits: EndpointRateLimitConfig::default(), }) } @@ -233,6 +234,7 @@ impl super::ControlPlaneApi for MockControlPlane { allowed_ips: Arc::new(info.allowed_ips), allowed_vpce: Arc::new(info.allowed_vpc_endpoint_ids), flags: info.access_blocker_flags, + rate_limits: info.rate_limits, }) } diff --git a/proxy/src/control_plane/client/mod.rs b/proxy/src/control_plane/client/mod.rs index 9b9d1e25ea..4e5f5c7899 100644 --- a/proxy/src/control_plane/client/mod.rs +++ b/proxy/src/control_plane/client/mod.rs @@ -10,6 +10,7 @@ use clashmap::ClashMap; use tokio::time::Instant; use tracing::{debug, info}; +use super::{EndpointAccessControl, RoleAccessControl}; use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError}; use crate::cache::endpoints::EndpointsCache; @@ -22,8 +23,6 @@ use crate::metrics::ApiLockMetrics; use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token}; use crate::types::EndpointId; -use super::{EndpointAccessControl, RoleAccessControl}; - #[non_exhaustive] #[derive(Clone)] pub enum ControlPlaneClient { diff --git a/proxy/src/control_plane/messages.rs b/proxy/src/control_plane/messages.rs index ec4554eab5..f0314f91f0 100644 --- a/proxy/src/control_plane/messages.rs +++ b/proxy/src/control_plane/messages.rs @@ -227,12 +227,35 @@ pub(crate) struct UserFacingMessage { #[derive(Deserialize)] pub(crate) struct GetEndpointAccessControl { pub(crate) role_secret: Box, - pub(crate) allowed_ips: Option>, - pub(crate) allowed_vpc_endpoint_ids: Option>, + pub(crate) project_id: Option, pub(crate) account_id: Option, + + pub(crate) allowed_ips: Option>, + pub(crate) allowed_vpc_endpoint_ids: Option>, pub(crate) block_public_connections: Option, pub(crate) block_vpc_connections: Option, + + #[serde(default)] + pub(crate) rate_limits: EndpointRateLimitConfig, +} + +#[derive(Copy, Clone, Deserialize, Default)] +pub struct EndpointRateLimitConfig { + pub connection_attempts: ConnectionAttemptsLimit, +} + +#[derive(Copy, Clone, Deserialize, Default)] +pub struct ConnectionAttemptsLimit { + pub tcp: Option, + pub ws: Option, + pub http: Option, +} + +#[derive(Copy, Clone, Deserialize)] +pub struct LeakyBucketSetting { + pub rps: f64, + pub burst: f64, } /// Response which holds compute node's `host:port` pair. diff --git a/proxy/src/control_plane/mod.rs b/proxy/src/control_plane/mod.rs index ad10cf4257..ed83e98bfe 100644 --- a/proxy/src/control_plane/mod.rs +++ b/proxy/src/control_plane/mod.rs @@ -11,6 +11,8 @@ pub(crate) mod errors; use std::sync::Arc; +use messages::EndpointRateLimitConfig; + use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list}; @@ -18,8 +20,9 @@ use crate::cache::{Cached, TimedLru}; use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo}; -use crate::intern::{AccountIdInt, ProjectIdInt}; +use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt}; use crate::protocol2::ConnectionInfoExtra; +use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig}; use crate::types::{EndpointCacheKey, EndpointId, RoleName}; use crate::{compute, scram}; @@ -56,6 +59,8 @@ pub(crate) struct AuthInfo { pub(crate) account_id: Option, /// Are public connections or VPC connections blocked? pub(crate) access_blocker_flags: AccessBlockerFlags, + /// The rate limits for this endpoint. + pub(crate) rate_limits: EndpointRateLimitConfig, } /// Info for establishing a connection to a compute node. @@ -101,6 +106,8 @@ pub struct EndpointAccessControl { pub allowed_ips: Arc>, pub allowed_vpce: Arc>, pub flags: AccessBlockerFlags, + + pub rate_limits: EndpointRateLimitConfig, } impl EndpointAccessControl { @@ -139,6 +146,36 @@ impl EndpointAccessControl { Ok(()) } + + pub fn connection_attempt_rate_limit( + &self, + ctx: &RequestContext, + endpoint: &EndpointId, + rate_limiter: &EndpointRateLimiter, + ) -> Result<(), AuthError> { + let endpoint = EndpointIdInt::from(endpoint); + + let limits = &self.rate_limits.connection_attempts; + let config = match ctx.protocol() { + crate::metrics::Protocol::Http => limits.http, + crate::metrics::Protocol::Ws => limits.ws, + crate::metrics::Protocol::Tcp => limits.tcp, + crate::metrics::Protocol::SniRouter => return Ok(()), + }; + let config = config.and_then(|config| { + if config.rps <= 0.0 || config.burst <= 0.0 { + return None; + } + + Some(LeakyBucketConfig::new(config.rps, config.burst)) + }); + + if !rate_limiter.check(endpoint, config, 1) { + return Err(AuthError::too_many_connections()); + } + + Ok(()) + } } /// This will allocate per each call, but the http requests alone diff --git a/proxy/src/rate_limiter/leaky_bucket.rs b/proxy/src/rate_limiter/leaky_bucket.rs index 0c79b5e92f..f7e54ebfe7 100644 --- a/proxy/src/rate_limiter/leaky_bucket.rs +++ b/proxy/src/rate_limiter/leaky_bucket.rs @@ -69,9 +69,8 @@ pub struct LeakyBucketConfig { pub max: f64, } -#[cfg(test)] impl LeakyBucketConfig { - pub(crate) fn new(rps: f64, max: f64) -> Self { + pub fn new(rps: f64, max: f64) -> Self { assert!(rps > 0.0, "rps must be positive"); assert!(max > 0.0, "max must be positive"); Self { rps, max } diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index 9d700c1b52..0cd539188a 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -12,11 +12,10 @@ use rand::{Rng, SeedableRng}; use tokio::time::{Duration, Instant}; use tracing::info; +use super::LeakyBucketConfig; use crate::ext::LockExt; use crate::intern::EndpointIdInt; -use super::LeakyBucketConfig; - pub struct GlobalRateLimiter { data: Vec, info: Vec, diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 316e038344..26269d0a6e 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -68,17 +68,20 @@ impl PoolingBackend { self.config.authentication_config.is_vpc_acccess_proxy, )?; - let ep = EndpointIdInt::from(&user_info.endpoint); - let rate_limit_config = None; - if !self.endpoint_rate_limiter.check(ep, rate_limit_config, 1) { - return Err(AuthError::too_many_connections()); - } + access_control.connection_attempt_rate_limit( + ctx, + &user_info.endpoint, + &self.endpoint_rate_limiter, + )?; + let role_access = backend.get_role_secret(ctx).await?; let Some(secret) = role_access.secret else { // If we don't have an authentication secret, for the http flow we can just return an error. info!("authentication info not found"); return Err(AuthError::password_failed(&*user_info.user)); }; + + let ep = EndpointIdInt::from(&user_info.endpoint); let auth_outcome = crate::auth::validate_password_and_exchange( &self.config.authentication_config.thread_pool, ep,