diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index d37c107323..c812779e30 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -1,13 +1,11 @@ use std::collections::{HashMap, HashSet, hash_map}; use std::convert::Infallible; -use std::sync::atomic::AtomicU64; use std::time::Duration; use async_trait::async_trait; use clashmap::ClashMap; use clashmap::mapref::one::Ref; use rand::{Rng, thread_rng}; -use tokio::sync::Mutex; use tokio::time::Instant; use tracing::{debug, info}; @@ -22,31 +20,23 @@ pub(crate) trait ProjectInfoCache { fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt); fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt); fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt); - async fn decrement_active_listeners(&self); - async fn increment_active_listeners(&self); } struct Entry { - created_at: Instant, + expires_at: Instant, value: T, } impl Entry { - pub(crate) fn new(value: T) -> Self { + pub(crate) fn new(value: T, ttl: Duration) -> Self { Self { - created_at: Instant::now(), + expires_at: Instant::now() + ttl, value, } } - pub(crate) fn get(&self, valid_since: Instant) -> Option<&T> { - (valid_since < self.created_at).then_some(&self.value) - } -} - -impl From for Entry { - fn from(value: T) -> Self { - Self::new(value) + pub(crate) fn get(&self) -> Option<&T> { + (self.expires_at > Instant::now()).then_some(&self.value) } } @@ -56,18 +46,12 @@ struct EndpointInfo { } impl EndpointInfo { - pub(crate) fn get_role_secret( - &self, - role_name: RoleNameInt, - valid_since: Instant, - ) -> Option { - let controls = self.role_controls.get(&role_name)?; - controls.get(valid_since).cloned() + pub(crate) fn get_role_secret(&self, role_name: RoleNameInt) -> Option { + self.role_controls.get(&role_name)?.get().cloned() } - pub(crate) fn get_controls(&self, valid_since: Instant) -> Option { - let controls = self.controls.as_ref()?; - controls.get(valid_since).cloned() + pub(crate) fn get_controls(&self) -> Option { + self.controls.as_ref()?.get().cloned() } pub(crate) fn invalidate_endpoint(&mut self) { @@ -92,11 +76,8 @@ pub struct ProjectInfoCacheImpl { project2ep: ClashMap>, // FIXME(stefan): we need a way to GC the account2ep map. account2ep: ClashMap>, - config: ProjectInfoCacheOptions, - start_time: Instant, - ttl_disabled_since_us: AtomicU64, - active_listeners_lock: Mutex, + config: ProjectInfoCacheOptions, } #[async_trait] @@ -152,29 +133,6 @@ impl ProjectInfoCache for ProjectInfoCacheImpl { } } } - - async fn decrement_active_listeners(&self) { - let mut listeners_guard = self.active_listeners_lock.lock().await; - if *listeners_guard == 0 { - tracing::error!("active_listeners count is already 0, something is broken"); - return; - } - *listeners_guard -= 1; - if *listeners_guard == 0 { - self.ttl_disabled_since_us - .store(u64::MAX, std::sync::atomic::Ordering::SeqCst); - } - } - - async fn increment_active_listeners(&self) { - let mut listeners_guard = self.active_listeners_lock.lock().await; - *listeners_guard += 1; - if *listeners_guard == 1 { - let new_ttl = (self.start_time.elapsed() + self.config.ttl).as_micros() as u64; - self.ttl_disabled_since_us - .store(new_ttl, std::sync::atomic::Ordering::SeqCst); - } - } } impl ProjectInfoCacheImpl { @@ -184,9 +142,6 @@ impl ProjectInfoCacheImpl { project2ep: ClashMap::new(), account2ep: ClashMap::new(), config, - ttl_disabled_since_us: AtomicU64::new(u64::MAX), - start_time: Instant::now(), - active_listeners_lock: Mutex::new(0), } } @@ -203,19 +158,17 @@ impl ProjectInfoCacheImpl { endpoint_id: &EndpointId, role_name: &RoleName, ) -> Option { - let valid_since = self.get_cache_times(); let role_name = RoleNameInt::get(role_name)?; let endpoint_info = self.get_endpoint_cache(endpoint_id)?; - endpoint_info.get_role_secret(role_name, valid_since) + endpoint_info.get_role_secret(role_name) } pub(crate) fn get_endpoint_access( &self, endpoint_id: &EndpointId, ) -> Option { - let valid_since = self.get_cache_times(); let endpoint_info = self.get_endpoint_cache(endpoint_id)?; - endpoint_info.get_controls(valid_since) + endpoint_info.get_controls() } pub(crate) fn insert_endpoint_access( @@ -237,8 +190,8 @@ impl ProjectInfoCacheImpl { return; } - let controls = Entry::from(controls); - let role_controls = Entry::from(role_controls); + let controls = Entry::new(controls, self.config.ttl); + let role_controls = Entry::new(role_controls, self.config.ttl); match self.cache.entry(endpoint_id) { clashmap::Entry::Vacant(e) => { @@ -275,27 +228,6 @@ impl ProjectInfoCacheImpl { } } - fn ignore_ttl_since(&self) -> Option { - let ttl_disabled_since_us = self - .ttl_disabled_since_us - .load(std::sync::atomic::Ordering::Relaxed); - - if ttl_disabled_since_us == u64::MAX { - return None; - } - - Some(self.start_time + Duration::from_micros(ttl_disabled_since_us)) - } - - fn get_cache_times(&self) -> Instant { - let mut valid_since = Instant::now() - self.config.ttl; - if let Some(ignore_ttl_since) = self.ignore_ttl_since() { - // We are fine if entry is not older than ttl or was added before we are getting notifications. - valid_since = valid_since.min(ignore_ttl_since); - } - valid_since - } - pub fn maybe_invalidate_role_secret(&self, endpoint_id: &EndpointId, role_name: &RoleName) { let Some(endpoint_id) = EndpointIdInt::get(endpoint_id) else { return; @@ -313,16 +245,7 @@ impl ProjectInfoCacheImpl { return; }; - let created_at = role_controls.get().created_at; - let expire = match self.ignore_ttl_since() { - // if ignoring TTL, we should still try and roll the password if it's old - // and we the client gave an incorrect password. There could be some lag on the redis channel. - Some(_) => created_at + self.config.ttl < Instant::now(), - // edge case: redis is down, let's be generous and invalidate the cache immediately. - None => true, - }; - - if expire { + if role_controls.get().expires_at <= Instant::now() { role_controls.remove(); } } diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 973a4c5b02..a6d376562b 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -265,10 +265,7 @@ async fn handle_messages( return Ok(()); } let mut conn = match try_connect(&redis).await { - Ok(conn) => { - handler.cache.increment_active_listeners().await; - conn - } + Ok(conn) => conn, Err(e) => { tracing::error!( "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}" @@ -287,11 +284,9 @@ async fn handle_messages( } } if cancellation_token.is_cancelled() { - handler.cache.decrement_active_listeners().await; return Ok(()); } } - handler.cache.decrement_active_listeners().await; } }