diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index 84430dc812..5d661c9135 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -16,12 +16,15 @@ use super::{Cache, Cached}; use crate::auth::IpPattern; use crate::config::ProjectInfoCacheOptions; use crate::control_plane::AuthSecret; -use crate::intern::{EndpointIdInt, ProjectIdInt, RoleNameInt}; +use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; use crate::types::{EndpointId, RoleName}; #[async_trait] pub(crate) trait ProjectInfoCache { fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt); + fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec); + fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt); + fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt); fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt); async fn decrement_active_listeners(&self); async fn increment_active_listeners(&self); @@ -51,6 +54,8 @@ impl From for Entry { struct EndpointInfo { secret: std::collections::HashMap>>, allowed_ips: Option>>>, + block_public_or_vpc_access: Option>, + allowed_vpc_endpoint_ids: Option>>>, } impl EndpointInfo { @@ -92,6 +97,51 @@ impl EndpointInfo { } None } + + pub(crate) fn get_allowed_vpc_endpoint_ids( + &self, + valid_since: Instant, + ignore_cache_since: Option, + ) -> Option<(Arc>, bool)> { + if let Some(allowed_vpc_endpoint_ids) = &self.allowed_vpc_endpoint_ids { + if valid_since < allowed_vpc_endpoint_ids.created_at { + return Some(( + allowed_vpc_endpoint_ids.value.clone(), + Self::check_ignore_cache( + ignore_cache_since, + allowed_vpc_endpoint_ids.created_at, + ), + )); + } + } + None + } + + pub(crate) fn get_block_public_or_vpc_access( + &self, + valid_since: Instant, + ignore_cache_since: Option, + ) -> Option<((bool, bool), bool)> { + if let Some(block_public_or_vpc_access) = &self.block_public_or_vpc_access { + if valid_since < block_public_or_vpc_access.created_at { + return Some(( + block_public_or_vpc_access.value.clone(), + Self::check_ignore_cache( + ignore_cache_since, + block_public_or_vpc_access.created_at, + ), + )); + } + } + None + } + + pub(crate) fn invalidate_block_public_or_vpc_access(&mut self) { + self.block_public_or_vpc_access = None; + } + pub(crate) fn invalidate_allowed_vpc_endpoint_ids(&mut self) { + self.allowed_vpc_endpoint_ids = None; + } pub(crate) fn invalidate_allowed_ips(&mut self) { self.allowed_ips = None; } @@ -111,6 +161,8 @@ pub struct ProjectInfoCacheImpl { cache: DashMap, project2ep: DashMap>, + // FIXME(stefan): we need a way to GC the account2ep map. + account2ep: DashMap>, config: ProjectInfoCacheOptions, start_time: Instant, @@ -120,6 +172,59 @@ pub struct ProjectInfoCacheImpl { #[async_trait] impl ProjectInfoCache for ProjectInfoCacheImpl { + fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec) { + info!( + "invalidating allowed vpc endpoint ids for projects `{}`", + project_ids.join(", ") + ); + for project_id in project_ids { + let endpoints = self + .project2ep + .get(&project_id) + .map(|kv| kv.value().clone()) + .unwrap_or_default(); + for endpoint_id in endpoints { + if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { + endpoint_info.invalidate_allowed_vpc_endpoint_ids(); + } + } + } + } + + fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt) { + info!( + "invalidating allowed vpc endpoint ids for org `{}`", + account_id + ); + let endpoints = self + .account2ep + .get(&account_id) + .map(|kv| kv.value().clone()) + .unwrap_or_default(); + for endpoint_id in endpoints { + if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { + endpoint_info.invalidate_allowed_vpc_endpoint_ids(); + } + } + } + + fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt) { + info!( + "invalidating block public or vpc access for project `{}`", + project_id + ); + let endpoints = self + .project2ep + .get(&project_id) + .map(|kv| kv.value().clone()) + .unwrap_or_default(); + for endpoint_id in endpoints { + if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { + endpoint_info.invalidate_block_public_or_vpc_access(); + } + } + } + fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt) { info!("invalidating allowed ips for project `{}`", project_id); let endpoints = self @@ -178,6 +283,7 @@ impl ProjectInfoCacheImpl { Self { cache: DashMap::new(), project2ep: DashMap::new(), + account2ep: DashMap::new(), config, ttl_disabled_since_us: AtomicU64::new(u64::MAX), start_time: Instant::now(), @@ -226,6 +332,7 @@ impl ProjectInfoCacheImpl { } Some(Cached::new_uncached(value)) } + pub(crate) fn insert_role_secret( &self, project_id: ProjectIdInt, @@ -256,6 +363,16 @@ impl ProjectInfoCacheImpl { self.insert_project2endpoint(project_id, endpoint_id); self.cache.entry(endpoint_id).or_default().allowed_ips = Some(allowed_ips.into()); } + pub(crate) fn insert_vpc_allowed_endpoint_ids(&self, account_id: AccountIdInt, project_id: ProjectIdInt, endpoint_id: EndpointIdInt, vpc_allowed_endpoint_ids: HashSet) { + if self.cache.len() >= self.config.size { + // If there are too many entries, wait until the next gc cycle. + return; + } + self.insert_account2endpoint(account_id, endpoint_id); + self.insert_project2endpoint(project_id, endpoint_id); + self.cache.entry(endpoint_id).or_default().vpc_allowed_endpoint_ids = Some(vpc_allowed_endpoint_ids); + } + } fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) { if let Some(mut endpoints) = self.project2ep.get_mut(&project_id) { endpoints.insert(endpoint_id); @@ -264,6 +381,13 @@ impl ProjectInfoCacheImpl { .insert(project_id, HashSet::from([endpoint_id])); } } + fn insert_account2endpoint(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt) { + if let Some(mut endpoints) = self.account2ep.get_mut(&account_id) { + endpoints.insert(endpoint_id); + } else { + self.account2ep.insert(account_id, HashSet::from([endpoint_id])); + } + } fn get_cache_times(&self) -> (Instant, Option) { let mut valid_since = Instant::now() - self.config.ttl; // Only ignore cache if ttl is disabled. diff --git a/proxy/src/control_plane/messages.rs b/proxy/src/control_plane/messages.rs index 2662ab85f9..22829efe32 100644 --- a/proxy/src/control_plane/messages.rs +++ b/proxy/src/control_plane/messages.rs @@ -238,6 +238,8 @@ pub(crate) struct GetEndpointAccessControl { pub(crate) allowed_ips: Option>, pub(crate) project_id: Option, pub(crate) allowed_vpc_endpoint_ids: Option>, + pub(crate) block_public_connections: Option, + pub(crate) block_vpc_connections: Option, } // Manually implement debug to omit sensitive info. diff --git a/proxy/src/intern.rs b/proxy/src/intern.rs index f56d92a6b3..704ad2323d 100644 --- a/proxy/src/intern.rs +++ b/proxy/src/intern.rs @@ -7,7 +7,7 @@ use std::sync::OnceLock; use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo}; use rustc_hash::FxHasher; -use crate::types::{BranchId, EndpointId, ProjectId, RoleName}; +use crate::types::{AccountId, BranchId, EndpointId, ProjectId, RoleName}; pub trait InternId: Sized + 'static { fn get_interner() -> &'static StringInterner; @@ -206,6 +206,26 @@ impl From for ProjectIdInt { } } +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub struct AccountIdTag; +impl InternId for AccountIdTag { + fn get_interner() -> &'static StringInterner { + static ROLE_NAMES: OnceLock> = OnceLock::new(); + ROLE_NAMES.get_or_init(Default::default) + } +} +pub type AccountIdInt = InternedString; +impl From<&AccountId> for AccountIdInt { + fn from(value: &AccountId) -> Self { + AccountIdTag::get_interner().get_or_intern(value) + } +} +impl From for AccountIdInt { + fn from(value: AccountId) -> Self { + AccountIdTag::get_interner().get_or_intern(&value) + } +} + #[cfg(test)] mod tests { use std::sync::OnceLock; diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index 659c57c865..dfbb4a1530 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -556,6 +556,9 @@ pub enum RedisEventsCount { CancelSession, PasswordUpdate, AllowedIpsUpdate, + AllowedVpcEndpointIdsUpdateForProjects, + AllowedVpcEndpointIdsUpdateForAllProjectsInOrg, + BlockPublicOrVpcAccessUpdate, } pub struct ThreadPoolWorkers(usize); diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index f3aa97c032..2323e22ce6 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -11,7 +11,7 @@ use uuid::Uuid; use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::cache::project_info::ProjectInfoCache; use crate::cancellation::{CancelMap, CancellationHandler}; -use crate::intern::{ProjectIdInt, RoleNameInt}; +use crate::intern::{AccountIdInt, ProjectIdInt, RoleNameInt}; use crate::metrics::{Metrics, RedisErrors, RedisEventsCount}; use tracing::Instrument; @@ -39,6 +39,27 @@ pub(crate) enum Notification { AllowedIpsUpdate { allowed_ips_update: AllowedIpsUpdate, }, + #[serde( + rename = "/allowed_vpc_endpoint_ids_updated_for_projects", + deserialize_with = "deserialize_json_string" + )] + AllowedVpcEndpointIdsUpdateForProjects { + allowed_vpc_endpoint_ids_update_for_projects: AllowedVpcEndpointIdsUpdateForProjects, + }, + #[serde( + rename = "/allowed_vpc_endpoint_ids_updated_for_org", + deserialize_with = "deserialize_json_string" + )] + AllowedVpcEndpointIdsUpdateForAllProjectsInOrg { + allowed_vpc_endpoint_ids_update_for_org: AllowedVpcEndpointIdsUpdateForAllProjectsInOrg, + }, + #[serde( + rename = "/block_public_or_vpc_access_updated", + deserialize_with = "deserialize_json_string" + )] + BlockPublicOrVpcAccessUpdate { + block_public_or_vpc_access_update: BlockPublicOrVpcAccessUpdate, + }, #[serde( rename = "/password_updated", deserialize_with = "deserialize_json_string" @@ -51,6 +72,22 @@ pub(crate) enum Notification { pub(crate) struct AllowedIpsUpdate { project_id: ProjectIdInt, } + +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +pub(crate) struct AllowedVpcEndpointIdsUpdateForProjects { + project_ids: Vec, +} + +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +pub(crate) struct AllowedVpcEndpointIdsUpdateForAllProjectsInOrg { + account_id: AccountIdInt, +} + +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +pub(crate) struct BlockPublicOrVpcAccessUpdate { + project_id: ProjectIdInt, +} + #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] pub(crate) struct PasswordUpdate { project_id: ProjectIdInt, @@ -164,7 +201,11 @@ impl MessageHandler { } } } - Notification::AllowedIpsUpdate { .. } | Notification::PasswordUpdate { .. } => { + Notification::AllowedIpsUpdate { .. } + | Notification::PasswordUpdate { .. } + | Notification::AllowedVpcEndpointIdsUpdateForProjects { .. } + | Notification::AllowedVpcEndpointIdsUpdateForAllProjectsInOrg { .. } + | Notification::BlockPublicOrVpcAccessUpdate { .. } => { invalidate_cache(self.cache.clone(), msg.clone()); if matches!(msg, Notification::AllowedIpsUpdate { .. }) { Metrics::get() @@ -176,6 +217,27 @@ impl MessageHandler { .proxy .redis_events_count .inc(RedisEventsCount::PasswordUpdate); + } else if matches!( + msg, + Notification::AllowedVpcEndpointIdsUpdateForProjects { .. } + ) { + Metrics::get() + .proxy + .redis_events_count + .inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForProjects); + } else if matches!( + msg, + Notification::AllowedVpcEndpointIdsUpdateForAllProjectsInOrg { .. } + ) { + Metrics::get() + .proxy + .redis_events_count + .inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForAllProjectsInOrg); + } else if matches!(msg, Notification::BlockPublicOrVpcAccessUpdate { .. }) { + Metrics::get() + .proxy + .redis_events_count + .inc(RedisEventsCount::BlockPublicOrVpcAccessUpdate); } // It might happen that the invalid entry is on the way to be cached. // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds. @@ -197,6 +259,21 @@ fn invalidate_cache(cache: Arc, msg: Notification) { Notification::AllowedIpsUpdate { allowed_ips_update } => { cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id); } + Notification::AllowedVpcEndpointIdsUpdateForProjects { + allowed_vpc_endpoint_ids_update_for_projects, + } => cache.invalidate_allowed_vpc_endpoint_ids_for_projects( + allowed_vpc_endpoint_ids_update_for_projects.project_ids, + ), + Notification::AllowedVpcEndpointIdsUpdateForAllProjectsInOrg { + allowed_vpc_endpoint_ids_update_for_org, + } => cache.invalidate_allowed_vpc_endpoint_ids_for_org( + allowed_vpc_endpoint_ids_update_for_org.account_id, + ), + Notification::BlockPublicOrVpcAccessUpdate { + block_public_or_vpc_access_update, + } => cache.invalidate_block_public_or_vpc_access_for_project( + block_public_or_vpc_access_update.project_id, + ), Notification::PasswordUpdate { password_update } => cache .invalidate_role_secret_for_project( password_update.project_id, diff --git a/proxy/src/types.rs b/proxy/src/types.rs index 6e0bd61c94..d5952d1d8b 100644 --- a/proxy/src/types.rs +++ b/proxy/src/types.rs @@ -97,6 +97,8 @@ smol_str_wrapper!(EndpointId); smol_str_wrapper!(BranchId); // 90% of project strings are 23 characters or less. smol_str_wrapper!(ProjectId); +// 90% of account strings are 23 characters or less. +smol_str_wrapper!(AccountId); // will usually equal endpoint ID smol_str_wrapper!(EndpointCacheKey);