use std::convert::Infallible; use std::sync::{Arc, Mutex}; use crossbeam_skiplist::SkipMap; use crossbeam_skiplist::equivalent::{Comparable, Equivalent}; use moka::sync::Cache; use tracing::{debug, info}; use crate::cache::common::{ ControlPlaneResult, CplaneExpiry, count_cache_insert, count_cache_outcome, eviction_listener, }; use crate::config::ProjectInfoCacheOptions; use crate::control_plane::messages::{ControlPlaneErrorMessage, Reason}; use crate::control_plane::{EndpointAccessControl, RoleAccessControl}; use crate::ext::LockExt; use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; use crate::metrics::{CacheKind, Metrics}; use crate::types::{EndpointId, RoleName}; /// Cache for project info. /// This is used to cache auth data for endpoints. /// Invalidation is done by console notifications or by TTL (if console notifications are disabled). /// /// We also store endpoint-to-project mapping in the cache, to be able to access per-endpoint data. /// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available? /// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache. pub struct ProjectInfoCache { role_controls: Cache<(EndpointIdInt, RoleNameInt), ControlPlaneResult>>, ep_controls: Cache>>, project2ep: Arc>, account2ep: Arc>, config: ProjectInfoCacheOptions, } type RefCount = Mutex; // This is rather hacky. // We use an ordered map of (K, V) -> RefCount. // We use range queries over `(K, _)..(K+1, _)` to do the invalidation. // We use the RefCount to know when to remove entries. type RefCountMultiSet = SkipMap, RefCount>; #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Debug)] struct KeyValue(K, V); struct Key<'a, K>(&'a K, bool); impl<'a, K> Key<'a, K> { fn prefix(key: &'a K) -> std::ops::Range { Self(key, false)..Self(key, true) } } impl<'a, K: Ord, V> Equivalent> for KeyValue { fn equivalent(&self, key: &Key<'a, K>) -> bool { self.0 == *key.0 && !key.1 } } impl<'a, K: Ord, V> Comparable> for KeyValue { fn compare(&self, key: &Key<'a, K>) -> std::cmp::Ordering { self.0.cmp(key.0).then(false.cmp(&key.1)) } } #[derive(Clone)] struct Entry { project_id: Option, account_id: Option, value: T, } impl Entry { fn dec_ref_counts( self, project2ep: &RefCountMultiSet, account2ep: &RefCountMultiSet, endpoint_id: EndpointIdInt, ) { if let Some(project_id) = self.project_id { dec_ref_count(project2ep, project_id, endpoint_id); } if let Some(account_id) = self.account_id { dec_ref_count(account2ep, account_id, endpoint_id); } } } fn dec_ref_count( id2ep: &RefCountMultiSet, id: Id, endpoint_id: EndpointIdInt, ) { if let Some(entry) = id2ep.get(&KeyValue(id, endpoint_id)) { let mut count = entry.value().lock_propagate_poison(); *count -= 1; if *count == 0 { // remove the entry while holding the lock entry.remove(); } } } impl ProjectInfoCache { pub fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) { info!("invalidating endpoint access for `{endpoint_id}`"); self.ep_controls.invalidate(&endpoint_id); } pub fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) { info!("invalidating endpoint access for project `{project_id}`"); for entry in self.project2ep.range(Key::prefix(&project_id)) { self.ep_controls.invalidate(&entry.key().1); } } pub fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) { info!("invalidating endpoint access for org `{account_id}`"); for entry in self.account2ep.range(Key::prefix(&account_id)) { self.ep_controls.invalidate(&entry.key().1); } } pub fn invalidate_role_secret_for_project( &self, project_id: ProjectIdInt, role_name: RoleNameInt, ) { info!( "invalidating role secret for project_id `{}` and role_name `{}`", project_id, role_name, ); for entry in self.project2ep.range(Key::prefix(&project_id)) { self.role_controls.invalidate(&(entry.key().1, role_name)); } } } impl ProjectInfoCache { pub(crate) fn new(config: ProjectInfoCacheOptions) -> Self { Metrics::get().cache.capacity.set( CacheKind::ProjectInfoRoles, (config.size * config.max_roles) as i64, ); Metrics::get() .cache .capacity .set(CacheKind::ProjectInfoEndpoints, config.size as i64); let project2ep = Arc::new(RefCountMultiSet::::new()); let account2ep = Arc::new(RefCountMultiSet::::new()); let project2ep1 = Arc::clone(&project2ep); let project2ep2 = Arc::clone(&project2ep); let account2ep1 = Arc::clone(&account2ep); let account2ep2 = Arc::clone(&account2ep); // we cache errors for 30 seconds, unless retry_at is set. let expiry = CplaneExpiry::default(); Self { role_controls: Cache::builder() .name("role_access_controls") .eviction_listener( move |k, v: ControlPlaneResult>, cause| { eviction_listener(CacheKind::ProjectInfoRoles, cause); let (endpoint_id, _): (EndpointIdInt, RoleNameInt) = *k; if let Ok(v) = v { v.dec_ref_counts(&project2ep1, &account2ep1, endpoint_id); } }, ) .max_capacity(config.size * config.max_roles) .time_to_live(config.ttl) .expire_after(expiry) .build(), ep_controls: Cache::builder() .name("endpoint_access_controls") .eviction_listener( move |k, v: ControlPlaneResult>, cause| { eviction_listener(CacheKind::ProjectInfoEndpoints, cause); let endpoint_id: EndpointIdInt = *k; if let Ok(v) = v { v.dec_ref_counts(&project2ep2, &account2ep2, endpoint_id); } }, ) .max_capacity(config.size) .time_to_live(config.ttl) .expire_after(expiry) .build(), project2ep, account2ep, config, } } pub(crate) fn get_role_secret( &self, endpoint_id: &EndpointId, role_name: &RoleName, ) -> Option> { let endpoint_id = EndpointIdInt::get(endpoint_id)?; let role_name = RoleNameInt::get(role_name)?; count_cache_outcome( CacheKind::ProjectInfoRoles, self.role_controls .get(&(endpoint_id, role_name)) .map(|e| e.map(|e| e.value)), ) } pub(crate) fn get_endpoint_access( &self, endpoint_id: &EndpointId, ) -> Option> { let endpoint_id = EndpointIdInt::get(endpoint_id)?; count_cache_outcome( CacheKind::ProjectInfoEndpoints, self.ep_controls .get(&endpoint_id) .map(|e| e.map(|e| e.value)), ) } pub(crate) fn insert_endpoint_access( &self, account_id: Option, project_id: Option, endpoint_id: EndpointIdInt, role_name: RoleNameInt, controls: EndpointAccessControl, role_controls: RoleAccessControl, ) { // 2 corresponds to how many cache inserts we do. if let Some(account_id) = account_id { self.inc_account2ep_ref(account_id, endpoint_id, 2); } if let Some(project_id) = project_id { self.inc_project2ep_ref(project_id, endpoint_id, 2); } debug!( key = &*endpoint_id, "created a cache entry for endpoint access" ); count_cache_insert(CacheKind::ProjectInfoEndpoints); count_cache_insert(CacheKind::ProjectInfoRoles); self.ep_controls.insert( endpoint_id, Ok(Entry { account_id, project_id, value: controls, }), ); self.role_controls.insert( (endpoint_id, role_name), Ok(Entry { account_id, project_id, value: role_controls, }), ); } pub(crate) fn insert_endpoint_access_err( &self, endpoint_id: EndpointIdInt, role_name: RoleNameInt, msg: Box, ) { debug!( key = &*endpoint_id, "created a cache entry for an endpoint access error" ); // RoleProtected is the only role-specific error that control plane can give us. // If a given role name does not exist, it still returns a successful response, // just with an empty secret. if msg.get_reason() != Reason::RoleProtected { // We can cache all the other errors in ep_controls because they don't // depend on what role name we pass to control plane. self.ep_controls .entry(endpoint_id) .and_compute_with(|entry| match entry { // leave the entry alone if it's already Ok Some(entry) if entry.value().is_ok() => moka::ops::compute::Op::Nop, // replace the entry _ => { count_cache_insert(CacheKind::ProjectInfoEndpoints); moka::ops::compute::Op::Put(Err(msg.clone())) } }); } count_cache_insert(CacheKind::ProjectInfoRoles); self.role_controls .insert((endpoint_id, role_name), Err(msg)); } fn inc_project2ep_ref(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt, x: usize) { let entry = self .project2ep .get_or_insert(KeyValue(project_id, endpoint_id), Mutex::new(0)); *entry.value().lock_propagate_poison() += x; } fn inc_account2ep_ref(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt, x: usize) { let entry = self .account2ep .get_or_insert(KeyValue(account_id, endpoint_id), Mutex::new(0)); *entry.value().lock_propagate_poison() += x; } pub fn maybe_invalidate_role_secret(&self, _endpoint_id: &EndpointId, _role_name: &RoleName) { // TODO: Expire the value early if the key is idle. // Currently not an issue as we would just use the TTL to decide, which is what already happens. } pub async fn gc_worker(&self) -> anyhow::Result { let mut interval = tokio::time::interval(self.config.gc_interval); loop { interval.tick().await; self.ep_controls.run_pending_tasks(); self.role_controls.run_pending_tasks(); } } } #[cfg(test)] mod tests { use std::sync::Arc; use std::time::Duration; use super::*; use crate::control_plane::messages::{Details, EndpointRateLimitConfig, ErrorInfo, Status}; use crate::control_plane::{AccessBlockerFlags, AuthSecret}; use crate::scram::ServerSecret; #[tokio::test] async fn test_project_info_cache_settings() { let cache = ProjectInfoCache::new(ProjectInfoCacheOptions { size: 1, max_roles: 2, ttl: Duration::from_secs(1), gc_interval: Duration::from_secs(600), }); let project_id: Option = Some(ProjectIdInt::from(&"project".into())); let endpoint_id: EndpointId = "endpoint".into(); let account_id = None; let user1: RoleName = "user1".into(); let user2: RoleName = "user2".into(); let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32]))); let secret2 = None; let allowed_ips = Arc::new(vec![ "127.0.0.1".parse().unwrap(), "127.0.0.2".parse().unwrap(), ]); cache.insert_endpoint_access( account_id, project_id, (&endpoint_id).into(), (&user1).into(), EndpointAccessControl { allowed_ips: allowed_ips.clone(), allowed_vpce: Arc::new(vec![]), flags: AccessBlockerFlags::default(), rate_limits: EndpointRateLimitConfig::default(), }, RoleAccessControl { secret: secret1.clone(), }, ); cache.ep_controls.run_pending_tasks(); cache.role_controls.run_pending_tasks(); // check the project mappings are there assert_eq!(cache.project2ep.len(), 1); // check the ref counts let entry = cache.project2ep.front().unwrap(); assert_eq!(*entry.value().lock_propagate_poison(), 2); cache.insert_endpoint_access( account_id, project_id, (&endpoint_id).into(), (&user2).into(), EndpointAccessControl { allowed_ips: allowed_ips.clone(), allowed_vpce: Arc::new(vec![]), flags: AccessBlockerFlags::default(), rate_limits: EndpointRateLimitConfig::default(), }, RoleAccessControl { secret: secret2.clone(), }, ); cache.ep_controls.run_pending_tasks(); cache.role_controls.run_pending_tasks(); // check the project mappings are still there assert_eq!(cache.project2ep.len(), 1); // check the ref counts let entry = cache.project2ep.front().unwrap(); assert_eq!(*entry.value().lock_propagate_poison(), 3); // check both entries exist let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); assert_eq!(cached.unwrap().secret, secret1); let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap(); assert_eq!(cached.unwrap().secret, secret2); // Shouldn't add more than 2 roles. let user3: RoleName = "user3".into(); let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32]))); cache.role_controls.run_pending_tasks(); cache.insert_endpoint_access( account_id, project_id, (&endpoint_id).into(), (&user3).into(), EndpointAccessControl { allowed_ips: allowed_ips.clone(), allowed_vpce: Arc::new(vec![]), flags: AccessBlockerFlags::default(), rate_limits: EndpointRateLimitConfig::default(), }, RoleAccessControl { secret: secret3.clone(), }, ); cache.ep_controls.run_pending_tasks(); cache.role_controls.run_pending_tasks(); assert_eq!(cache.role_controls.entry_count(), 2); // check the project mappings are still there assert_eq!(cache.project2ep.len(), 1); // check the ref counts are unchanged. let entry = cache.project2ep.front().unwrap(); assert_eq!(*entry.value().lock_propagate_poison(), 3); tokio::time::sleep(Duration::from_secs(2)).await; cache.ep_controls.run_pending_tasks(); cache.role_controls.run_pending_tasks(); assert_eq!(cache.role_controls.entry_count(), 0); // check the project/account mappings are no longer there assert!(cache.project2ep.is_empty()); } #[tokio::test] async fn test_caching_project_info_errors() { let cache = ProjectInfoCache::new(ProjectInfoCacheOptions { size: 10, max_roles: 10, ttl: Duration::from_secs(1), gc_interval: Duration::from_secs(600), }); let project_id = Some(ProjectIdInt::from(&"project".into())); let endpoint_id: EndpointId = "endpoint".into(); let account_id = None; let user1: RoleName = "user1".into(); let user2: RoleName = "user2".into(); let secret = Some(AuthSecret::Scram(ServerSecret::mock([1; 32]))); let role_msg = Box::new(ControlPlaneErrorMessage { error: "role is protected and cannot be used for password-based authentication" .to_owned() .into_boxed_str(), http_status_code: http::StatusCode::NOT_FOUND, status: Some(Status { code: "PERMISSION_DENIED".to_owned().into_boxed_str(), message: "role is protected and cannot be used for password-based authentication" .to_owned() .into_boxed_str(), details: Details { error_info: Some(ErrorInfo { reason: Reason::RoleProtected, }), retry_info: None, user_facing_message: None, }, }), }); let generic_msg = Box::new(ControlPlaneErrorMessage { error: "oh noes".to_owned().into_boxed_str(), http_status_code: http::StatusCode::NOT_FOUND, status: None, }); let get_role_secret = |endpoint_id, role_name| cache.get_role_secret(endpoint_id, role_name).unwrap(); let get_endpoint_access = |endpoint_id| cache.get_endpoint_access(endpoint_id).unwrap(); // stores role-specific errors only for get_role_secret cache.insert_endpoint_access_err((&endpoint_id).into(), (&user1).into(), role_msg.clone()); assert_eq!( get_role_secret(&endpoint_id, &user1).unwrap_err().error, role_msg.error ); assert!(cache.get_endpoint_access(&endpoint_id).is_none()); // stores non-role specific errors for both get_role_secret and get_endpoint_access cache.insert_endpoint_access_err( (&endpoint_id).into(), (&user1).into(), generic_msg.clone(), ); assert_eq!( get_role_secret(&endpoint_id, &user1).unwrap_err().error, generic_msg.error ); assert_eq!( get_endpoint_access(&endpoint_id).unwrap_err().error, generic_msg.error ); // error isn't returned for other roles in the same endpoint assert!(cache.get_role_secret(&endpoint_id, &user2).is_none()); // success for a role does not overwrite errors for other roles cache.insert_endpoint_access( account_id, project_id, (&endpoint_id).into(), (&user2).into(), EndpointAccessControl { allowed_ips: Arc::new(vec![]), allowed_vpce: Arc::new(vec![]), flags: AccessBlockerFlags::default(), rate_limits: EndpointRateLimitConfig::default(), }, RoleAccessControl { secret: secret.clone(), }, ); assert!(get_role_secret(&endpoint_id, &user1).is_err()); assert!(get_role_secret(&endpoint_id, &user2).is_ok()); // ...but does clear the access control error assert!(get_endpoint_access(&endpoint_id).is_ok()); // storing an error does not overwrite successful access control response cache.insert_endpoint_access_err( (&endpoint_id).into(), (&user2).into(), generic_msg.clone(), ); assert!(get_role_secret(&endpoint_id, &user2).is_err()); assert!(get_endpoint_access(&endpoint_id).is_ok()); } }