Compare commits

..

1 Commits

Author SHA1 Message Date
Raphael 'kena' Poss
2e79c47380 proxy: stub the VPC config cache and invalidation code
This commit defines cache entries for projects:

- the two "block public or VPC connections" booleans.
  This is implemented (for now) as a pair of `bools` because
  they are always delivered together from the project
  settings in cplane. The corresponding fields in the
  `/get_endpoint_access_control` response are
  `block_public_connections` and `block_vpc_connections`.

  We will use the redis broadcast channel
  `/block_public_or_vpc_access_updated`` for this. It will be notified
  when the **project settings** are changed in the control plane.

- the VPC endpoint ID list. The corresponding field
  in the `/get_endpoint_access_control` response is
  `allowed_vpc_endpoint_ids`.

  We will use the 2 redis broadcast channels
  `/allowed_vpc_endpoint_ids_updated_for_projects` and
  `/allowed_vpc_endpoint_ids_updated_for_org` for this. It will be
  notified when the **VPC config service** detects a configuration
  change for the project or all projects under an org.
2024-12-13 10:07:15 +01:00
7 changed files with 233 additions and 5 deletions

View File

@@ -68,7 +68,7 @@ CARGO_CMD_PREFIX += $(if $(filter n,$(MAKEFLAGS)),,+)
# Force cargo not to print progress bar
CARGO_CMD_PREFIX += CARGO_TERM_PROGRESS_WHEN=never CI=1
# Set PQ_LIB_DIR to make sure `storage_controller` get linked with bundled libpq (through diesel)
CARGO_CMD_PREFIX += PQ_LIB_DIR=$(POSTGRES_INSTALL_DIR)/v17/lib
CARGO_CMD_PREFIX += PQ_LIB_DIR=$(POSTGRES_INSTALL_DIR)/v16/lib
CACHEDIR_TAG_CONTENTS := "Signature: 8a477f597d28d172789f06886806bc55"

View File

@@ -16,12 +16,15 @@ use super::{Cache, Cached};
use crate::auth::IpPattern;
use crate::config::ProjectInfoCacheOptions;
use crate::control_plane::AuthSecret;
use crate::intern::{EndpointIdInt, ProjectIdInt, RoleNameInt};
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
use crate::types::{EndpointId, RoleName};
#[async_trait]
pub(crate) trait ProjectInfoCache {
fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt);
fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec<ProjectIdInt>);
fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt);
fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt);
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
async fn decrement_active_listeners(&self);
async fn increment_active_listeners(&self);
@@ -51,6 +54,8 @@ impl<T> From<T> for Entry<T> {
struct EndpointInfo {
secret: std::collections::HashMap<RoleNameInt, Entry<Option<AuthSecret>>>,
allowed_ips: Option<Entry<Arc<Vec<IpPattern>>>>,
block_public_or_vpc_access: Option<Entry<(bool, bool)>>,
allowed_vpc_endpoint_ids: Option<Entry<Arc<Vec<String>>>>,
}
impl EndpointInfo {
@@ -92,6 +97,51 @@ impl EndpointInfo {
}
None
}
pub(crate) fn get_allowed_vpc_endpoint_ids(
&self,
valid_since: Instant,
ignore_cache_since: Option<Instant>,
) -> Option<(Arc<Vec<String>>, bool)> {
if let Some(allowed_vpc_endpoint_ids) = &self.allowed_vpc_endpoint_ids {
if valid_since < allowed_vpc_endpoint_ids.created_at {
return Some((
allowed_vpc_endpoint_ids.value.clone(),
Self::check_ignore_cache(
ignore_cache_since,
allowed_vpc_endpoint_ids.created_at,
),
));
}
}
None
}
pub(crate) fn get_block_public_or_vpc_access(
&self,
valid_since: Instant,
ignore_cache_since: Option<Instant>,
) -> Option<((bool, bool), bool)> {
if let Some(block_public_or_vpc_access) = &self.block_public_or_vpc_access {
if valid_since < block_public_or_vpc_access.created_at {
return Some((
block_public_or_vpc_access.value.clone(),
Self::check_ignore_cache(
ignore_cache_since,
block_public_or_vpc_access.created_at,
),
));
}
}
None
}
pub(crate) fn invalidate_block_public_or_vpc_access(&mut self) {
self.block_public_or_vpc_access = None;
}
pub(crate) fn invalidate_allowed_vpc_endpoint_ids(&mut self) {
self.allowed_vpc_endpoint_ids = None;
}
pub(crate) fn invalidate_allowed_ips(&mut self) {
self.allowed_ips = None;
}
@@ -111,6 +161,8 @@ pub struct ProjectInfoCacheImpl {
cache: DashMap<EndpointIdInt, EndpointInfo>,
project2ep: DashMap<ProjectIdInt, HashSet<EndpointIdInt>>,
// FIXME(stefan): we need a way to GC the account2ep map.
account2ep: DashMap<AccountIdInt, HashSet<EndpointIdInt>>,
config: ProjectInfoCacheOptions,
start_time: Instant,
@@ -120,6 +172,59 @@ pub struct ProjectInfoCacheImpl {
#[async_trait]
impl ProjectInfoCache for ProjectInfoCacheImpl {
fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec<ProjectIdInt>) {
info!(
"invalidating allowed vpc endpoint ids for projects `{}`",
project_ids.join(", ")
);
for project_id in project_ids {
let endpoints = self
.project2ep
.get(&project_id)
.map(|kv| kv.value().clone())
.unwrap_or_default();
for endpoint_id in endpoints {
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_allowed_vpc_endpoint_ids();
}
}
}
}
fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt) {
info!(
"invalidating allowed vpc endpoint ids for org `{}`",
account_id
);
let endpoints = self
.account2ep
.get(&account_id)
.map(|kv| kv.value().clone())
.unwrap_or_default();
for endpoint_id in endpoints {
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_allowed_vpc_endpoint_ids();
}
}
}
fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt) {
info!(
"invalidating block public or vpc access for project `{}`",
project_id
);
let endpoints = self
.project2ep
.get(&project_id)
.map(|kv| kv.value().clone())
.unwrap_or_default();
for endpoint_id in endpoints {
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_block_public_or_vpc_access();
}
}
}
fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt) {
info!("invalidating allowed ips for project `{}`", project_id);
let endpoints = self
@@ -178,6 +283,7 @@ impl ProjectInfoCacheImpl {
Self {
cache: DashMap::new(),
project2ep: DashMap::new(),
account2ep: DashMap::new(),
config,
ttl_disabled_since_us: AtomicU64::new(u64::MAX),
start_time: Instant::now(),
@@ -226,6 +332,7 @@ impl ProjectInfoCacheImpl {
}
Some(Cached::new_uncached(value))
}
pub(crate) fn insert_role_secret(
&self,
project_id: ProjectIdInt,
@@ -256,6 +363,16 @@ impl ProjectInfoCacheImpl {
self.insert_project2endpoint(project_id, endpoint_id);
self.cache.entry(endpoint_id).or_default().allowed_ips = Some(allowed_ips.into());
}
pub(crate) fn insert_vpc_allowed_endpoint_ids(&self, account_id: AccountIdInt, project_id: ProjectIdInt, endpoint_id: EndpointIdInt, vpc_allowed_endpoint_ids: HashSet<EndpointIdInt>) {
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
return;
}
self.insert_account2endpoint(account_id, endpoint_id);
self.insert_project2endpoint(project_id, endpoint_id);
self.cache.entry(endpoint_id).or_default().vpc_allowed_endpoint_ids = Some(vpc_allowed_endpoint_ids);
}
}
fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) {
if let Some(mut endpoints) = self.project2ep.get_mut(&project_id) {
endpoints.insert(endpoint_id);
@@ -264,6 +381,13 @@ impl ProjectInfoCacheImpl {
.insert(project_id, HashSet::from([endpoint_id]));
}
}
fn insert_account2endpoint(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt) {
if let Some(mut endpoints) = self.account2ep.get_mut(&account_id) {
endpoints.insert(endpoint_id);
} else {
self.account2ep.insert(account_id, HashSet::from([endpoint_id]));
}
}
fn get_cache_times(&self) -> (Instant, Option<Instant>) {
let mut valid_since = Instant::now() - self.config.ttl;
// Only ignore cache if ttl is disabled.

View File

@@ -238,6 +238,8 @@ pub(crate) struct GetEndpointAccessControl {
pub(crate) allowed_ips: Option<Vec<IpPattern>>,
pub(crate) project_id: Option<ProjectIdInt>,
pub(crate) allowed_vpc_endpoint_ids: Option<Vec<EndpointIdInt>>,
pub(crate) block_public_connections: Option<bool>,
pub(crate) block_vpc_connections: Option<bool>,
}
// Manually implement debug to omit sensitive info.

View File

@@ -7,7 +7,7 @@ use std::sync::OnceLock;
use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
use rustc_hash::FxHasher;
use crate::types::{BranchId, EndpointId, ProjectId, RoleName};
use crate::types::{AccountId, BranchId, EndpointId, ProjectId, RoleName};
pub trait InternId: Sized + 'static {
fn get_interner() -> &'static StringInterner<Self>;
@@ -206,6 +206,26 @@ impl From<ProjectId> for ProjectIdInt {
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct AccountIdTag;
impl InternId for AccountIdTag {
fn get_interner() -> &'static StringInterner<Self> {
static ROLE_NAMES: OnceLock<StringInterner<AccountIdTag>> = OnceLock::new();
ROLE_NAMES.get_or_init(Default::default)
}
}
pub type AccountIdInt = InternedString<AccountIdTag>;
impl From<&AccountId> for AccountIdInt {
fn from(value: &AccountId) -> Self {
AccountIdTag::get_interner().get_or_intern(value)
}
}
impl From<AccountId> for AccountIdInt {
fn from(value: AccountId) -> Self {
AccountIdTag::get_interner().get_or_intern(&value)
}
}
#[cfg(test)]
mod tests {
use std::sync::OnceLock;

View File

@@ -556,6 +556,9 @@ pub enum RedisEventsCount {
CancelSession,
PasswordUpdate,
AllowedIpsUpdate,
AllowedVpcEndpointIdsUpdateForProjects,
AllowedVpcEndpointIdsUpdateForAllProjectsInOrg,
BlockPublicOrVpcAccessUpdate,
}
pub struct ThreadPoolWorkers(usize);

View File

@@ -11,7 +11,7 @@ use uuid::Uuid;
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::cache::project_info::ProjectInfoCache;
use crate::cancellation::{CancelMap, CancellationHandler};
use crate::intern::{ProjectIdInt, RoleNameInt};
use crate::intern::{AccountIdInt, ProjectIdInt, RoleNameInt};
use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
use tracing::Instrument;
@@ -39,6 +39,27 @@ pub(crate) enum Notification {
AllowedIpsUpdate {
allowed_ips_update: AllowedIpsUpdate,
},
#[serde(
rename = "/allowed_vpc_endpoint_ids_updated_for_projects",
deserialize_with = "deserialize_json_string"
)]
AllowedVpcEndpointIdsUpdateForProjects {
allowed_vpc_endpoint_ids_update_for_projects: AllowedVpcEndpointIdsUpdateForProjects,
},
#[serde(
rename = "/allowed_vpc_endpoint_ids_updated_for_org",
deserialize_with = "deserialize_json_string"
)]
AllowedVpcEndpointIdsUpdateForAllProjectsInOrg {
allowed_vpc_endpoint_ids_update_for_org: AllowedVpcEndpointIdsUpdateForAllProjectsInOrg,
},
#[serde(
rename = "/block_public_or_vpc_access_updated",
deserialize_with = "deserialize_json_string"
)]
BlockPublicOrVpcAccessUpdate {
block_public_or_vpc_access_update: BlockPublicOrVpcAccessUpdate,
},
#[serde(
rename = "/password_updated",
deserialize_with = "deserialize_json_string"
@@ -51,6 +72,22 @@ pub(crate) enum Notification {
pub(crate) struct AllowedIpsUpdate {
project_id: ProjectIdInt,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct AllowedVpcEndpointIdsUpdateForProjects {
project_ids: Vec<ProjectIdInt>,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct AllowedVpcEndpointIdsUpdateForAllProjectsInOrg {
account_id: AccountIdInt,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct BlockPublicOrVpcAccessUpdate {
project_id: ProjectIdInt,
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub(crate) struct PasswordUpdate {
project_id: ProjectIdInt,
@@ -164,7 +201,11 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
}
}
}
Notification::AllowedIpsUpdate { .. } | Notification::PasswordUpdate { .. } => {
Notification::AllowedIpsUpdate { .. }
| Notification::PasswordUpdate { .. }
| Notification::AllowedVpcEndpointIdsUpdateForProjects { .. }
| Notification::AllowedVpcEndpointIdsUpdateForAllProjectsInOrg { .. }
| Notification::BlockPublicOrVpcAccessUpdate { .. } => {
invalidate_cache(self.cache.clone(), msg.clone());
if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
Metrics::get()
@@ -176,6 +217,27 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
.proxy
.redis_events_count
.inc(RedisEventsCount::PasswordUpdate);
} else if matches!(
msg,
Notification::AllowedVpcEndpointIdsUpdateForProjects { .. }
) {
Metrics::get()
.proxy
.redis_events_count
.inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForProjects);
} else if matches!(
msg,
Notification::AllowedVpcEndpointIdsUpdateForAllProjectsInOrg { .. }
) {
Metrics::get()
.proxy
.redis_events_count
.inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForAllProjectsInOrg);
} else if matches!(msg, Notification::BlockPublicOrVpcAccessUpdate { .. }) {
Metrics::get()
.proxy
.redis_events_count
.inc(RedisEventsCount::BlockPublicOrVpcAccessUpdate);
}
// It might happen that the invalid entry is on the way to be cached.
// To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
@@ -197,6 +259,21 @@ fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
Notification::AllowedIpsUpdate { allowed_ips_update } => {
cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id);
}
Notification::AllowedVpcEndpointIdsUpdateForProjects {
allowed_vpc_endpoint_ids_update_for_projects,
} => cache.invalidate_allowed_vpc_endpoint_ids_for_projects(
allowed_vpc_endpoint_ids_update_for_projects.project_ids,
),
Notification::AllowedVpcEndpointIdsUpdateForAllProjectsInOrg {
allowed_vpc_endpoint_ids_update_for_org,
} => cache.invalidate_allowed_vpc_endpoint_ids_for_org(
allowed_vpc_endpoint_ids_update_for_org.account_id,
),
Notification::BlockPublicOrVpcAccessUpdate {
block_public_or_vpc_access_update,
} => cache.invalidate_block_public_or_vpc_access_for_project(
block_public_or_vpc_access_update.project_id,
),
Notification::PasswordUpdate { password_update } => cache
.invalidate_role_secret_for_project(
password_update.project_id,

View File

@@ -97,6 +97,8 @@ smol_str_wrapper!(EndpointId);
smol_str_wrapper!(BranchId);
// 90% of project strings are 23 characters or less.
smol_str_wrapper!(ProjectId);
// 90% of account strings are 23 characters or less.
smol_str_wrapper!(AccountId);
// will usually equal endpoint ID
smol_str_wrapper!(EndpointCacheKey);