diff --git a/libs/metrics/src/hll.rs b/libs/metrics/src/hll.rs index 93f6a2b7cc..1a7d7a7e44 100644 --- a/libs/metrics/src/hll.rs +++ b/libs/metrics/src/hll.rs @@ -107,7 +107,7 @@ impl MetricType for HyperLogLogState { } impl HyperLogLogState { - pub fn measure(&self, item: &impl Hash) { + pub fn measure(&self, item: &(impl Hash + ?Sized)) { // changing the hasher will break compatibility with previous measurements. self.record(BuildHasherDefault::::default().hash_one(item)); } diff --git a/libs/utils/src/leaky_bucket.rs b/libs/utils/src/leaky_bucket.rs index 2398f92766..17e96bd0a9 100644 --- a/libs/utils/src/leaky_bucket.rs +++ b/libs/utils/src/leaky_bucket.rs @@ -28,6 +28,7 @@ use std::time::Duration; use tokio::sync::Notify; use tokio::time::Instant; +#[derive(Clone, Copy)] pub struct LeakyBucketConfig { /// This is the "time cost" of a single request unit. /// Should loosely represent how long it takes to handle a request unit in active resource time. diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index 8c892d90a0..735cb52f47 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -4,38 +4,31 @@ mod hacks; pub mod jwt; pub mod local; -use std::net::IpAddr; use std::sync::Arc; pub use console_redirect::ConsoleRedirectBackend; pub(crate) use console_redirect::ConsoleRedirectError; -use ipnet::{Ipv4Net, Ipv6Net}; use local::LocalBackend; use postgres_client::config::AuthKeys; use serde::{Deserialize, Serialize}; use tokio::io::{AsyncRead, AsyncWrite}; -use tracing::{debug, info, warn}; +use tracing::{debug, info}; -use crate::auth::credentials::check_peer_addr_is_in_list; -use crate::auth::{ - self, AuthError, ComputeUserInfoMaybeEndpoint, IpPattern, validate_password_and_exchange, -}; +use crate::auth::{self, AuthError, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange}; use crate::cache::Cached; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::client::ControlPlaneClient; use crate::control_plane::errors::GetAuthInfoError; use crate::control_plane::{ - self, AccessBlockerFlags, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps, - CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi, + self, AccessBlockerFlags, AuthSecret, CachedNodeInfo, ControlPlaneApi, EndpointAccessControl, + RoleAccessControl, }; use crate::intern::EndpointIdInt; -use crate::metrics::Metrics; use crate::pqproto::BeMessage; -use crate::protocol2::ConnectionInfoExtra; use crate::proxy::NeonOptions; use crate::proxy::connect_compute::ComputeConnectBackend; -use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter}; +use crate::rate_limiter::EndpointRateLimiter; use crate::stream::Stream; use crate::types::{EndpointCacheKey, EndpointId, RoleName}; use crate::{scram, stream}; @@ -201,78 +194,6 @@ impl TryFrom for ComputeUserInfo { } } -#[derive(PartialEq, PartialOrd, Hash, Eq, Ord, Debug, Copy, Clone)] -pub struct MaskedIp(IpAddr); - -impl MaskedIp { - fn new(value: IpAddr, prefix: u8) -> Self { - match value { - IpAddr::V4(v4) => Self(IpAddr::V4( - Ipv4Net::new(v4, prefix).map_or(v4, |x| x.trunc().addr()), - )), - IpAddr::V6(v6) => Self(IpAddr::V6( - Ipv6Net::new(v6, prefix).map_or(v6, |x| x.trunc().addr()), - )), - } - } -} - -// This can't be just per IP because that would limit some PaaS that share IP addresses -pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, MaskedIp)>; - -impl AuthenticationConfig { - pub(crate) fn check_rate_limit( - &self, - ctx: &RequestContext, - secret: AuthSecret, - endpoint: &EndpointId, - is_cleartext: bool, - ) -> auth::Result { - // we have validated the endpoint exists, so let's intern it. - let endpoint_int = EndpointIdInt::from(endpoint.normalize()); - - // only count the full hash count if password hack or websocket flow. - // in other words, if proxy needs to run the hashing - let password_weight = if is_cleartext { - match &secret { - #[cfg(any(test, feature = "testing"))] - AuthSecret::Md5(_) => 1, - AuthSecret::Scram(s) => s.iterations + 1, - } - } else { - // validating scram takes just 1 hmac_sha_256 operation. - 1 - }; - - let limit_not_exceeded = self.rate_limiter.check( - ( - endpoint_int, - MaskedIp::new(ctx.peer_addr(), self.rate_limit_ip_subnet), - ), - password_weight, - ); - - if !limit_not_exceeded { - warn!( - enabled = self.rate_limiter_enabled, - "rate limiting authentication" - ); - Metrics::get().proxy.requests_auth_rate_limits_total.inc(); - Metrics::get() - .proxy - .endpoints_auth_rate_limits - .get_metric() - .measure(endpoint); - - if self.rate_limiter_enabled { - return Err(auth::AuthError::too_many_connections()); - } - } - - Ok(secret) - } -} - /// True to its name, this function encapsulates our current auth trade-offs. /// Here, we choose the appropriate auth flow based on circumstances. /// @@ -285,7 +206,7 @@ async fn auth_quirks( allow_cleartext: bool, config: &'static AuthenticationConfig, endpoint_rate_limiter: Arc, -) -> auth::Result<(ComputeCredentials, Option>)> { +) -> auth::Result { // If there's no project so far, that entails that client doesn't // support SNI or other means of passing the endpoint (project) name. // We now expect to see a very specific payload in the place of password. @@ -301,55 +222,27 @@ async fn auth_quirks( debug!("fetching authentication info and allowlists"); - // check allowed list - let allowed_ips = if config.ip_allowlist_check_enabled { - let allowed_ips = api.get_allowed_ips(ctx, &info).await?; - if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) { - return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr())); - } - allowed_ips - } else { - Cached::new_uncached(Arc::new(vec![])) - }; + let access_controls = api + .get_endpoint_access_control(ctx, &info.endpoint, &info.user) + .await?; - // check if a VPC endpoint ID is coming in and if yes, if it's allowed - let access_blocks = api.get_block_public_or_vpc_access(ctx, &info).await?; - if config.is_vpc_acccess_proxy { - if access_blocks.vpc_access_blocked { - return Err(AuthError::NetworkNotAllowed); - } + access_controls.check( + ctx, + config.ip_allowlist_check_enabled, + config.is_vpc_acccess_proxy, + )?; - let incoming_vpc_endpoint_id = match ctx.extra() { - None => return Err(AuthError::MissingEndpointName), - Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(), - Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), - }; - let allowed_vpc_endpoint_ids = api.get_allowed_vpc_endpoint_ids(ctx, &info).await?; - // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. - if !allowed_vpc_endpoint_ids.is_empty() - && !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id) - { - return Err(AuthError::vpc_endpoint_id_not_allowed( - incoming_vpc_endpoint_id, - )); - } - } else if access_blocks.public_access_blocked { - return Err(AuthError::NetworkNotAllowed); - } - - if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) { + let endpoint = EndpointIdInt::from(&info.endpoint); + let rate_limit_config = None; + if !endpoint_rate_limiter.check(endpoint, rate_limit_config, 1) { return Err(AuthError::too_many_connections()); } - let cached_secret = api.get_role_secret(ctx, &info).await?; - let (cached_entry, secret) = cached_secret.take_value(); + let role_access = api + .get_role_access_control(ctx, &info.endpoint, &info.user) + .await?; - let secret = if let Some(secret) = secret { - config.check_rate_limit( - ctx, - secret, - &info.endpoint, - unauthenticated_password.is_some() || allow_cleartext, - )? + let secret = if let Some(secret) = role_access.secret { + secret } else { // If we don't have an authentication secret, we mock one to // prevent malicious probing (possible due to missing protocol steps). @@ -369,14 +262,8 @@ async fn auth_quirks( ) .await { - Ok(keys) => Ok((keys, Some(allowed_ips.as_ref().clone()))), - Err(e) => { - if e.is_password_failed() { - // The password could have been changed, so we invalidate the cache. - cached_entry.invalidate(); - } - Err(e) - } + Ok(keys) => Ok(keys), + Err(e) => Err(e), } } @@ -439,7 +326,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> { allow_cleartext: bool, config: &'static AuthenticationConfig, endpoint_rate_limiter: Arc, - ) -> auth::Result<(Backend<'a, ComputeCredentials>, Option>)> { + ) -> auth::Result> { let res = match self { Self::ControlPlane(api, user_info) => { debug!( @@ -448,17 +335,35 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> { "performing authentication using the console" ); - let (credentials, ip_allowlist) = auth_quirks( + let auth_res = auth_quirks( ctx, &*api, - user_info, + user_info.clone(), client, allow_cleartext, config, endpoint_rate_limiter, ) - .await?; - Ok((Backend::ControlPlane(api, credentials), ip_allowlist)) + .await; + match auth_res { + Ok(credentials) => Ok(Backend::ControlPlane(api, credentials)), + Err(e) => { + // The password could have been changed, so we invalidate the cache. + // We should only invalidate the cache if the TTL might have expired. + if e.is_password_failed() { + #[allow(irrefutable_let_patterns)] + if let ControlPlaneClient::ProxyV1(api) = &*api { + if let Some(ep) = &user_info.endpoint_id { + api.caches + .project_info + .maybe_invalidate_role_secret(ep, &user_info.user); + } + } + } + + Err(e) + } + } } Self::Local(_) => { return Err(auth::AuthError::bad_auth_method("invalid for local proxy")); @@ -475,44 +380,30 @@ impl Backend<'_, ComputeUserInfo> { pub(crate) async fn get_role_secret( &self, ctx: &RequestContext, - ) -> Result { - match self { - Self::ControlPlane(api, user_info) => api.get_role_secret(ctx, user_info).await, - Self::Local(_) => Ok(Cached::new_uncached(None)), - } - } - - pub(crate) async fn get_allowed_ips( - &self, - ctx: &RequestContext, - ) -> Result { - match self { - Self::ControlPlane(api, user_info) => api.get_allowed_ips(ctx, user_info).await, - Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))), - } - } - - pub(crate) async fn get_allowed_vpc_endpoint_ids( - &self, - ctx: &RequestContext, - ) -> Result { + ) -> Result { match self { Self::ControlPlane(api, user_info) => { - api.get_allowed_vpc_endpoint_ids(ctx, user_info).await + api.get_role_access_control(ctx, &user_info.endpoint, &user_info.user) + .await } - Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))), + Self::Local(_) => Ok(RoleAccessControl { secret: None }), } } - pub(crate) async fn get_block_public_or_vpc_access( + pub(crate) async fn get_endpoint_access_control( &self, ctx: &RequestContext, - ) -> Result { + ) -> Result { match self { Self::ControlPlane(api, user_info) => { - api.get_block_public_or_vpc_access(ctx, user_info).await + api.get_endpoint_access_control(ctx, &user_info.endpoint, &user_info.user) + .await } - Self::Local(_) => Ok(Cached::new_uncached(AccessBlockerFlags::default())), + Self::Local(_) => Ok(EndpointAccessControl { + allowed_ips: Arc::new(vec![]), + allowed_vpce: Arc::new(vec![]), + flags: AccessBlockerFlags::default(), + }), } } } @@ -541,9 +432,7 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials> { mod tests { #![allow(clippy::unimplemented, clippy::unwrap_used)] - use std::net::IpAddr; use std::sync::Arc; - use std::time::Duration; use bytes::BytesMut; use control_plane::AuthSecret; @@ -554,18 +443,16 @@ mod tests { use postgres_protocol::message::frontend; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; + use super::auth_quirks; use super::jwt::JwkCache; - use super::{AuthRateLimiter, auth_quirks}; - use crate::auth::backend::MaskedIp; use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern}; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::{ - self, AccessBlockerFlags, CachedAccessBlockerFlags, CachedAllowedIps, - CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, + self, AccessBlockerFlags, CachedNodeInfo, EndpointAccessControl, RoleAccessControl, }; use crate::proxy::NeonOptions; - use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo}; + use crate::rate_limiter::EndpointRateLimiter; use crate::scram::ServerSecret; use crate::scram::threadpool::ThreadPool; use crate::stream::{PqStream, Stream}; @@ -578,46 +465,34 @@ mod tests { } impl control_plane::ControlPlaneApi for Auth { - async fn get_role_secret( + async fn get_role_access_control( &self, _ctx: &RequestContext, - _user_info: &super::ComputeUserInfo, - ) -> Result { - Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone()))) + _endpoint: &crate::types::EndpointId, + _role: &crate::types::RoleName, + ) -> Result { + Ok(RoleAccessControl { + secret: Some(self.secret.clone()), + }) } - async fn get_allowed_ips( + async fn get_endpoint_access_control( &self, _ctx: &RequestContext, - _user_info: &super::ComputeUserInfo, - ) -> Result { - Ok(CachedAllowedIps::new_uncached(Arc::new(self.ips.clone()))) - } - - async fn get_allowed_vpc_endpoint_ids( - &self, - _ctx: &RequestContext, - _user_info: &super::ComputeUserInfo, - ) -> Result { - Ok(CachedAllowedVpcEndpointIds::new_uncached(Arc::new( - self.vpc_endpoint_ids.clone(), - ))) - } - - async fn get_block_public_or_vpc_access( - &self, - _ctx: &RequestContext, - _user_info: &super::ComputeUserInfo, - ) -> Result { - Ok(CachedAccessBlockerFlags::new_uncached( - self.access_blocker_flags.clone(), - )) + _endpoint: &crate::types::EndpointId, + _role: &crate::types::RoleName, + ) -> Result { + Ok(EndpointAccessControl { + allowed_ips: Arc::new(self.ips.clone()), + allowed_vpce: Arc::new(self.vpc_endpoint_ids.clone()), + flags: self.access_blocker_flags, + }) } async fn get_endpoint_jwks( &self, _ctx: &RequestContext, - _endpoint: crate::types::EndpointId, + _endpoint: &crate::types::EndpointId, ) -> Result, control_plane::errors::GetEndpointJwksError> { unimplemented!() @@ -636,9 +511,6 @@ mod tests { jwks_cache: JwkCache::default(), thread_pool: ThreadPool::new(1), scram_protocol_timeout: std::time::Duration::from_secs(5), - rate_limiter_enabled: true, - rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET), - rate_limit_ip_subnet: 64, ip_allowlist_check_enabled: true, is_vpc_acccess_proxy: false, is_auth_broker: false, @@ -655,51 +527,6 @@ mod tests { } } - #[test] - fn masked_ip() { - let ip_a = IpAddr::V4([127, 0, 0, 1].into()); - let ip_b = IpAddr::V4([127, 0, 0, 2].into()); - let ip_c = IpAddr::V4([192, 168, 1, 101].into()); - let ip_d = IpAddr::V4([192, 168, 1, 102].into()); - let ip_e = IpAddr::V6("abcd:abcd:abcd:abcd:abcd:abcd:abcd:abcd".parse().unwrap()); - let ip_f = IpAddr::V6("abcd:abcd:abcd:abcd:1234:abcd:abcd:abcd".parse().unwrap()); - - assert_ne!(MaskedIp::new(ip_a, 64), MaskedIp::new(ip_b, 64)); - assert_ne!(MaskedIp::new(ip_a, 32), MaskedIp::new(ip_b, 32)); - assert_eq!(MaskedIp::new(ip_a, 30), MaskedIp::new(ip_b, 30)); - assert_eq!(MaskedIp::new(ip_c, 30), MaskedIp::new(ip_d, 30)); - - assert_ne!(MaskedIp::new(ip_e, 128), MaskedIp::new(ip_f, 128)); - assert_eq!(MaskedIp::new(ip_e, 64), MaskedIp::new(ip_f, 64)); - } - - #[test] - fn test_default_auth_rate_limit_set() { - // these values used to exceed u32::MAX - assert_eq!( - RateBucketInfo::DEFAULT_AUTH_SET, - [ - RateBucketInfo { - interval: Duration::from_secs(1), - max_rpi: 1000 * 4096, - }, - RateBucketInfo { - interval: Duration::from_secs(60), - max_rpi: 600 * 4096 * 60, - }, - RateBucketInfo { - interval: Duration::from_secs(600), - max_rpi: 300 * 4096 * 600, - } - ] - ); - - for x in RateBucketInfo::DEFAULT_AUTH_SET { - let y = x.to_string().parse().unwrap(); - assert_eq!(x, y); - } - } - #[tokio::test] async fn auth_quirks_scram() { let (mut client, server) = tokio::io::duplex(1024); @@ -888,7 +715,7 @@ mod tests { .await .unwrap(); - assert_eq!(creds.0.info.endpoint, "my-endpoint"); + assert_eq!(creds.info.endpoint, "my-endpoint"); handle.await.unwrap(); } diff --git a/proxy/src/binary/local_proxy.rs b/proxy/src/binary/local_proxy.rs index a566383390..ba10fce7b4 100644 --- a/proxy/src/binary/local_proxy.rs +++ b/proxy/src/binary/local_proxy.rs @@ -32,9 +32,7 @@ use crate::ext::TaskExt; use crate::http::health_server::AppMetrics; use crate::intern::RoleNameInt; use crate::metrics::{Metrics, ThreadPoolMetrics}; -use crate::rate_limiter::{ - BucketRateLimiter, EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo, -}; +use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo}; use crate::scram::threadpool::ThreadPool; use crate::serverless::cancel_set::CancelSet; use crate::serverless::{self, GlobalConnPoolOptions}; @@ -69,15 +67,6 @@ struct LocalProxyCliArgs { /// Can be given multiple times for different bucket sizes. #[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)] user_rps_limit: Vec, - /// Whether the auth rate limiter actually takes effect (for testing) - #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] - auth_rate_limit_enabled: bool, - /// Authentication rate limiter max number of hashes per second. - #[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)] - auth_rate_limit: Vec, - /// The IP subnet to use when considering whether two IP addresses are considered the same. - #[clap(long, default_value_t = 64)] - auth_rate_limit_ip_subnet: u8, /// Whether to retry the connection to the compute node #[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)] connect_to_compute_retry: String, @@ -282,9 +271,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig jwks_cache: JwkCache::default(), thread_pool: ThreadPool::new(0), scram_protocol_timeout: Duration::from_secs(10), - rate_limiter_enabled: false, - rate_limiter: BucketRateLimiter::new(vec![]), - rate_limit_ip_subnet: 64, ip_allowlist_check_enabled: true, is_vpc_acccess_proxy: false, is_auth_broker: false, diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 9a3903ba9a..dcae263647 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -20,7 +20,7 @@ use utils::sentry_init::init_sentry; use utils::{project_build_tag, project_git_version}; use crate::auth::backend::jwt::JwkCache; -use crate::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned}; +use crate::auth::backend::{ConsoleRedirectBackend, MaybeOwned}; use crate::cancellation::{CancellationHandler, handle_cancel_messages}; use crate::config::{ self, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig, ProjectInfoCacheOptions, @@ -29,9 +29,7 @@ use crate::config::{ use crate::context::parquet::ParquetUploadArgs; use crate::http::health_server::AppMetrics; use crate::metrics::Metrics; -use crate::rate_limiter::{ - EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo, WakeComputeRateLimiter, -}; +use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::redis::kv_ops::RedisKVClient; use crate::redis::{elasticache, notifications}; @@ -154,15 +152,6 @@ struct ProxyCliArgs { /// Wake compute rate limiter max number of requests per second. #[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)] wake_compute_limit: Vec, - /// Whether the auth rate limiter actually takes effect (for testing) - #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] - auth_rate_limit_enabled: bool, - /// Authentication rate limiter max number of hashes per second. - #[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)] - auth_rate_limit: Vec, - /// The IP subnet to use when considering whether two IP addresses are considered the same. - #[clap(long, default_value_t = 64)] - auth_rate_limit_ip_subnet: u8, /// Redis rate limiter max number of requests per second. #[clap(long, default_values_t = RateBucketInfo::DEFAULT_REDIS_SET)] redis_rps_limit: Vec, @@ -410,22 +399,9 @@ pub async fn run() -> anyhow::Result<()> { Some(tx_cancel), )); - // bit of a hack - find the min rps and max rps supported and turn it into - // leaky bucket config instead - let max = args - .endpoint_rps_limit - .iter() - .map(|x| x.rps()) - .max_by(f64::total_cmp) - .unwrap_or(EndpointRateLimiter::DEFAULT.max); - let rps = args - .endpoint_rps_limit - .iter() - .map(|x| x.rps()) - .min_by(f64::total_cmp) - .unwrap_or(EndpointRateLimiter::DEFAULT.rps); let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards( - LeakyBucketConfig { rps, max }, + RateBucketInfo::to_leaky_bucket(&args.endpoint_rps_limit) + .unwrap_or(EndpointRateLimiter::DEFAULT), 64, )); @@ -678,9 +654,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { jwks_cache: JwkCache::default(), thread_pool, scram_protocol_timeout: args.scram_protocol_timeout, - rate_limiter_enabled: args.auth_rate_limit_enabled, - rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()), - rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet, ip_allowlist_check_enabled: !args.is_private_access_proxy, is_vpc_acccess_proxy: args.is_private_access_proxy, is_auth_broker: args.is_auth_broker, diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index 60678b034d..81c88e3ddd 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -1,30 +1,25 @@ -use std::collections::HashSet; +use std::collections::{HashMap, HashSet, hash_map}; use std::convert::Infallible; -use std::sync::Arc; use std::sync::atomic::AtomicU64; use std::time::Duration; use async_trait::async_trait; use clashmap::ClashMap; +use clashmap::mapref::one::Ref; use rand::{Rng, thread_rng}; -use smol_str::SmolStr; use tokio::sync::Mutex; use tokio::time::Instant; use tracing::{debug, info}; -use super::{Cache, Cached}; -use crate::auth::IpPattern; use crate::config::ProjectInfoCacheOptions; -use crate::control_plane::{AccessBlockerFlags, AuthSecret}; +use crate::control_plane::{EndpointAccessControl, RoleAccessControl}; use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; use crate::types::{EndpointId, RoleName}; #[async_trait] pub(crate) trait ProjectInfoCache { - fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt); - fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec); - fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt); - fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt); + fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt); + fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt); fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt); async fn decrement_active_listeners(&self); async fn increment_active_listeners(&self); @@ -42,6 +37,10 @@ impl Entry { value, } } + + pub(crate) fn get(&self, valid_since: Instant) -> Option<&T> { + (valid_since < self.created_at).then_some(&self.value) + } } impl From for Entry { @@ -50,101 +49,32 @@ impl From for Entry { } } -#[derive(Default)] struct EndpointInfo { - secret: std::collections::HashMap>>, - allowed_ips: Option>>>, - block_public_or_vpc_access: Option>, - allowed_vpc_endpoint_ids: Option>>>, + role_controls: HashMap>, + controls: Option>, } impl EndpointInfo { - fn check_ignore_cache(ignore_cache_since: Option, created_at: Instant) -> bool { - match ignore_cache_since { - None => false, - Some(t) => t < created_at, - } - } pub(crate) fn get_role_secret( &self, role_name: RoleNameInt, valid_since: Instant, - ignore_cache_since: Option, - ) -> Option<(Option, bool)> { - if let Some(secret) = self.secret.get(&role_name) { - if valid_since < secret.created_at { - return Some(( - secret.value.clone(), - Self::check_ignore_cache(ignore_cache_since, secret.created_at), - )); - } - } - None + ) -> Option { + let controls = self.role_controls.get(&role_name)?; + controls.get(valid_since).cloned() } - pub(crate) fn get_allowed_ips( - &self, - valid_since: Instant, - ignore_cache_since: Option, - ) -> Option<(Arc>, bool)> { - if let Some(allowed_ips) = &self.allowed_ips { - if valid_since < allowed_ips.created_at { - return Some(( - allowed_ips.value.clone(), - Self::check_ignore_cache(ignore_cache_since, allowed_ips.created_at), - )); - } - } - None - } - pub(crate) fn get_allowed_vpc_endpoint_ids( - &self, - valid_since: Instant, - ignore_cache_since: Option, - ) -> Option<(Arc>, bool)> { - if let Some(allowed_vpc_endpoint_ids) = &self.allowed_vpc_endpoint_ids { - if valid_since < allowed_vpc_endpoint_ids.created_at { - return Some(( - allowed_vpc_endpoint_ids.value.clone(), - Self::check_ignore_cache( - ignore_cache_since, - allowed_vpc_endpoint_ids.created_at, - ), - )); - } - } - None - } - pub(crate) fn get_block_public_or_vpc_access( - &self, - valid_since: Instant, - ignore_cache_since: Option, - ) -> Option<(AccessBlockerFlags, bool)> { - if let Some(block_public_or_vpc_access) = &self.block_public_or_vpc_access { - if valid_since < block_public_or_vpc_access.created_at { - return Some(( - block_public_or_vpc_access.value.clone(), - Self::check_ignore_cache( - ignore_cache_since, - block_public_or_vpc_access.created_at, - ), - )); - } - } - None + pub(crate) fn get_controls(&self, valid_since: Instant) -> Option { + let controls = self.controls.as_ref()?; + controls.get(valid_since).cloned() } - pub(crate) fn invalidate_allowed_ips(&mut self) { - self.allowed_ips = None; - } - pub(crate) fn invalidate_allowed_vpc_endpoint_ids(&mut self) { - self.allowed_vpc_endpoint_ids = None; - } - pub(crate) fn invalidate_block_public_or_vpc_access(&mut self) { - self.block_public_or_vpc_access = None; + pub(crate) fn invalidate_endpoint(&mut self) { + self.controls = None; } + pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) { - self.secret.remove(&role_name); + self.role_controls.remove(&role_name); } } @@ -170,34 +100,22 @@ pub struct ProjectInfoCacheImpl { #[async_trait] impl ProjectInfoCache for ProjectInfoCacheImpl { - fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec) { - info!( - "invalidating allowed vpc endpoint ids for projects `{}`", - project_ids - .iter() - .map(|id| id.to_string()) - .collect::>() - .join(", ") - ); - for project_id in project_ids { - let endpoints = self - .project2ep - .get(&project_id) - .map(|kv| kv.value().clone()) - .unwrap_or_default(); - for endpoint_id in endpoints { - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_allowed_vpc_endpoint_ids(); - } + fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) { + info!("invalidating endpoint access for project `{project_id}`"); + let endpoints = self + .project2ep + .get(&project_id) + .map(|kv| kv.value().clone()) + .unwrap_or_default(); + for endpoint_id in endpoints { + if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { + endpoint_info.invalidate_endpoint(); } } } - fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt) { - info!( - "invalidating allowed vpc endpoint ids for org `{}`", - account_id - ); + fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) { + info!("invalidating endpoint access for org `{account_id}`"); let endpoints = self .account2ep .get(&account_id) @@ -205,41 +123,11 @@ impl ProjectInfoCache for ProjectInfoCacheImpl { .unwrap_or_default(); for endpoint_id in endpoints { if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_allowed_vpc_endpoint_ids(); + endpoint_info.invalidate_endpoint(); } } } - fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt) { - info!( - "invalidating block public or vpc access for project `{}`", - project_id - ); - let endpoints = self - .project2ep - .get(&project_id) - .map(|kv| kv.value().clone()) - .unwrap_or_default(); - for endpoint_id in endpoints { - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_block_public_or_vpc_access(); - } - } - } - - fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt) { - info!("invalidating allowed ips for project `{}`", project_id); - let endpoints = self - .project2ep - .get(&project_id) - .map(|kv| kv.value().clone()) - .unwrap_or_default(); - for endpoint_id in endpoints { - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_allowed_ips(); - } - } - } fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt) { info!( "invalidating role secret for project_id `{}` and role_name `{}`", @@ -256,6 +144,7 @@ impl ProjectInfoCache for ProjectInfoCacheImpl { } } } + async fn decrement_active_listeners(&self) { let mut listeners_guard = self.active_listeners_lock.lock().await; if *listeners_guard == 0 { @@ -293,155 +182,71 @@ impl ProjectInfoCacheImpl { } } + fn get_endpoint_cache( + &self, + endpoint_id: &EndpointId, + ) -> Option> { + let endpoint_id = EndpointIdInt::get(endpoint_id)?; + self.cache.get(&endpoint_id) + } + pub(crate) fn get_role_secret( &self, endpoint_id: &EndpointId, role_name: &RoleName, - ) -> Option>> { - let endpoint_id = EndpointIdInt::get(endpoint_id)?; + ) -> Option { + let valid_since = self.get_cache_times(); let role_name = RoleNameInt::get(role_name)?; - let (valid_since, ignore_cache_since) = self.get_cache_times(); - let endpoint_info = self.cache.get(&endpoint_id)?; - let (value, ignore_cache) = - endpoint_info.get_role_secret(role_name, valid_since, ignore_cache_since)?; - if !ignore_cache { - let cached = Cached { - token: Some(( - self, - CachedLookupInfo::new_role_secret(endpoint_id, role_name), - )), - value, - }; - return Some(cached); - } - Some(Cached::new_uncached(value)) - } - pub(crate) fn get_allowed_ips( - &self, - endpoint_id: &EndpointId, - ) -> Option>>> { - let endpoint_id = EndpointIdInt::get(endpoint_id)?; - let (valid_since, ignore_cache_since) = self.get_cache_times(); - let endpoint_info = self.cache.get(&endpoint_id)?; - let value = endpoint_info.get_allowed_ips(valid_since, ignore_cache_since); - let (value, ignore_cache) = value?; - if !ignore_cache { - let cached = Cached { - token: Some((self, CachedLookupInfo::new_allowed_ips(endpoint_id))), - value, - }; - return Some(cached); - } - Some(Cached::new_uncached(value)) - } - pub(crate) fn get_allowed_vpc_endpoint_ids( - &self, - endpoint_id: &EndpointId, - ) -> Option>>> { - let endpoint_id = EndpointIdInt::get(endpoint_id)?; - let (valid_since, ignore_cache_since) = self.get_cache_times(); - let endpoint_info = self.cache.get(&endpoint_id)?; - let value = endpoint_info.get_allowed_vpc_endpoint_ids(valid_since, ignore_cache_since); - let (value, ignore_cache) = value?; - if !ignore_cache { - let cached = Cached { - token: Some(( - self, - CachedLookupInfo::new_allowed_vpc_endpoint_ids(endpoint_id), - )), - value, - }; - return Some(cached); - } - Some(Cached::new_uncached(value)) - } - pub(crate) fn get_block_public_or_vpc_access( - &self, - endpoint_id: &EndpointId, - ) -> Option> { - let endpoint_id = EndpointIdInt::get(endpoint_id)?; - let (valid_since, ignore_cache_since) = self.get_cache_times(); - let endpoint_info = self.cache.get(&endpoint_id)?; - let value = endpoint_info.get_block_public_or_vpc_access(valid_since, ignore_cache_since); - let (value, ignore_cache) = value?; - if !ignore_cache { - let cached = Cached { - token: Some(( - self, - CachedLookupInfo::new_block_public_or_vpc_access(endpoint_id), - )), - value, - }; - return Some(cached); - } - Some(Cached::new_uncached(value)) + let endpoint_info = self.get_endpoint_cache(endpoint_id)?; + endpoint_info.get_role_secret(role_name, valid_since) } - pub(crate) fn insert_role_secret( + pub(crate) fn get_endpoint_access( &self, - project_id: ProjectIdInt, - endpoint_id: EndpointIdInt, - role_name: RoleNameInt, - secret: Option, - ) { - if self.cache.len() >= self.config.size { - // If there are too many entries, wait until the next gc cycle. - return; - } - self.insert_project2endpoint(project_id, endpoint_id); - let mut entry = self.cache.entry(endpoint_id).or_default(); - if entry.secret.len() < self.config.max_roles { - entry.secret.insert(role_name, secret.into()); - } + endpoint_id: &EndpointId, + ) -> Option { + let valid_since = self.get_cache_times(); + let endpoint_info = self.get_endpoint_cache(endpoint_id)?; + endpoint_info.get_controls(valid_since) } - pub(crate) fn insert_allowed_ips( - &self, - project_id: ProjectIdInt, - endpoint_id: EndpointIdInt, - allowed_ips: Arc>, - ) { - if self.cache.len() >= self.config.size { - // If there are too many entries, wait until the next gc cycle. - return; - } - self.insert_project2endpoint(project_id, endpoint_id); - self.cache.entry(endpoint_id).or_default().allowed_ips = Some(allowed_ips.into()); - } - pub(crate) fn insert_allowed_vpc_endpoint_ids( + + pub(crate) fn insert_endpoint_access( &self, account_id: Option, project_id: ProjectIdInt, endpoint_id: EndpointIdInt, - allowed_vpc_endpoint_ids: Arc>, + role_name: RoleNameInt, + controls: EndpointAccessControl, + role_controls: RoleAccessControl, ) { - if self.cache.len() >= self.config.size { - // If there are too many entries, wait until the next gc cycle. - return; - } if let Some(account_id) = account_id { self.insert_account2endpoint(account_id, endpoint_id); } self.insert_project2endpoint(project_id, endpoint_id); - self.cache - .entry(endpoint_id) - .or_default() - .allowed_vpc_endpoint_ids = Some(allowed_vpc_endpoint_ids.into()); - } - pub(crate) fn insert_block_public_or_vpc_access( - &self, - project_id: ProjectIdInt, - endpoint_id: EndpointIdInt, - access_blockers: AccessBlockerFlags, - ) { + if self.cache.len() >= self.config.size { // If there are too many entries, wait until the next gc cycle. return; } - self.insert_project2endpoint(project_id, endpoint_id); - self.cache - .entry(endpoint_id) - .or_default() - .block_public_or_vpc_access = Some(access_blockers.into()); + + let controls = Entry::from(controls); + let role_controls = Entry::from(role_controls); + + match self.cache.entry(endpoint_id) { + clashmap::Entry::Vacant(e) => { + e.insert(EndpointInfo { + role_controls: HashMap::from_iter([(role_name, role_controls)]), + controls: Some(controls), + }); + } + clashmap::Entry::Occupied(mut e) => { + let ep = e.get_mut(); + ep.controls = Some(controls); + if ep.role_controls.len() < self.config.max_roles { + ep.role_controls.insert(role_name, role_controls); + } + } + } } fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) { @@ -452,6 +257,7 @@ impl ProjectInfoCacheImpl { .insert(project_id, HashSet::from([endpoint_id])); } } + fn insert_account2endpoint(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt) { if let Some(mut endpoints) = self.account2ep.get_mut(&account_id) { endpoints.insert(endpoint_id); @@ -460,21 +266,57 @@ impl ProjectInfoCacheImpl { .insert(account_id, HashSet::from([endpoint_id])); } } - fn get_cache_times(&self) -> (Instant, Option) { - let mut valid_since = Instant::now() - self.config.ttl; - // Only ignore cache if ttl is disabled. + + fn ignore_ttl_since(&self) -> Option { let ttl_disabled_since_us = self .ttl_disabled_since_us .load(std::sync::atomic::Ordering::Relaxed); - let ignore_cache_since = if ttl_disabled_since_us == u64::MAX { - None - } else { - let ignore_cache_since = self.start_time + Duration::from_micros(ttl_disabled_since_us); + + if ttl_disabled_since_us == u64::MAX { + return None; + } + + Some(self.start_time + Duration::from_micros(ttl_disabled_since_us)) + } + + fn get_cache_times(&self) -> Instant { + let mut valid_since = Instant::now() - self.config.ttl; + if let Some(ignore_ttl_since) = self.ignore_ttl_since() { // We are fine if entry is not older than ttl or was added before we are getting notifications. - valid_since = valid_since.min(ignore_cache_since); - Some(ignore_cache_since) + valid_since = valid_since.min(ignore_ttl_since); + } + valid_since + } + + pub fn maybe_invalidate_role_secret(&self, endpoint_id: &EndpointId, role_name: &RoleName) { + let Some(endpoint_id) = EndpointIdInt::get(endpoint_id) else { + return; }; - (valid_since, ignore_cache_since) + let Some(role_name) = RoleNameInt::get(role_name) else { + return; + }; + + let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) else { + return; + }; + + let entry = endpoint_info.role_controls.entry(role_name); + let hash_map::Entry::Occupied(role_controls) = entry else { + return; + }; + + let created_at = role_controls.get().created_at; + let expire = match self.ignore_ttl_since() { + // if ignoring TTL, we should still try and roll the password if it's old + // and we the client gave an incorrect password. There could be some lag on the redis channel. + Some(_) => created_at + self.config.ttl < Instant::now(), + // edge case: redis is down, let's be generous and invalidate the cache immediately. + None => true, + }; + + if expire { + role_controls.remove(); + } } pub async fn gc_worker(&self) -> anyhow::Result { @@ -509,84 +351,12 @@ impl ProjectInfoCacheImpl { } } -/// Lookup info for project info cache. -/// This is used to invalidate cache entries. -pub(crate) struct CachedLookupInfo { - /// Search by this key. - endpoint_id: EndpointIdInt, - lookup_type: LookupType, -} - -impl CachedLookupInfo { - pub(self) fn new_role_secret(endpoint_id: EndpointIdInt, role_name: RoleNameInt) -> Self { - Self { - endpoint_id, - lookup_type: LookupType::RoleSecret(role_name), - } - } - pub(self) fn new_allowed_ips(endpoint_id: EndpointIdInt) -> Self { - Self { - endpoint_id, - lookup_type: LookupType::AllowedIps, - } - } - pub(self) fn new_allowed_vpc_endpoint_ids(endpoint_id: EndpointIdInt) -> Self { - Self { - endpoint_id, - lookup_type: LookupType::AllowedVpcEndpointIds, - } - } - pub(self) fn new_block_public_or_vpc_access(endpoint_id: EndpointIdInt) -> Self { - Self { - endpoint_id, - lookup_type: LookupType::BlockPublicOrVpcAccess, - } - } -} - -enum LookupType { - RoleSecret(RoleNameInt), - AllowedIps, - AllowedVpcEndpointIds, - BlockPublicOrVpcAccess, -} - -impl Cache for ProjectInfoCacheImpl { - type Key = SmolStr; - // Value is not really used here, but we need to specify it. - type Value = SmolStr; - - type LookupInfo = CachedLookupInfo; - - fn invalidate(&self, key: &Self::LookupInfo) { - match &key.lookup_type { - LookupType::RoleSecret(role_name) => { - if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { - endpoint_info.invalidate_role_secret(*role_name); - } - } - LookupType::AllowedIps => { - if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { - endpoint_info.invalidate_allowed_ips(); - } - } - LookupType::AllowedVpcEndpointIds => { - if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { - endpoint_info.invalidate_allowed_vpc_endpoint_ids(); - } - } - LookupType::BlockPublicOrVpcAccess => { - if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { - endpoint_info.invalidate_block_public_or_vpc_access(); - } - } - } - } -} - #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; + use crate::control_plane::{AccessBlockerFlags, AuthSecret}; use crate::scram::ServerSecret; use crate::types::ProjectId; @@ -601,6 +371,8 @@ mod tests { }); let project_id: ProjectId = "project".into(); let endpoint_id: EndpointId = "endpoint".into(); + let account_id: Option = None; + let user1: RoleName = "user1".into(); let user2: RoleName = "user2".into(); let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32]))); @@ -609,183 +381,73 @@ mod tests { "127.0.0.1".parse().unwrap(), "127.0.0.2".parse().unwrap(), ]); - cache.insert_role_secret( + + cache.insert_endpoint_access( + account_id, (&project_id).into(), (&endpoint_id).into(), (&user1).into(), - secret1.clone(), + EndpointAccessControl { + allowed_ips: allowed_ips.clone(), + allowed_vpce: Arc::new(vec![]), + flags: AccessBlockerFlags::default(), + }, + RoleAccessControl { + secret: secret1.clone(), + }, ); - cache.insert_role_secret( + + cache.insert_endpoint_access( + account_id, (&project_id).into(), (&endpoint_id).into(), (&user2).into(), - secret2.clone(), - ); - cache.insert_allowed_ips( - (&project_id).into(), - (&endpoint_id).into(), - allowed_ips.clone(), + EndpointAccessControl { + allowed_ips: allowed_ips.clone(), + allowed_vpce: Arc::new(vec![]), + flags: AccessBlockerFlags::default(), + }, + RoleAccessControl { + secret: secret2.clone(), + }, ); let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); - assert!(cached.cached()); - assert_eq!(cached.value, secret1); + assert_eq!(cached.secret, secret1); + let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap(); - assert!(cached.cached()); - assert_eq!(cached.value, secret2); + assert_eq!(cached.secret, secret2); // Shouldn't add more than 2 roles. let user3: RoleName = "user3".into(); let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32]))); - cache.insert_role_secret( + + cache.insert_endpoint_access( + account_id, (&project_id).into(), (&endpoint_id).into(), (&user3).into(), - secret3.clone(), + EndpointAccessControl { + allowed_ips: allowed_ips.clone(), + allowed_vpce: Arc::new(vec![]), + flags: AccessBlockerFlags::default(), + }, + RoleAccessControl { + secret: secret3.clone(), + }, ); + assert!(cache.get_role_secret(&endpoint_id, &user3).is_none()); - let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); - assert!(cached.cached()); - assert_eq!(cached.value, allowed_ips); + let cached = cache.get_endpoint_access(&endpoint_id).unwrap(); + assert_eq!(cached.allowed_ips, allowed_ips); tokio::time::advance(Duration::from_secs(2)).await; let cached = cache.get_role_secret(&endpoint_id, &user1); assert!(cached.is_none()); let cached = cache.get_role_secret(&endpoint_id, &user2); assert!(cached.is_none()); - let cached = cache.get_allowed_ips(&endpoint_id); + let cached = cache.get_endpoint_access(&endpoint_id); assert!(cached.is_none()); } - - #[tokio::test] - async fn test_project_info_cache_invalidations() { - tokio::time::pause(); - let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions { - size: 2, - max_roles: 2, - ttl: Duration::from_secs(1), - gc_interval: Duration::from_secs(600), - })); - cache.clone().increment_active_listeners().await; - tokio::time::advance(Duration::from_secs(2)).await; - - let project_id: ProjectId = "project".into(); - let endpoint_id: EndpointId = "endpoint".into(); - let user1: RoleName = "user1".into(); - let user2: RoleName = "user2".into(); - let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32]))); - let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32]))); - let allowed_ips = Arc::new(vec![ - "127.0.0.1".parse().unwrap(), - "127.0.0.2".parse().unwrap(), - ]); - cache.insert_role_secret( - (&project_id).into(), - (&endpoint_id).into(), - (&user1).into(), - secret1.clone(), - ); - cache.insert_role_secret( - (&project_id).into(), - (&endpoint_id).into(), - (&user2).into(), - secret2.clone(), - ); - cache.insert_allowed_ips( - (&project_id).into(), - (&endpoint_id).into(), - allowed_ips.clone(), - ); - - tokio::time::advance(Duration::from_secs(2)).await; - // Nothing should be invalidated. - - let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); - // TTL is disabled, so it should be impossible to invalidate this value. - assert!(!cached.cached()); - assert_eq!(cached.value, secret1); - - cached.invalidate(); // Shouldn't do anything. - let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); - assert_eq!(cached.value, secret1); - - let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap(); - assert!(!cached.cached()); - assert_eq!(cached.value, secret2); - - // The only way to invalidate this value is to invalidate via the api. - cache.invalidate_role_secret_for_project((&project_id).into(), (&user2).into()); - assert!(cache.get_role_secret(&endpoint_id, &user2).is_none()); - - let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); - assert!(!cached.cached()); - assert_eq!(cached.value, allowed_ips); - } - - #[tokio::test] - async fn test_increment_active_listeners_invalidate_added_before() { - tokio::time::pause(); - let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions { - size: 2, - max_roles: 2, - ttl: Duration::from_secs(1), - gc_interval: Duration::from_secs(600), - })); - - let project_id: ProjectId = "project".into(); - let endpoint_id: EndpointId = "endpoint".into(); - let user1: RoleName = "user1".into(); - let user2: RoleName = "user2".into(); - let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32]))); - let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32]))); - let allowed_ips = Arc::new(vec![ - "127.0.0.1".parse().unwrap(), - "127.0.0.2".parse().unwrap(), - ]); - cache.insert_role_secret( - (&project_id).into(), - (&endpoint_id).into(), - (&user1).into(), - secret1.clone(), - ); - cache.clone().increment_active_listeners().await; - tokio::time::advance(Duration::from_millis(100)).await; - cache.insert_role_secret( - (&project_id).into(), - (&endpoint_id).into(), - (&user2).into(), - secret2.clone(), - ); - - // Added before ttl was disabled + ttl should be still cached. - let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); - assert!(cached.cached()); - let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap(); - assert!(cached.cached()); - - tokio::time::advance(Duration::from_secs(1)).await; - // Added before ttl was disabled + ttl should expire. - assert!(cache.get_role_secret(&endpoint_id, &user1).is_none()); - assert!(cache.get_role_secret(&endpoint_id, &user2).is_none()); - - // Added after ttl was disabled + ttl should not be cached. - cache.insert_allowed_ips( - (&project_id).into(), - (&endpoint_id).into(), - allowed_ips.clone(), - ); - let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); - assert!(!cached.cached()); - - tokio::time::advance(Duration::from_secs(1)).await; - // Added before ttl was disabled + ttl still should expire. - assert!(cache.get_role_secret(&endpoint_id, &user1).is_none()); - assert!(cache.get_role_secret(&endpoint_id, &user2).is_none()); - // Shouldn't be invalidated. - - let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); - assert!(!cached.cached()); - assert_eq!(cached.value, allowed_ips); - } } diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index 0bff901376..d26641db46 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -12,8 +12,8 @@ use tokio::net::TcpStream; use tokio::sync::{mpsc, oneshot}; use tracing::{debug, error, info, warn}; +use crate::auth::AuthError; use crate::auth::backend::ComputeUserInfo; -use crate::auth::{AuthError, check_peer_addr_is_in_list}; use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::ControlPlaneApi; @@ -21,7 +21,6 @@ use crate::error::ReportableError; use crate::ext::LockExt; use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind}; use crate::pqproto::CancelKeyData; -use crate::protocol2::ConnectionInfoExtra; use crate::rate_limiter::LeakyBucketRateLimiter; use crate::redis::keys::KeyPrefix; use crate::redis::kv_ops::RedisKVClient; @@ -272,13 +271,7 @@ pub(crate) enum CancelError { #[error("rate limit exceeded")] RateLimit, - #[error("IP is not allowed")] - IpNotAllowed, - - #[error("VPC endpoint id is not allowed to connect")] - VpcEndpointIdNotAllowed, - - #[error("Authentication backend error")] + #[error("Authentication error")] AuthError(#[from] AuthError), #[error("key not found")] @@ -297,10 +290,7 @@ impl ReportableError for CancelError { } CancelError::Postgres(_) => crate::error::ErrorKind::Compute, CancelError::RateLimit => crate::error::ErrorKind::RateLimit, - CancelError::IpNotAllowed - | CancelError::VpcEndpointIdNotAllowed - | CancelError::NotFound => crate::error::ErrorKind::User, - CancelError::AuthError(_) => crate::error::ErrorKind::ControlPlane, + CancelError::NotFound | CancelError::AuthError(_) => crate::error::ErrorKind::User, CancelError::InternalError => crate::error::ErrorKind::Service, } } @@ -422,7 +412,13 @@ impl CancellationHandler { IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()), }; - if !self.limiter.lock_propagate_poison().check(subnet_key, 1) { + + let allowed = { + let rate_limit_config = None; + let limiter = self.limiter.lock_propagate_poison(); + limiter.check(subnet_key, rate_limit_config, 1) + }; + if !allowed { // log only the subnet part of the IP address to know which subnet is rate limited tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}"); Metrics::get() @@ -450,52 +446,13 @@ impl CancellationHandler { return Err(CancelError::NotFound); }; - if check_ip_allowed { - let ip_allowlist = auth_backend - .get_allowed_ips(&ctx, &cancel_closure.user_info) - .await - .map_err(|e| CancelError::AuthError(e.into()))?; - - if !check_peer_addr_is_in_list(&ctx.peer_addr(), &ip_allowlist) { - // log it here since cancel_session could be spawned in a task - tracing::warn!( - "IP is not allowed to cancel the query: {key}, address: {}", - ctx.peer_addr() - ); - return Err(CancelError::IpNotAllowed); - } - } - - // check if a VPC endpoint ID is coming in and if yes, if it's allowed - let access_blocks = auth_backend - .get_block_public_or_vpc_access(&ctx, &cancel_closure.user_info) + let info = &cancel_closure.user_info; + let access_controls = auth_backend + .get_endpoint_access_control(&ctx, &info.endpoint, &info.user) .await .map_err(|e| CancelError::AuthError(e.into()))?; - if check_vpc_allowed { - if access_blocks.vpc_access_blocked { - return Err(CancelError::AuthError(AuthError::NetworkNotAllowed)); - } - - let incoming_vpc_endpoint_id = match ctx.extra() { - None => return Err(CancelError::AuthError(AuthError::MissingVPCEndpointId)), - Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(), - Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), - }; - - let allowed_vpc_endpoint_ids = auth_backend - .get_allowed_vpc_endpoint_ids(&ctx, &cancel_closure.user_info) - .await - .map_err(|e| CancelError::AuthError(e.into()))?; - // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. - if !allowed_vpc_endpoint_ids.is_empty() - && !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id) - { - return Err(CancelError::VpcEndpointIdNotAllowed); - } - } else if access_blocks.public_access_blocked { - return Err(CancelError::VpcEndpointIdNotAllowed); - } + access_controls.check(&ctx, check_ip_allowed, check_vpc_allowed)?; Metrics::get() .proxy diff --git a/proxy/src/config.rs b/proxy/src/config.rs index ad398c122c..a97339df9a 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -7,7 +7,6 @@ use arc_swap::ArcSwapOption; use clap::ValueEnum; use remote_storage::RemoteStorageConfig; -use crate::auth::backend::AuthRateLimiter; use crate::auth::backend::jwt::JwkCache; use crate::control_plane::locks::ApiLocks; use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig}; @@ -65,9 +64,6 @@ pub struct HttpConfig { pub struct AuthenticationConfig { pub thread_pool: Arc, pub scram_protocol_timeout: tokio::time::Duration, - pub rate_limiter_enabled: bool, - pub rate_limiter: AuthRateLimiter, - pub rate_limit_ip_subnet: u8, pub ip_allowlist_check_enabled: bool, pub is_vpc_acccess_proxy: bool, pub jwks_cache: JwkCache, diff --git a/proxy/src/context/mod.rs b/proxy/src/context/mod.rs index de4600951e..24268997ba 100644 --- a/proxy/src/context/mod.rs +++ b/proxy/src/context/mod.rs @@ -370,6 +370,18 @@ impl RequestContext { } } + pub(crate) fn latency_timer_pause_at( + &self, + at: tokio::time::Instant, + waiting_for: Waiting, + ) -> LatencyTimerPause<'_> { + LatencyTimerPause { + ctx: self, + start: at, + waiting_for, + } + } + pub(crate) fn get_proxy_latency(&self) -> LatencyAccumulated { self.0 .try_lock() diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index 2765aaa462..93f4ea6cf7 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -15,7 +15,6 @@ use tracing::{Instrument, debug, info, info_span, warn}; use super::super::messages::{ControlPlaneErrorMessage, GetEndpointAccessControl, WakeCompute}; use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; -use crate::cache::Cached; use crate::context::RequestContext; use crate::control_plane::caches::ApiCaches; use crate::control_plane::errors::{ @@ -24,12 +23,12 @@ use crate::control_plane::errors::{ use crate::control_plane::locks::ApiLocks; use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason}; use crate::control_plane::{ - AccessBlockerFlags, AuthInfo, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps, - CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, NodeInfo, + AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo, + RoleAccessControl, }; -use crate::metrics::{CacheOutcome, Metrics}; +use crate::metrics::Metrics; use crate::rate_limiter::WakeComputeRateLimiter; -use crate::types::{EndpointCacheKey, EndpointId}; +use crate::types::{EndpointCacheKey, EndpointId, RoleName}; use crate::{compute, http, scram}; pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id"); @@ -66,65 +65,34 @@ impl NeonControlPlaneClient { self.endpoint.url().as_str() } - async fn do_get_auth_info( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - if !self - .caches - .endpoints_cache - .is_valid(ctx, &user_info.endpoint.normalize()) - { - // TODO: refactor this because it's weird - // this is a failure to authenticate but we return Ok. - info!("endpoint is not valid, skipping the request"); - return Ok(AuthInfo::default()); - } - self.do_get_auth_req(user_info, &ctx.session_id(), Some(ctx)) - .await - } - async fn do_get_auth_req( &self, - user_info: &ComputeUserInfo, - session_id: &uuid::Uuid, - ctx: Option<&RequestContext>, + ctx: &RequestContext, + endpoint: &EndpointId, + role: &RoleName, ) -> Result { - let request_id: String = session_id.to_string(); - let application_name = if let Some(ctx) = ctx { - ctx.console_application_name() - } else { - "auth_cancellation".to_string() - }; - async { let request = self .endpoint .get_path("get_endpoint_access_control") - .header(X_REQUEST_ID, &request_id) + .header(X_REQUEST_ID, ctx.session_id().to_string()) .header(AUTHORIZATION, format!("Bearer {}", &self.jwt)) - .query(&[("session_id", session_id)]) + .query(&[("session_id", ctx.session_id())]) .query(&[ - ("application_name", application_name.as_str()), - ("endpointish", user_info.endpoint.as_str()), - ("role", user_info.user.as_str()), + ("application_name", ctx.console_application_name().as_str()), + ("endpointish", endpoint.as_str()), + ("role", role.as_str()), ]) .build()?; debug!(url = request.url().as_str(), "sending http request"); let start = Instant::now(); - let response = match ctx { - Some(ctx) => { - let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane); - let rsp = self.endpoint.execute(request).await; - drop(pause); - rsp? - } - None => self.endpoint.execute(request).await?, + let response = { + let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane); + self.endpoint.execute(request).await? }; - info!(duration = ?start.elapsed(), "received http response"); + let body = match parse_body::(response).await { Ok(body) => body, // Error 404 is special: it's ok not to have a secret. @@ -180,7 +148,7 @@ impl NeonControlPlaneClient { async fn do_get_endpoint_jwks( &self, ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, GetEndpointJwksError> { if !self .caches @@ -313,225 +281,104 @@ impl NeonControlPlaneClient { impl super::ControlPlaneApi for NeonControlPlaneClient { #[tracing::instrument(skip_all)] - async fn get_role_secret( + async fn get_role_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - let normalized_ep = &user_info.endpoint.normalize(); - let user = &user_info.user; - if let Some(role_secret) = self + endpoint: &EndpointId, + role: &RoleName, + ) -> Result { + let normalized_ep = &endpoint.normalize(); + if let Some(secret) = self .caches .project_info - .get_role_secret(normalized_ep, user) + .get_role_secret(normalized_ep, role) { - return Ok(role_secret); + return Ok(secret); } - let auth_info = self.do_get_auth_info(ctx, user_info).await?; - let account_id = auth_info.account_id; + + if !self.caches.endpoints_cache.is_valid(ctx, normalized_ep) { + info!("endpoint is not valid, skipping the request"); + return Err(GetAuthInfoError::UnknownEndpoint); + } + + let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?; + + let control = EndpointAccessControl { + allowed_ips: Arc::new(auth_info.allowed_ips), + allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids), + flags: auth_info.access_blocker_flags, + }; + let role_control = RoleAccessControl { + secret: auth_info.secret, + }; + if let Some(project_id) = auth_info.project_id { let normalized_ep_int = normalized_ep.into(); - self.caches.project_info.insert_role_secret( + + self.caches.project_info.insert_endpoint_access( + auth_info.account_id, project_id, normalized_ep_int, - user.into(), - auth_info.secret.clone(), - ); - self.caches.project_info.insert_allowed_ips( - project_id, - normalized_ep_int, - Arc::new(auth_info.allowed_ips), - ); - self.caches.project_info.insert_allowed_vpc_endpoint_ids( - account_id, - project_id, - normalized_ep_int, - Arc::new(auth_info.allowed_vpc_endpoint_ids), - ); - self.caches.project_info.insert_block_public_or_vpc_access( - project_id, - normalized_ep_int, - auth_info.access_blocker_flags, + role.into(), + control, + role_control.clone(), ); ctx.set_project_id(project_id); } - // When we just got a secret, we don't need to invalidate it. - Ok(Cached::new_uncached(auth_info.secret)) + + Ok(role_control) } - async fn get_allowed_ips( + #[tracing::instrument(skip_all)] + async fn get_endpoint_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - let normalized_ep = &user_info.endpoint.normalize(); - if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) { - Metrics::get() - .proxy - .allowed_ips_cache_misses // TODO SR: Should we rename this variable to something like allowed_ip_cache_stats? - .inc(CacheOutcome::Hit); - return Ok(allowed_ips); + endpoint: &EndpointId, + role: &RoleName, + ) -> Result { + let normalized_ep = &endpoint.normalize(); + if let Some(control) = self.caches.project_info.get_endpoint_access(normalized_ep) { + return Ok(control); } - Metrics::get() - .proxy - .allowed_ips_cache_misses - .inc(CacheOutcome::Miss); - let auth_info = self.do_get_auth_info(ctx, user_info).await?; - let allowed_ips = Arc::new(auth_info.allowed_ips); - let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids); - let access_blocker_flags = auth_info.access_blocker_flags; - let user = &user_info.user; - let account_id = auth_info.account_id; + + if !self.caches.endpoints_cache.is_valid(ctx, normalized_ep) { + info!("endpoint is not valid, skipping the request"); + return Err(GetAuthInfoError::UnknownEndpoint); + } + + let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?; + + let control = EndpointAccessControl { + allowed_ips: Arc::new(auth_info.allowed_ips), + allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids), + flags: auth_info.access_blocker_flags, + }; + let role_control = RoleAccessControl { + secret: auth_info.secret, + }; + if let Some(project_id) = auth_info.project_id { let normalized_ep_int = normalized_ep.into(); - self.caches.project_info.insert_role_secret( + + self.caches.project_info.insert_endpoint_access( + auth_info.account_id, project_id, normalized_ep_int, - user.into(), - auth_info.secret.clone(), - ); - self.caches.project_info.insert_allowed_ips( - project_id, - normalized_ep_int, - allowed_ips.clone(), - ); - self.caches.project_info.insert_allowed_vpc_endpoint_ids( - account_id, - project_id, - normalized_ep_int, - allowed_vpc_endpoint_ids.clone(), - ); - self.caches.project_info.insert_block_public_or_vpc_access( - project_id, - normalized_ep_int, - access_blocker_flags, + role.into(), + control.clone(), + role_control, ); ctx.set_project_id(project_id); } - Ok(Cached::new_uncached(allowed_ips)) - } - async fn get_allowed_vpc_endpoint_ids( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - let normalized_ep = &user_info.endpoint.normalize(); - if let Some(allowed_vpc_endpoint_ids) = self - .caches - .project_info - .get_allowed_vpc_endpoint_ids(normalized_ep) - { - Metrics::get() - .proxy - .vpc_endpoint_id_cache_stats - .inc(CacheOutcome::Hit); - return Ok(allowed_vpc_endpoint_ids); - } - - Metrics::get() - .proxy - .vpc_endpoint_id_cache_stats - .inc(CacheOutcome::Miss); - - let auth_info = self.do_get_auth_info(ctx, user_info).await?; - let allowed_ips = Arc::new(auth_info.allowed_ips); - let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids); - let access_blocker_flags = auth_info.access_blocker_flags; - let user = &user_info.user; - let account_id = auth_info.account_id; - if let Some(project_id) = auth_info.project_id { - let normalized_ep_int = normalized_ep.into(); - self.caches.project_info.insert_role_secret( - project_id, - normalized_ep_int, - user.into(), - auth_info.secret.clone(), - ); - self.caches.project_info.insert_allowed_ips( - project_id, - normalized_ep_int, - allowed_ips.clone(), - ); - self.caches.project_info.insert_allowed_vpc_endpoint_ids( - account_id, - project_id, - normalized_ep_int, - allowed_vpc_endpoint_ids.clone(), - ); - self.caches.project_info.insert_block_public_or_vpc_access( - project_id, - normalized_ep_int, - access_blocker_flags, - ); - ctx.set_project_id(project_id); - } - Ok(Cached::new_uncached(allowed_vpc_endpoint_ids)) - } - - async fn get_block_public_or_vpc_access( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - let normalized_ep = &user_info.endpoint.normalize(); - if let Some(access_blocker_flags) = self - .caches - .project_info - .get_block_public_or_vpc_access(normalized_ep) - { - Metrics::get() - .proxy - .access_blocker_flags_cache_stats - .inc(CacheOutcome::Hit); - return Ok(access_blocker_flags); - } - - Metrics::get() - .proxy - .access_blocker_flags_cache_stats - .inc(CacheOutcome::Miss); - - let auth_info = self.do_get_auth_info(ctx, user_info).await?; - let allowed_ips = Arc::new(auth_info.allowed_ips); - let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids); - let access_blocker_flags = auth_info.access_blocker_flags; - let user = &user_info.user; - let account_id = auth_info.account_id; - if let Some(project_id) = auth_info.project_id { - let normalized_ep_int = normalized_ep.into(); - self.caches.project_info.insert_role_secret( - project_id, - normalized_ep_int, - user.into(), - auth_info.secret.clone(), - ); - self.caches.project_info.insert_allowed_ips( - project_id, - normalized_ep_int, - allowed_ips.clone(), - ); - self.caches.project_info.insert_allowed_vpc_endpoint_ids( - account_id, - project_id, - normalized_ep_int, - allowed_vpc_endpoint_ids.clone(), - ); - self.caches.project_info.insert_block_public_or_vpc_access( - project_id, - normalized_ep_int, - access_blocker_flags.clone(), - ); - ctx.set_project_id(project_id); - } - Ok(Cached::new_uncached(access_blocker_flags)) + Ok(control) } #[tracing::instrument(skip_all)] async fn get_endpoint_jwks( &self, ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, GetEndpointJwksError> { self.do_get_endpoint_jwks(ctx, endpoint).await } diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index d3ab4abd0b..ece7153fce 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -15,14 +15,14 @@ use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; use crate::cache::Cached; use crate::context::RequestContext; -use crate::control_plane::client::{ - CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedRoleSecret, -}; use crate::control_plane::errors::{ ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError, }; use crate::control_plane::messages::MetricsAuxInfo; -use crate::control_plane::{AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo}; +use crate::control_plane::{ + AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo, + RoleAccessControl, +}; use crate::intern::RoleNameInt; use crate::types::{BranchId, EndpointId, ProjectId, RoleName}; use crate::url::ApiUrl; @@ -66,7 +66,8 @@ impl MockControlPlane { async fn do_get_auth_info( &self, - user_info: &ComputeUserInfo, + endpoint: &EndpointId, + role: &RoleName, ) -> Result { let (secret, allowed_ips) = async { // Perhaps we could persist this connection, but then we'd have to @@ -80,7 +81,7 @@ impl MockControlPlane { let secret = if let Some(entry) = get_execute_postgres_query( &client, "select rolpassword from pg_catalog.pg_authid where rolname = $1", - &[&&*user_info.user], + &[&role.as_str()], "rolpassword", ) .await? @@ -89,7 +90,7 @@ impl MockControlPlane { let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram); secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5)) } else { - warn!("user '{}' does not exist", user_info.user); + warn!("user '{role}' does not exist"); None }; @@ -97,7 +98,7 @@ impl MockControlPlane { match get_execute_postgres_query( &client, "select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1", - &[&user_info.endpoint.as_str()], + &[&endpoint.as_str()], "allowed_ips", ) .await? @@ -133,7 +134,7 @@ impl MockControlPlane { async fn do_get_endpoint_jwks( &self, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, GetEndpointJwksError> { let (client, connection) = tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?; @@ -222,53 +223,36 @@ async fn get_execute_postgres_query( } impl super::ControlPlaneApi for MockControlPlane { - #[tracing::instrument(skip_all)] - async fn get_role_secret( + async fn get_endpoint_access_control( &self, _ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - Ok(CachedRoleSecret::new_uncached( - self.do_get_auth_info(user_info).await?.secret, - )) + endpoint: &EndpointId, + role: &RoleName, + ) -> Result { + let info = self.do_get_auth_info(endpoint, role).await?; + Ok(EndpointAccessControl { + allowed_ips: Arc::new(info.allowed_ips), + allowed_vpce: Arc::new(info.allowed_vpc_endpoint_ids), + flags: info.access_blocker_flags, + }) } - async fn get_allowed_ips( + async fn get_role_access_control( &self, _ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - Ok(Cached::new_uncached(Arc::new( - self.do_get_auth_info(user_info).await?.allowed_ips, - ))) - } - - async fn get_allowed_vpc_endpoint_ids( - &self, - _ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - Ok(Cached::new_uncached(Arc::new( - self.do_get_auth_info(user_info) - .await? - .allowed_vpc_endpoint_ids, - ))) - } - - async fn get_block_public_or_vpc_access( - &self, - _ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - Ok(Cached::new_uncached( - self.do_get_auth_info(user_info).await?.access_blocker_flags, - )) + endpoint: &EndpointId, + role: &RoleName, + ) -> Result { + let info = self.do_get_auth_info(endpoint, role).await?; + Ok(RoleAccessControl { + secret: info.secret, + }) } async fn get_endpoint_jwks( &self, _ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, GetEndpointJwksError> { self.do_get_endpoint_jwks(endpoint).await } diff --git a/proxy/src/control_plane/client/mod.rs b/proxy/src/control_plane/client/mod.rs index 746595de38..9b9d1e25ea 100644 --- a/proxy/src/control_plane/client/mod.rs +++ b/proxy/src/control_plane/client/mod.rs @@ -16,15 +16,14 @@ use crate::cache::endpoints::EndpointsCache; use crate::cache::project_info::ProjectInfoCacheImpl; use crate::config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions}; use crate::context::RequestContext; -use crate::control_plane::{ - CachedAccessBlockerFlags, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, - CachedRoleSecret, ControlPlaneApi, NodeInfoCache, errors, -}; +use crate::control_plane::{CachedNodeInfo, ControlPlaneApi, NodeInfoCache, errors}; use crate::error::ReportableError; use crate::metrics::ApiLockMetrics; use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token}; use crate::types::EndpointId; +use super::{EndpointAccessControl, RoleAccessControl}; + #[non_exhaustive] #[derive(Clone)] pub enum ControlPlaneClient { @@ -40,68 +39,42 @@ pub enum ControlPlaneClient { } impl ControlPlaneApi for ControlPlaneClient { - async fn get_role_secret( + async fn get_role_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { + endpoint: &EndpointId, + role: &crate::types::RoleName, + ) -> Result { match self { - Self::ProxyV1(api) => api.get_role_secret(ctx, user_info).await, + Self::ProxyV1(api) => api.get_role_access_control(ctx, endpoint, role).await, #[cfg(any(test, feature = "testing"))] - Self::PostgresMock(api) => api.get_role_secret(ctx, user_info).await, + Self::PostgresMock(api) => api.get_role_access_control(ctx, endpoint, role).await, #[cfg(test)] - Self::Test(_) => { + Self::Test(_api) => { unreachable!("this function should never be called in the test backend") } } } - async fn get_allowed_ips( + async fn get_endpoint_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { + endpoint: &EndpointId, + role: &crate::types::RoleName, + ) -> Result { match self { - Self::ProxyV1(api) => api.get_allowed_ips(ctx, user_info).await, + Self::ProxyV1(api) => api.get_endpoint_access_control(ctx, endpoint, role).await, #[cfg(any(test, feature = "testing"))] - Self::PostgresMock(api) => api.get_allowed_ips(ctx, user_info).await, + Self::PostgresMock(api) => api.get_endpoint_access_control(ctx, endpoint, role).await, #[cfg(test)] - Self::Test(api) => api.get_allowed_ips(), - } - } - - async fn get_allowed_vpc_endpoint_ids( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - match self { - Self::ProxyV1(api) => api.get_allowed_vpc_endpoint_ids(ctx, user_info).await, - #[cfg(any(test, feature = "testing"))] - Self::PostgresMock(api) => api.get_allowed_vpc_endpoint_ids(ctx, user_info).await, - #[cfg(test)] - Self::Test(api) => api.get_allowed_vpc_endpoint_ids(), - } - } - - async fn get_block_public_or_vpc_access( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result { - match self { - Self::ProxyV1(api) => api.get_block_public_or_vpc_access(ctx, user_info).await, - #[cfg(any(test, feature = "testing"))] - Self::PostgresMock(api) => api.get_block_public_or_vpc_access(ctx, user_info).await, - #[cfg(test)] - Self::Test(api) => api.get_block_public_or_vpc_access(), + Self::Test(api) => api.get_access_control(), } } async fn get_endpoint_jwks( &self, ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, errors::GetEndpointJwksError> { match self { Self::ProxyV1(api) => api.get_endpoint_jwks(ctx, endpoint).await, @@ -131,15 +104,7 @@ impl ControlPlaneApi for ControlPlaneClient { pub(crate) trait TestControlPlaneClient: Send + Sync + 'static { fn wake_compute(&self) -> Result; - fn get_allowed_ips(&self) -> Result; - - fn get_allowed_vpc_endpoint_ids( - &self, - ) -> Result; - - fn get_block_public_or_vpc_access( - &self, - ) -> Result; + fn get_access_control(&self) -> Result; fn dyn_clone(&self) -> Box; } @@ -309,7 +274,7 @@ impl FetchAuthRules for ControlPlaneClient { ctx: &RequestContext, endpoint: EndpointId, ) -> Result, FetchAuthRulesError> { - self.get_endpoint_jwks(ctx, endpoint) + self.get_endpoint_jwks(ctx, &endpoint) .await .map_err(FetchAuthRulesError::GetEndpointJwks) } diff --git a/proxy/src/control_plane/errors.rs b/proxy/src/control_plane/errors.rs index 850d061333..77312c89c5 100644 --- a/proxy/src/control_plane/errors.rs +++ b/proxy/src/control_plane/errors.rs @@ -99,6 +99,10 @@ pub(crate) enum GetAuthInfoError { #[error(transparent)] ApiError(ControlPlaneError), + + /// Proxy does not know about the endpoint in advanced + #[error("endpoint not found in endpoint cache")] + UnknownEndpoint, } // This allows more useful interactions than `#[from]`. @@ -115,6 +119,8 @@ impl UserFacingError for GetAuthInfoError { Self::BadSecret => REQUEST_FAILED.to_owned(), // However, API might return a meaningful error. Self::ApiError(e) => e.to_string_client(), + // pretend like control plane returned an error. + Self::UnknownEndpoint => REQUEST_FAILED.to_owned(), } } } @@ -124,6 +130,8 @@ impl ReportableError for GetAuthInfoError { match self { Self::BadSecret => crate::error::ErrorKind::ControlPlane, Self::ApiError(_) => crate::error::ErrorKind::ControlPlane, + // we only apply endpoint filtering if control plane is under high load. + Self::UnknownEndpoint => crate::error::ErrorKind::ServiceRateLimit, } } } diff --git a/proxy/src/control_plane/mod.rs b/proxy/src/control_plane/mod.rs index d592223be1..7ff093d9dc 100644 --- a/proxy/src/control_plane/mod.rs +++ b/proxy/src/control_plane/mod.rs @@ -11,16 +11,16 @@ pub(crate) mod errors; use std::sync::Arc; -use crate::auth::IpPattern; use crate::auth::backend::jwt::AuthRule; use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo}; -use crate::cache::project_info::ProjectInfoCacheImpl; +use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list}; use crate::cache::{Cached, TimedLru}; use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo}; use crate::intern::{AccountIdInt, ProjectIdInt}; -use crate::types::{EndpointCacheKey, EndpointId}; +use crate::protocol2::ConnectionInfoExtra; +use crate::types::{EndpointCacheKey, EndpointId, RoleName}; use crate::{compute, scram}; /// Various cache-related types. @@ -101,7 +101,7 @@ impl NodeInfo { } } -#[derive(Clone, Default, Eq, PartialEq, Debug)] +#[derive(Copy, Clone, Default)] pub(crate) struct AccessBlockerFlags { pub public_access_blocked: bool, pub vpc_access_blocked: bool, @@ -110,47 +110,78 @@ pub(crate) struct AccessBlockerFlags { pub(crate) type NodeInfoCache = TimedLru>>; pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>; -pub(crate) type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option>; -pub(crate) type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc>>; -pub(crate) type CachedAllowedVpcEndpointIds = - Cached<&'static ProjectInfoCacheImpl, Arc>>; -pub(crate) type CachedAccessBlockerFlags = - Cached<&'static ProjectInfoCacheImpl, AccessBlockerFlags>; + +#[derive(Clone)] +pub struct RoleAccessControl { + pub secret: Option, +} + +#[derive(Clone)] +pub struct EndpointAccessControl { + pub allowed_ips: Arc>, + pub allowed_vpce: Arc>, + pub flags: AccessBlockerFlags, +} + +impl EndpointAccessControl { + pub fn check( + &self, + ctx: &RequestContext, + check_ip_allowed: bool, + check_vpc_allowed: bool, + ) -> Result<(), AuthError> { + if check_ip_allowed && !check_peer_addr_is_in_list(&ctx.peer_addr(), &self.allowed_ips) { + return Err(AuthError::IpAddressNotAllowed(ctx.peer_addr())); + } + + // check if a VPC endpoint ID is coming in and if yes, if it's allowed + if check_vpc_allowed { + if self.flags.vpc_access_blocked { + return Err(AuthError::NetworkNotAllowed); + } + + let incoming_vpc_endpoint_id = match ctx.extra() { + None => return Err(AuthError::MissingVPCEndpointId), + Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(), + Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), + }; + + let vpce = &self.allowed_vpce; + // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. + if !vpce.is_empty() && !vpce.contains(&incoming_vpc_endpoint_id) { + return Err(AuthError::vpc_endpoint_id_not_allowed( + incoming_vpc_endpoint_id, + )); + } + } else if self.flags.public_access_blocked { + return Err(AuthError::NetworkNotAllowed); + } + + Ok(()) + } +} /// This will allocate per each call, but the http requests alone /// already require a few allocations, so it should be fine. pub(crate) trait ControlPlaneApi { - /// Get the client's auth secret for authentication. - /// Returns option because user not found situation is special. - /// We still have to mock the scram to avoid leaking information that user doesn't exist. - async fn get_role_secret( + async fn get_role_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result; + endpoint: &EndpointId, + role: &RoleName, + ) -> Result; - async fn get_allowed_ips( + async fn get_endpoint_access_control( &self, ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result; - - async fn get_allowed_vpc_endpoint_ids( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result; - - async fn get_block_public_or_vpc_access( - &self, - ctx: &RequestContext, - user_info: &ComputeUserInfo, - ) -> Result; + endpoint: &EndpointId, + role: &RoleName, + ) -> Result; async fn get_endpoint_jwks( &self, ctx: &RequestContext, - endpoint: EndpointId, + endpoint: &EndpointId, ) -> Result, errors::GetEndpointJwksError>; /// Wake up the compute node and return the corresponding connection info. diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 26ac6a89e7..ac0aca1176 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -345,7 +345,7 @@ pub(crate) async fn handle_client( }; let user = user_info.get_user().to_owned(); - let (user_info, _ip_allowlist) = match user_info + let user_info = match user_info .authenticate( ctx, &mut stream, diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 3cc053e0ad..61e8ee4a10 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -26,9 +26,7 @@ use crate::auth::backend::{ use crate::config::{ComputeConfig, RetryConfig}; use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient}; use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status}; -use crate::control_plane::{ - self, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, NodeInfo, NodeInfoCache, -}; +use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache}; use crate::error::ErrorKind; use crate::tls::client_config::compute_client_config_with_certs; use crate::tls::postgres_rustls::MakeRustlsConnect; @@ -547,20 +545,9 @@ impl TestControlPlaneClient for TestConnectMechanism { } } - fn get_allowed_ips(&self) -> Result { - unimplemented!("not used in tests") - } - - fn get_allowed_vpc_endpoint_ids( + fn get_access_control( &self, - ) -> Result { - unimplemented!("not used in tests") - } - - fn get_block_public_or_vpc_access( - &self, - ) -> Result - { + ) -> Result { unimplemented!("not used in tests") } diff --git a/proxy/src/rate_limiter/leaky_bucket.rs b/proxy/src/rate_limiter/leaky_bucket.rs index 4f27c6faef..0c79b5e92f 100644 --- a/proxy/src/rate_limiter/leaky_bucket.rs +++ b/proxy/src/rate_limiter/leaky_bucket.rs @@ -15,7 +15,7 @@ pub type EndpointRateLimiter = LeakyBucketRateLimiter; pub struct LeakyBucketRateLimiter { map: ClashMap, - config: utils::leaky_bucket::LeakyBucketConfig, + default_config: utils::leaky_bucket::LeakyBucketConfig, access_count: AtomicUsize, } @@ -28,15 +28,17 @@ impl LeakyBucketRateLimiter { pub fn new_with_shards(config: LeakyBucketConfig, shards: usize) -> Self { Self { map: ClashMap::with_hasher_and_shard_amount(RandomState::new(), shards), - config: config.into(), + default_config: config.into(), access_count: AtomicUsize::new(0), } } /// Check that number of connections to the endpoint is below `max_rps` rps. - pub(crate) fn check(&self, key: K, n: u32) -> bool { + pub(crate) fn check(&self, key: K, config: Option, n: u32) -> bool { let now = Instant::now(); + let config = config.map_or(self.default_config, Into::into); + if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 { self.do_gc(now); } @@ -46,7 +48,7 @@ impl LeakyBucketRateLimiter { .entry(key) .or_insert_with(|| LeakyBucketState { empty_at: now }); - entry.add_tokens(&self.config, now, n as f64).is_ok() + entry.add_tokens(&config, now, n as f64).is_ok() } fn do_gc(&self, now: Instant) { diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index 21eaa6739b..9d700c1b52 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -15,6 +15,8 @@ use tracing::info; use crate::ext::LockExt; use crate::intern::EndpointIdInt; +use super::LeakyBucketConfig; + pub struct GlobalRateLimiter { data: Vec, info: Vec, @@ -144,19 +146,6 @@ impl RateBucketInfo { Self::new(50_000, Duration::from_secs(10)), ]; - /// All of these are per endpoint-maskedip pair. - /// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus). - /// - /// First bucket: 1000mcpus total per endpoint-ip pair - /// * 4096000 requests per second with 1 hash rounds. - /// * 1000 requests per second with 4096 hash rounds. - /// * 6.8 requests per second with 600000 hash rounds. - pub const DEFAULT_AUTH_SET: [Self; 3] = [ - Self::new(1000 * 4096, Duration::from_secs(1)), - Self::new(600 * 4096, Duration::from_secs(60)), - Self::new(300 * 4096, Duration::from_secs(600)), - ]; - pub fn rps(&self) -> f64 { (self.max_rpi as f64) / self.interval.as_secs_f64() } @@ -184,6 +173,21 @@ impl RateBucketInfo { max_rpi: ((max_rps as u64) * (interval.as_millis() as u64) / 1000) as u32, } } + + pub fn to_leaky_bucket(this: &[Self]) -> Option { + // bit of a hack - find the min rps and max rps supported and turn it into + // leaky bucket config instead + + let mut iter = this.iter().map(|info| info.rps()); + let first = iter.next()?; + + let (min, max) = (first, first); + let (min, max) = iter.fold((min, max), |(min, max), rps| { + (f64::min(min, rps), f64::max(max, rps)) + }); + + Some(LeakyBucketConfig { rps: min, max }) + } } impl BucketRateLimiter { diff --git a/proxy/src/rate_limiter/mod.rs b/proxy/src/rate_limiter/mod.rs index 5f90102da3..112b95873a 100644 --- a/proxy/src/rate_limiter/mod.rs +++ b/proxy/src/rate_limiter/mod.rs @@ -8,4 +8,4 @@ pub(crate) use limit_algorithm::aimd::Aimd; pub(crate) use limit_algorithm::{ DynamicLimiter, Outcome, RateLimitAlgorithm, RateLimiterConfig, Token, }; -pub use limiter::{BucketRateLimiter, GlobalRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; +pub use limiter::{GlobalRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 769d519d94..a9d6b40603 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -233,29 +233,30 @@ impl MessageHandler { fn invalidate_cache(cache: Arc, msg: Notification) { match msg { - Notification::AllowedIpsUpdate { allowed_ips_update } => { - cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id); + Notification::AllowedIpsUpdate { + allowed_ips_update: AllowedIpsUpdate { project_id }, } - Notification::BlockPublicOrVpcAccessUpdated { - block_public_or_vpc_access_updated, - } => cache.invalidate_block_public_or_vpc_access_for_project( - block_public_or_vpc_access_updated.project_id, - ), + | Notification::BlockPublicOrVpcAccessUpdated { + block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated { project_id }, + } => cache.invalidate_endpoint_access_for_project(project_id), Notification::AllowedVpcEndpointsUpdatedForOrg { - allowed_vpc_endpoints_updated_for_org, - } => cache.invalidate_allowed_vpc_endpoint_ids_for_org( - allowed_vpc_endpoints_updated_for_org.account_id, - ), + allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg { account_id }, + } => cache.invalidate_endpoint_access_for_org(account_id), Notification::AllowedVpcEndpointsUpdatedForProjects { - allowed_vpc_endpoints_updated_for_projects, - } => cache.invalidate_allowed_vpc_endpoint_ids_for_projects( - allowed_vpc_endpoints_updated_for_projects.project_ids, - ), - Notification::PasswordUpdate { password_update } => cache - .invalidate_role_secret_for_project( - password_update.project_id, - password_update.role_name, - ), + allowed_vpc_endpoints_updated_for_projects: + AllowedVpcEndpointsUpdatedForProjects { project_ids }, + } => { + for project in project_ids { + cache.invalidate_endpoint_access_for_project(project); + } + } + Notification::PasswordUpdate { + password_update: + PasswordUpdate { + project_id, + role_name, + }, + } => cache.invalidate_role_secret_for_project(project_id, role_name), Notification::UnknownTopic => unreachable!(), } } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 13058f08f1..bf640c05e9 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -22,7 +22,7 @@ use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client}; use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool}; use crate::auth::backend::local::StaticAuthRules; use crate::auth::backend::{ComputeCredentials, ComputeUserInfo}; -use crate::auth::{self, AuthError, check_peer_addr_is_in_list}; +use crate::auth::{self, AuthError}; use crate::compute; use crate::compute_ctl::{ ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest, @@ -35,7 +35,6 @@ use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError}; use crate::control_plane::locks::ApiLocks; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::intern::EndpointIdInt; -use crate::protocol2::ConnectionInfoExtra; use crate::proxy::connect_compute::ConnectMechanism; use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute}; use crate::rate_limiter::EndpointRateLimiter; @@ -63,63 +62,24 @@ impl PoolingBackend { let user_info = user_info.clone(); let backend = self.auth_backend.as_ref().map(|()| user_info.clone()); - let allowed_ips = backend.get_allowed_ips(ctx).await?; + let access_control = backend.get_endpoint_access_control(ctx).await?; + access_control.check( + ctx, + self.config.authentication_config.ip_allowlist_check_enabled, + self.config.authentication_config.is_vpc_acccess_proxy, + )?; - if self.config.authentication_config.ip_allowlist_check_enabled - && !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) - { - return Err(AuthError::ip_address_not_allowed(ctx.peer_addr())); - } - - let access_blocker_flags = backend.get_block_public_or_vpc_access(ctx).await?; - if self.config.authentication_config.is_vpc_acccess_proxy { - if access_blocker_flags.vpc_access_blocked { - return Err(AuthError::NetworkNotAllowed); - } - - let extra = ctx.extra(); - let incoming_endpoint_id = match extra { - None => String::new(), - Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(), - Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), - }; - - if incoming_endpoint_id.is_empty() { - return Err(AuthError::MissingVPCEndpointId); - } - - let allowed_vpc_endpoint_ids = backend.get_allowed_vpc_endpoint_ids(ctx).await?; - // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that. - if !allowed_vpc_endpoint_ids.is_empty() - && !allowed_vpc_endpoint_ids.contains(&incoming_endpoint_id) - { - return Err(AuthError::vpc_endpoint_id_not_allowed(incoming_endpoint_id)); - } - } else if access_blocker_flags.public_access_blocked { - return Err(AuthError::NetworkNotAllowed); - } - - if !self - .endpoint_rate_limiter - .check(user_info.endpoint.clone().into(), 1) - { + let ep = EndpointIdInt::from(&user_info.endpoint); + let rate_limit_config = None; + if !self.endpoint_rate_limiter.check(ep, rate_limit_config, 1) { return Err(AuthError::too_many_connections()); } - let cached_secret = backend.get_role_secret(ctx).await?; - let secret = match cached_secret.value.clone() { - Some(secret) => self.config.authentication_config.check_rate_limit( - ctx, - secret, - &user_info.endpoint, - true, - )?, - None => { - // If we don't have an authentication secret, for the http flow we can just return an error. - info!("authentication info not found"); - return Err(AuthError::password_failed(&*user_info.user)); - } + let role_access = backend.get_role_secret(ctx).await?; + let Some(secret) = role_access.secret else { + // If we don't have an authentication secret, for the http flow we can just return an error. + info!("authentication info not found"); + return Err(AuthError::password_failed(&*user_info.user)); }; - let ep = EndpointIdInt::from(&user_info.endpoint); let auth_outcome = crate::auth::validate_password_and_exchange( &self.config.authentication_config.thread_pool, ep,