diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index d8a1d261ce..10cc4ceee1 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -5,9 +5,11 @@ use std::{ time::Duration, }; +use async_trait::async_trait; use dashmap::DashMap; use rand::{thread_rng, Rng}; use smol_str::SmolStr; +use tokio::sync::Mutex; use tokio::time::Instant; use tracing::{debug, info}; @@ -21,11 +23,12 @@ use crate::{ use super::{Cache, Cached}; +#[async_trait] pub trait ProjectInfoCache { fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt); fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt); - fn enable_ttl(&self); - fn disable_ttl(&self); + async fn decrement_active_listeners(&self); + async fn increment_active_listeners(&self); } struct Entry { @@ -116,8 +119,10 @@ pub struct ProjectInfoCacheImpl { start_time: Instant, ttl_disabled_since_us: AtomicU64, + active_listeners_lock: Mutex, } +#[async_trait] impl ProjectInfoCache for ProjectInfoCacheImpl { fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt) { info!("invalidating allowed ips for project `{}`", project_id); @@ -148,15 +153,27 @@ impl ProjectInfoCache for ProjectInfoCacheImpl { } } } - fn enable_ttl(&self) { - self.ttl_disabled_since_us - .store(u64::MAX, std::sync::atomic::Ordering::Relaxed); + async fn decrement_active_listeners(&self) { + let mut listeners_guard = self.active_listeners_lock.lock().await; + if *listeners_guard == 0 { + tracing::error!("active_listeners count is already 0, something is broken"); + return; + } + *listeners_guard -= 1; + if *listeners_guard == 0 { + self.ttl_disabled_since_us + .store(u64::MAX, std::sync::atomic::Ordering::SeqCst); + } } - fn disable_ttl(&self) { - let new_ttl = (self.start_time.elapsed() + self.config.ttl).as_micros() as u64; - self.ttl_disabled_since_us - .store(new_ttl, std::sync::atomic::Ordering::Relaxed); + async fn increment_active_listeners(&self) { + let mut listeners_guard = self.active_listeners_lock.lock().await; + *listeners_guard += 1; + if *listeners_guard == 1 { + let new_ttl = (self.start_time.elapsed() + self.config.ttl).as_micros() as u64; + self.ttl_disabled_since_us + .store(new_ttl, std::sync::atomic::Ordering::SeqCst); + } } } @@ -168,6 +185,7 @@ impl ProjectInfoCacheImpl { config, ttl_disabled_since_us: AtomicU64::new(u64::MAX), start_time: Instant::now(), + active_listeners_lock: Mutex::new(0), } } @@ -432,7 +450,7 @@ mod tests { ttl: Duration::from_secs(1), gc_interval: Duration::from_secs(600), })); - cache.clone().disable_ttl(); + cache.clone().increment_active_listeners().await; tokio::time::advance(Duration::from_secs(2)).await; let project_id: ProjectId = "project".into(); @@ -489,7 +507,7 @@ mod tests { } #[tokio::test] - async fn test_disable_ttl_invalidate_added_before() { + async fn test_increment_active_listeners_invalidate_added_before() { tokio::time::pause(); let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions { size: 2, @@ -514,7 +532,7 @@ mod tests { (&user1).into(), secret1.clone(), ); - cache.clone().disable_ttl(); + cache.clone().increment_active_listeners().await; tokio::time::advance(Duration::from_millis(100)).await; cache.insert_role_secret( (&project_id).into(), diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index ba4dfb755e..87d723d17e 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -4,6 +4,7 @@ use futures::StreamExt; use pq_proto::CancelKeyData; use redis::aio::PubSub; use serde::{Deserialize, Serialize}; +use tokio_util::sync::CancellationToken; use uuid::Uuid; use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider; @@ -77,6 +78,16 @@ struct MessageHandler { region_id: String, } +impl Clone for MessageHandler { + fn clone(&self) -> Self { + Self { + cache: self.cache.clone(), + cancellation_handler: self.cancellation_handler.clone(), + region_id: self.region_id.clone(), + } + } +} + impl MessageHandler { pub fn new( cache: Arc, @@ -89,11 +100,11 @@ impl MessageHandler { region_id, } } - pub fn disable_ttl(&self) { - self.cache.disable_ttl(); + pub async fn increment_active_listeners(&self) { + self.cache.increment_active_listeners().await; } - pub fn enable_ttl(&self) { - self.cache.enable_ttl(); + pub async fn decrement_active_listeners(&self) { + self.cache.decrement_active_listeners().await; } #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))] async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> { @@ -182,37 +193,24 @@ fn invalidate_cache(cache: Arc, msg: Notification) { } } -/// Handle console's invalidation messages. -#[tracing::instrument(name = "console_notifications", skip_all)] -pub async fn task_main( +async fn handle_messages( + handler: MessageHandler, redis: ConnectionWithCredentialsProvider, - cache: Arc, - cancel_map: CancelMap, - region_id: String, -) -> anyhow::Result -where - C: ProjectInfoCache + Send + Sync + 'static, -{ - cache.enable_ttl(); - let handler = MessageHandler::new( - cache, - Arc::new(CancellationHandler::<()>::new( - cancel_map, - crate::metrics::CancellationSource::FromRedis, - )), - region_id, - ); - + cancellation_token: CancellationToken, +) -> anyhow::Result<()> { loop { + if cancellation_token.is_cancelled() { + return Ok(()); + } let mut conn = match try_connect(&redis).await { Ok(conn) => { - handler.disable_ttl(); + handler.increment_active_listeners().await; conn } Err(e) => { tracing::error!( - "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}" - ); + "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}" + ); tokio::time::sleep(RECONNECT_TIMEOUT).await; continue; } @@ -226,8 +224,47 @@ where break; } } + if cancellation_token.is_cancelled() { + handler.decrement_active_listeners().await; + return Ok(()); + } } - handler.enable_ttl(); + handler.decrement_active_listeners().await; + } +} + +/// Handle console's invalidation messages. +#[tracing::instrument(name = "redis_notifications", skip_all)] +pub async fn task_main( + redis: ConnectionWithCredentialsProvider, + cache: Arc, + cancel_map: CancelMap, + region_id: String, +) -> anyhow::Result +where + C: ProjectInfoCache + Send + Sync + 'static, +{ + let cancellation_handler = Arc::new(CancellationHandler::<()>::new( + cancel_map, + crate::metrics::CancellationSource::FromRedis, + )); + let handler = MessageHandler::new(cache, cancellation_handler, region_id); + // 6h - 1m. + // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost. + let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60)); + loop { + let cancellation_token = CancellationToken::new(); + interval.tick().await; + + tokio::spawn(handle_messages( + handler.clone(), + redis.clone(), + cancellation_token.clone(), + )); + tokio::spawn(async move { + tokio::time::sleep(std::time::Duration::from_secs(6 * 60 * 60)).await; // 6h. + cancellation_token.cancel(); + }); } }