From 8549b42bca720d24c422725f040bd1efc90a86da Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Sat, 26 Jul 2025 18:32:27 +0100 Subject: [PATCH] implement removal --- proxy/src/cache/project_info.rs | 115 +++++++++++++++++++++++++++----- 1 file changed, 100 insertions(+), 15 deletions(-) diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index 54780ae468..949972ff6e 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -36,10 +36,11 @@ pub struct ProjectInfoCache { } type RefCount = Mutex; + // This is rather hacky. // We use an ordered map of (K, V) -> RefCount. // We use range queries over `(K, _)..(K+1, _)` to do the invalidation. -// We use the RefCount to know when to remove the mappings. +// We use the RefCount to know when to remove entries. type RefCountMultiSet = SkipMap, RefCount>; #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Debug)] @@ -70,6 +71,37 @@ struct Entry { value: T, } +impl Entry { + fn dec_ref_counts( + self, + project2ep: &RefCountMultiSet, + account2ep: &RefCountMultiSet, + endpoint_id: EndpointIdInt, + ) { + if let Some(project_id) = self.project_id { + dec_ref_count(project2ep, project_id, endpoint_id); + } + if let Some(account_id) = self.account_id { + dec_ref_count(account2ep, account_id, endpoint_id); + } + } +} + +fn dec_ref_count( + id2ep: &RefCountMultiSet, + id: Id, + endpoint_id: EndpointIdInt, +) { + if let Some(entry) = id2ep.get(&KeyValue(id, endpoint_id)) { + let mut count = entry.value().lock_propagate_poison(); + *count -= 1; + if *count == 0 { + // remove the entry while holding the lock + entry.remove(); + } + } +} + impl ProjectInfoCache { pub fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) { info!("invalidating endpoint access for `{endpoint_id}`"); @@ -119,26 +151,44 @@ impl ProjectInfoCache { .capacity .set(CacheKind::ProjectInfoEndpoints, config.size as i64); - let project2ep = Arc::new(RefCountMultiSet::new()); - let account2ep = Arc::new(RefCountMultiSet::new()); + let project2ep = Arc::new(RefCountMultiSet::::new()); + let account2ep = Arc::new(RefCountMultiSet::::new()); + let project2ep1 = Arc::clone(&project2ep); + let project2ep2 = Arc::clone(&project2ep); + let account2ep1 = Arc::clone(&account2ep); + let account2ep2 = Arc::clone(&account2ep); // we cache errors for 30 seconds, unless retry_at is set. let expiry = CplaneExpiry::default(); Self { role_controls: Cache::builder() .name("role_access_controls") - .eviction_listener(|_k, _v, cause| { - eviction_listener(CacheKind::ProjectInfoRoles, cause); - }) + .eviction_listener( + move |k, v: ControlPlaneResult>, cause| { + eviction_listener(CacheKind::ProjectInfoRoles, cause); + + let (endpoint_id, _): (EndpointIdInt, RoleNameInt) = *k; + if let Ok(v) = v { + v.dec_ref_counts(&project2ep1, &account2ep1, endpoint_id); + } + }, + ) .max_capacity(config.size * config.max_roles) .time_to_live(config.ttl) .expire_after(expiry) .build(), ep_controls: Cache::builder() .name("endpoint_access_controls") - .eviction_listener(|_k, _v, cause| { - eviction_listener(CacheKind::ProjectInfoEndpoints, cause); - }) + .eviction_listener( + move |k, v: ControlPlaneResult>, cause| { + eviction_listener(CacheKind::ProjectInfoEndpoints, cause); + + let endpoint_id: EndpointIdInt = *k; + if let Ok(v) = v { + v.dec_ref_counts(&project2ep2, &account2ep2, endpoint_id); + } + }, + ) .max_capacity(config.size) .time_to_live(config.ttl) .expire_after(expiry) @@ -188,11 +238,12 @@ impl ProjectInfoCache { controls: EndpointAccessControl, role_controls: RoleAccessControl, ) { + // 2 corresponds to how many cache inserts we do. if let Some(account_id) = account_id { - self.insert_account2endpoint(account_id, endpoint_id); + self.inc_account2ep_ref(account_id, endpoint_id, 2); } if let Some(project_id) = project_id { - self.insert_project2endpoint(project_id, endpoint_id); + self.inc_project2ep_ref(project_id, endpoint_id, 2); } debug!( @@ -256,18 +307,18 @@ impl ProjectInfoCache { .insert((endpoint_id, role_name), Err(msg)); } - fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) { + fn inc_project2ep_ref(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt, x: usize) { let entry = self .project2ep .get_or_insert(KeyValue(project_id, endpoint_id), Mutex::new(0)); - *entry.value().lock_propagate_poison() += 1; + *entry.value().lock_propagate_poison() += x; } - fn insert_account2endpoint(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt) { + fn inc_account2ep_ref(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt, x: usize) { let entry = self .account2ep .get_or_insert(KeyValue(account_id, endpoint_id), Mutex::new(0)); - *entry.value().lock_propagate_poison() += 1; + *entry.value().lock_propagate_poison() += x; } pub fn maybe_invalidate_role_secret(&self, _endpoint_id: &EndpointId, _role_name: &RoleName) { @@ -332,6 +383,16 @@ mod tests { }, ); + cache.ep_controls.run_pending_tasks(); + cache.role_controls.run_pending_tasks(); + + // check the project mappings are there + assert_eq!(cache.project2ep.len(), 1); + + // check the ref counts + let entry = cache.project2ep.front().unwrap(); + assert_eq!(*entry.value().lock_propagate_poison(), 2); + cache.insert_endpoint_access( account_id, project_id, @@ -348,6 +409,17 @@ mod tests { }, ); + cache.ep_controls.run_pending_tasks(); + cache.role_controls.run_pending_tasks(); + + // check the project mappings are still there + assert_eq!(cache.project2ep.len(), 1); + + // check the ref counts + let entry = cache.project2ep.front().unwrap(); + assert_eq!(*entry.value().lock_propagate_poison(), 3); + + // check both entries exist let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); assert_eq!(cached.unwrap().secret, secret1); @@ -375,13 +447,26 @@ mod tests { }, ); + cache.ep_controls.run_pending_tasks(); cache.role_controls.run_pending_tasks(); + assert_eq!(cache.role_controls.entry_count(), 2); + // check the project mappings are still there + assert_eq!(cache.project2ep.len(), 1); + + // check the ref counts are unchanged. + let entry = cache.project2ep.front().unwrap(); + assert_eq!(*entry.value().lock_propagate_poison(), 3); + tokio::time::sleep(Duration::from_secs(2)).await; + cache.ep_controls.run_pending_tasks(); cache.role_controls.run_pending_tasks(); assert_eq!(cache.role_controls.entry_count(), 0); + + // check the project/account mappings are no longer there + assert!(cache.project2ep.is_empty()); } #[tokio::test]