diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index 81c88e3ddd..9a4be2f904 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -18,6 +18,7 @@ use crate::types::{EndpointId, RoleName}; #[async_trait] pub(crate) trait ProjectInfoCache { + fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt); fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt); fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt); fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt); @@ -100,6 +101,13 @@ pub struct ProjectInfoCacheImpl { #[async_trait] impl ProjectInfoCache for ProjectInfoCacheImpl { + fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) { + info!("invalidating endpoint access for `{endpoint_id}`"); + if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { + endpoint_info.invalidate_endpoint(); + } + } + fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) { info!("invalidating endpoint access for project `{project_id}`"); let endpoints = self diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index 4b22c912eb..4c340edfd5 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -610,11 +610,11 @@ pub enum RedisEventsCount { BranchCreated, ProjectCreated, CancelSession, - PasswordUpdate, - AllowedIpsUpdate, - AllowedVpcEndpointIdsUpdateForProjects, - AllowedVpcEndpointIdsUpdateForAllProjectsInOrg, - BlockPublicOrVpcAccessUpdate, + InvalidateRole, + InvalidateEndpoint, + InvalidateProject, + InvalidateProjects, + InvalidateOrg, } pub struct ThreadPoolWorkers(usize); diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index a9d6b40603..6c8260027f 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -3,12 +3,12 @@ use std::sync::Arc; use futures::StreamExt; use redis::aio::PubSub; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; use tokio_util::sync::CancellationToken; use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::cache::project_info::ProjectInfoCache; -use crate::intern::{AccountIdInt, ProjectIdInt, RoleNameInt}; +use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; use crate::metrics::{Metrics, RedisErrors, RedisEventsCount}; const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates"; @@ -27,42 +27,37 @@ struct NotificationHeader<'a> { topic: &'a str, } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] #[serde(tag = "topic", content = "data")] -pub(crate) enum Notification { +enum Notification { #[serde( - rename = "/allowed_ips_updated", + rename = "/account_settings_update", + alias = "/allowed_vpc_endpoints_updated_for_org", deserialize_with = "deserialize_json_string" )] - AllowedIpsUpdate { - allowed_ips_update: AllowedIpsUpdate, - }, + AccountSettingsUpdate(InvalidateAccount), + #[serde( - rename = "/block_public_or_vpc_access_updated", + rename = "/endpoint_settings_update", deserialize_with = "deserialize_json_string" )] - BlockPublicOrVpcAccessUpdated { - block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated, - }, + EndpointSettingsUpdate(InvalidateEndpoint), + #[serde( - rename = "/allowed_vpc_endpoints_updated_for_org", + rename = "/project_settings_update", + alias = "/allowed_ips_updated", + alias = "/block_public_or_vpc_access_updated", + alias = "/allowed_vpc_endpoints_updated_for_projects", deserialize_with = "deserialize_json_string" )] - AllowedVpcEndpointsUpdatedForOrg { - allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg, - }, + ProjectSettingsUpdate(InvalidateProject), + #[serde( - rename = "/allowed_vpc_endpoints_updated_for_projects", + rename = "/role_setting_update", + alias = "/password_updated", deserialize_with = "deserialize_json_string" )] - AllowedVpcEndpointsUpdatedForProjects { - allowed_vpc_endpoints_updated_for_projects: AllowedVpcEndpointsUpdatedForProjects, - }, - #[serde( - rename = "/password_updated", - deserialize_with = "deserialize_json_string" - )] - PasswordUpdate { password_update: PasswordUpdate }, + RoleSettingUpdate(InvalidateRole), #[serde( other, @@ -72,28 +67,56 @@ pub(crate) enum Notification { UnknownTopic, } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] -pub(crate) struct AllowedIpsUpdate { - project_id: ProjectIdInt, +#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] +#[serde(rename_all = "snake_case")] +enum InvalidateEndpoint { + EndpointId(EndpointIdInt), + EndpointIds(Vec), +} +impl std::ops::Deref for InvalidateEndpoint { + type Target = [EndpointIdInt]; + fn deref(&self) -> &Self::Target { + match self { + Self::EndpointId(id) => std::slice::from_ref(id), + Self::EndpointIds(ids) => ids, + } + } } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] -pub(crate) struct BlockPublicOrVpcAccessUpdated { - project_id: ProjectIdInt, +#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] +#[serde(rename_all = "snake_case")] +enum InvalidateProject { + ProjectId(ProjectIdInt), + ProjectIds(Vec), +} +impl std::ops::Deref for InvalidateProject { + type Target = [ProjectIdInt]; + fn deref(&self) -> &Self::Target { + match self { + Self::ProjectId(id) => std::slice::from_ref(id), + Self::ProjectIds(ids) => ids, + } + } } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] -pub(crate) struct AllowedVpcEndpointsUpdatedForOrg { - account_id: AccountIdInt, +#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] +#[serde(rename_all = "snake_case")] +enum InvalidateAccount { + AccountId(AccountIdInt), + AccountIds(Vec), +} +impl std::ops::Deref for InvalidateAccount { + type Target = [AccountIdInt]; + fn deref(&self) -> &Self::Target { + match self { + Self::AccountId(id) => std::slice::from_ref(id), + Self::AccountIds(ids) => ids, + } + } } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] -pub(crate) struct AllowedVpcEndpointsUpdatedForProjects { - project_ids: Vec, -} - -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] -pub(crate) struct PasswordUpdate { +#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] +struct InvalidateRole { project_id: ProjectIdInt, role_name: RoleNameInt, } @@ -177,41 +200,29 @@ impl MessageHandler { tracing::debug!(?msg, "received a message"); match msg { - Notification::AllowedIpsUpdate { .. } - | Notification::PasswordUpdate { .. } - | Notification::BlockPublicOrVpcAccessUpdated { .. } - | Notification::AllowedVpcEndpointsUpdatedForOrg { .. } - | Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => { + Notification::RoleSettingUpdate { .. } + | Notification::EndpointSettingsUpdate { .. } + | Notification::ProjectSettingsUpdate { .. } + | Notification::AccountSettingsUpdate { .. } => { invalidate_cache(self.cache.clone(), msg.clone()); - if matches!(msg, Notification::AllowedIpsUpdate { .. }) { - Metrics::get() - .proxy - .redis_events_count - .inc(RedisEventsCount::AllowedIpsUpdate); - } else if matches!(msg, Notification::PasswordUpdate { .. }) { - Metrics::get() - .proxy - .redis_events_count - .inc(RedisEventsCount::PasswordUpdate); - } else if matches!( - msg, - Notification::AllowedVpcEndpointsUpdatedForProjects { .. } - ) { - Metrics::get() - .proxy - .redis_events_count - .inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForProjects); - } else if matches!(msg, Notification::AllowedVpcEndpointsUpdatedForOrg { .. }) { - Metrics::get() - .proxy - .redis_events_count - .inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForAllProjectsInOrg); - } else if matches!(msg, Notification::BlockPublicOrVpcAccessUpdated { .. }) { - Metrics::get() - .proxy - .redis_events_count - .inc(RedisEventsCount::BlockPublicOrVpcAccessUpdate); + + let m = &Metrics::get().proxy.redis_events_count; + match msg { + Notification::RoleSettingUpdate { .. } => { + m.inc(RedisEventsCount::InvalidateRole); + } + Notification::EndpointSettingsUpdate { .. } => { + m.inc(RedisEventsCount::InvalidateEndpoint); + } + Notification::ProjectSettingsUpdate { .. } => { + m.inc(RedisEventsCount::InvalidateProject); + } + Notification::AccountSettingsUpdate { .. } => { + m.inc(RedisEventsCount::InvalidateOrg); + } + Notification::UnknownTopic => {} } + // TODO: add additional metrics for the other event types. // It might happen that the invalid entry is on the way to be cached. @@ -233,30 +244,23 @@ impl MessageHandler { fn invalidate_cache(cache: Arc, msg: Notification) { match msg { - Notification::AllowedIpsUpdate { - allowed_ips_update: AllowedIpsUpdate { project_id }, - } - | Notification::BlockPublicOrVpcAccessUpdated { - block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated { project_id }, - } => cache.invalidate_endpoint_access_for_project(project_id), - Notification::AllowedVpcEndpointsUpdatedForOrg { - allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg { account_id }, - } => cache.invalidate_endpoint_access_for_org(account_id), - Notification::AllowedVpcEndpointsUpdatedForProjects { - allowed_vpc_endpoints_updated_for_projects: - AllowedVpcEndpointsUpdatedForProjects { project_ids }, - } => { - for project in project_ids { - cache.invalidate_endpoint_access_for_project(project); - } - } - Notification::PasswordUpdate { - password_update: - PasswordUpdate { - project_id, - role_name, - }, - } => cache.invalidate_role_secret_for_project(project_id, role_name), + Notification::EndpointSettingsUpdate(ids) => ids + .iter() + .for_each(|&id| cache.invalidate_endpoint_access(id)), + + Notification::AccountSettingsUpdate(ids) => ids + .iter() + .for_each(|&id| cache.invalidate_endpoint_access_for_org(id)), + + Notification::ProjectSettingsUpdate(ids) => ids + .iter() + .for_each(|&id| cache.invalidate_endpoint_access_for_project(id)), + + Notification::RoleSettingUpdate(InvalidateRole { + project_id, + role_name, + }) => cache.invalidate_role_secret_for_project(project_id, role_name), + Notification::UnknownTopic => unreachable!(), } } @@ -353,11 +357,32 @@ mod tests { let result: Notification = serde_json::from_str(&text)?; assert_eq!( result, - Notification::AllowedIpsUpdate { - allowed_ips_update: AllowedIpsUpdate { - project_id: (&project_id).into() - } - } + Notification::ProjectSettingsUpdate(InvalidateProject::ProjectId((&project_id).into())) + ); + + Ok(()) + } + + #[test] + fn parse_multiple_projects() -> anyhow::Result<()> { + let project_id1: ProjectId = "new_project1".into(); + let project_id2: ProjectId = "new_project2".into(); + let data = format!("{{\"project_ids\": [\"{project_id1}\",\"{project_id2}\"]}}"); + let text = json!({ + "type": "message", + "topic": "/allowed_vpc_endpoints_updated_for_projects", + "data": data, + "extre_fields": "something" + }) + .to_string(); + + let result: Notification = serde_json::from_str(&text)?; + assert_eq!( + result, + Notification::ProjectSettingsUpdate(InvalidateProject::ProjectIds(vec![ + (&project_id1).into(), + (&project_id2).into() + ])) ); Ok(()) @@ -379,12 +404,10 @@ mod tests { let result: Notification = serde_json::from_str(&text)?; assert_eq!( result, - Notification::PasswordUpdate { - password_update: PasswordUpdate { - project_id: (&project_id).into(), - role_name: (&role_name).into(), - } - } + Notification::RoleSettingUpdate(InvalidateRole { + project_id: (&project_id).into(), + role_name: (&role_name).into(), + }) ); Ok(())