From 84e2bc52c228e8f552127a9b4a6afdb8081973a8 Mon Sep 17 00:00:00 2001 From: fys <40801205+fengys1996@users.noreply.github.com> Date: Fri, 11 Apr 2025 13:54:28 +0800 Subject: [PATCH] fix: gRPC connection pool leak (#5876) * fix: gRPC connection pool leak * use .config() instead of .inner.config * cancel the bg task if it is running * fix: cr * add unit test for pool release * Avoid potential data races --- Cargo.lock | 1 + src/common/grpc/Cargo.toml | 1 + src/common/grpc/src/channel_manager.rs | 186 ++++++++++++++++++------- 3 files changed, 140 insertions(+), 48 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index da4306f527..1b108a7546 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2099,6 +2099,7 @@ dependencies = [ "rand 0.9.0", "snafu 0.8.5", "tokio", + "tokio-util", "tonic 0.12.3", "tower 0.5.2", ] diff --git a/src/common/grpc/Cargo.toml b/src/common/grpc/Cargo.toml index d20e751e41..4dadf0571b 100644 --- a/src/common/grpc/Cargo.toml +++ b/src/common/grpc/Cargo.toml @@ -25,6 +25,7 @@ lazy_static.workspace = true prost.workspace = true snafu.workspace = true tokio.workspace = true +tokio-util.workspace = true tonic.workspace = true tower.workspace = true diff --git a/src/common/grpc/src/channel_manager.rs b/src/common/grpc/src/channel_manager.rs index 0127829567..713ad58d81 100644 --- a/src/common/grpc/src/channel_manager.rs +++ b/src/common/grpc/src/channel_manager.rs @@ -22,6 +22,7 @@ use dashmap::mapref::entry::Entry; use dashmap::DashMap; use lazy_static::lazy_static; use snafu::{OptionExt, ResultExt}; +use tokio_util::sync::CancellationToken; use tonic::transport::{ Certificate, Channel as InnerChannel, ClientTlsConfig, Endpoint, Identity, Uri, }; @@ -39,18 +40,48 @@ lazy_static! { static ref ID: AtomicU64 = AtomicU64::new(0); } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Default)] pub struct ChannelManager { + inner: Arc, +} + +#[derive(Debug)] +struct Inner { id: u64, config: ChannelConfig, client_tls_config: Option, pool: Arc, - channel_recycle_started: Arc, + channel_recycle_started: AtomicBool, + cancel: CancellationToken, } -impl Default for ChannelManager { +impl Default for Inner { fn default() -> Self { - ChannelManager::with_config(ChannelConfig::default()) + Self::with_config(ChannelConfig::default()) + } +} + +impl Drop for Inner { + fn drop(&mut self) { + // Cancel the channel recycle task. + self.cancel.cancel(); + } +} + +impl Inner { + fn with_config(config: ChannelConfig) -> Self { + let id = ID.fetch_add(1, Ordering::Relaxed); + let pool = Arc::new(Pool::default()); + let cancel = CancellationToken::new(); + + Self { + id, + config, + client_tls_config: None, + pool, + channel_recycle_started: AtomicBool::new(false), + cancel, + } } } @@ -60,19 +91,14 @@ impl ChannelManager { } pub fn with_config(config: ChannelConfig) -> Self { - let id = ID.fetch_add(1, Ordering::Relaxed); - let pool = Arc::new(Pool::default()); + let inner = Inner::with_config(config); Self { - id, - config, - client_tls_config: None, - pool, - channel_recycle_started: Arc::new(AtomicBool::new(false)), + inner: Arc::new(inner), } } pub fn with_tls_config(config: ChannelConfig) -> Result { - let mut cm = Self::with_config(config.clone()); + let mut inner = Inner::with_config(config.clone()); // setup tls let path_config = config.client_tls.context(InvalidTlsConfigSnafu { @@ -88,17 +114,23 @@ impl ChannelManager { .context(InvalidConfigFilePathSnafu)?; let client_identity = Identity::from_pem(client_cert, client_key); - cm.client_tls_config = Some( + inner.client_tls_config = Some( ClientTlsConfig::new() .ca_certificate(server_root_ca_cert) .identity(client_identity), ); - Ok(cm) + Ok(Self { + inner: Arc::new(inner), + }) } pub fn config(&self) -> &ChannelConfig { - &self.config + &self.inner.config + } + + fn pool(&self) -> &Arc { + &self.inner.pool } pub fn get(&self, addr: impl AsRef) -> Result { @@ -106,12 +138,12 @@ impl ChannelManager { let addr = addr.as_ref(); // It will acquire the read lock. - if let Some(inner_ch) = self.pool.get(addr) { + if let Some(inner_ch) = self.pool().get(addr) { return Ok(inner_ch); } // It will acquire the write lock. - let entry = match self.pool.entry(addr.to_string()) { + let entry = match self.pool().entry(addr.to_string()) { Entry::Occupied(entry) => { entry.get().increase_access(); entry.into_ref() @@ -150,7 +182,7 @@ impl ChannelManager { access: AtomicUsize::new(1), use_default_connector: false, }; - self.pool.put(addr, channel); + self.pool().put(addr, channel); Ok(inner_channel) } @@ -159,11 +191,11 @@ impl ChannelManager { where F: FnMut(&String, &mut Channel) -> bool, { - self.pool.retain_channel(f); + self.pool().retain_channel(f); } fn build_endpoint(&self, addr: &str) -> Result { - let http_prefix = if self.client_tls_config.is_some() { + let http_prefix = if self.inner.client_tls_config.is_some() { "https" } else { "http" @@ -172,51 +204,52 @@ impl ChannelManager { let mut endpoint = Endpoint::new(format!("{http_prefix}://{addr}")).context(CreateChannelSnafu)?; - if let Some(dur) = self.config.timeout { + if let Some(dur) = self.config().timeout { endpoint = endpoint.timeout(dur); } - if let Some(dur) = self.config.connect_timeout { + if let Some(dur) = self.config().connect_timeout { endpoint = endpoint.connect_timeout(dur); } - if let Some(limit) = self.config.concurrency_limit { + if let Some(limit) = self.config().concurrency_limit { endpoint = endpoint.concurrency_limit(limit); } - if let Some((limit, dur)) = self.config.rate_limit { + if let Some((limit, dur)) = self.config().rate_limit { endpoint = endpoint.rate_limit(limit, dur); } - if let Some(size) = self.config.initial_stream_window_size { + if let Some(size) = self.config().initial_stream_window_size { endpoint = endpoint.initial_stream_window_size(size); } - if let Some(size) = self.config.initial_connection_window_size { + if let Some(size) = self.config().initial_connection_window_size { endpoint = endpoint.initial_connection_window_size(size); } - if let Some(dur) = self.config.http2_keep_alive_interval { + if let Some(dur) = self.config().http2_keep_alive_interval { endpoint = endpoint.http2_keep_alive_interval(dur); } - if let Some(dur) = self.config.http2_keep_alive_timeout { + if let Some(dur) = self.config().http2_keep_alive_timeout { endpoint = endpoint.keep_alive_timeout(dur); } - if let Some(enabled) = self.config.http2_keep_alive_while_idle { + if let Some(enabled) = self.config().http2_keep_alive_while_idle { endpoint = endpoint.keep_alive_while_idle(enabled); } - if let Some(enabled) = self.config.http2_adaptive_window { + if let Some(enabled) = self.config().http2_adaptive_window { endpoint = endpoint.http2_adaptive_window(enabled); } - if let Some(tls_config) = &self.client_tls_config { + if let Some(tls_config) = &self.inner.client_tls_config { endpoint = endpoint .tls_config(tls_config.clone()) .context(CreateChannelSnafu)?; } endpoint = endpoint - .tcp_keepalive(self.config.tcp_keepalive) - .tcp_nodelay(self.config.tcp_nodelay); + .tcp_keepalive(self.config().tcp_keepalive) + .tcp_nodelay(self.config().tcp_nodelay); Ok(endpoint) } fn trigger_channel_recycling(&self) { if self + .inner .channel_recycle_started .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed) .is_err() @@ -224,13 +257,15 @@ impl ChannelManager { return; } - let pool = self.pool.clone(); - let _handle = common_runtime::spawn_global(async { - recycle_channel_in_loop(pool, RECYCLE_CHANNEL_INTERVAL_SECS).await; + let pool = self.pool().clone(); + let cancel = self.inner.cancel.clone(); + let id = self.inner.id; + let _handle = common_runtime::spawn_global(async move { + recycle_channel_in_loop(pool, id, cancel, RECYCLE_CHANNEL_INTERVAL_SECS).await; }); info!( "ChannelManager: {}, channel recycle is started, running in the background!", - self.id + self.inner.id ); } } @@ -443,11 +478,23 @@ impl Pool { } } -async fn recycle_channel_in_loop(pool: Arc, interval_secs: u64) { +async fn recycle_channel_in_loop( + pool: Arc, + id: u64, + cancel: CancellationToken, + interval_secs: u64, +) { let mut interval = tokio::time::interval(Duration::from_secs(interval_secs)); loop { - let _ = interval.tick().await; + tokio::select! { + _ = cancel.cancelled() => { + info!("Stop channel recycle, ChannelManager id: {}", id); + break; + }, + _ = interval.tick() => {} + } + pool.retain_channel(|_, c| c.access.swap(0, Ordering::Relaxed) != 0) } } @@ -461,11 +508,7 @@ mod tests { #[should_panic] #[test] fn test_invalid_addr() { - let pool = Arc::new(Pool::default()); - let mgr = ChannelManager { - pool, - ..Default::default() - }; + let mgr = ChannelManager::default(); let addr = "http://test"; let _ = mgr.get(addr).unwrap(); @@ -475,7 +518,9 @@ mod tests { async fn test_access_count() { let mgr = ChannelManager::new(); // Do not start recycle - mgr.channel_recycle_started.store(true, Ordering::Relaxed); + mgr.inner + .channel_recycle_started + .store(true, Ordering::Relaxed); let mgr = Arc::new(mgr); let addr = "test_uri"; @@ -493,12 +538,12 @@ mod tests { join.await.unwrap(); } - assert_eq!(1000, mgr.pool.get_access(addr).unwrap()); + assert_eq!(1000, mgr.pool().get_access(addr).unwrap()); - mgr.pool + mgr.pool() .retain_channel(|_, c| c.access.swap(0, Ordering::Relaxed) != 0); - assert_eq!(0, mgr.pool.get_access(addr).unwrap()); + assert_eq!(0, mgr.pool().get_access(addr).unwrap()); } #[test] @@ -624,4 +669,49 @@ mod tests { true }); } + + #[tokio::test] + async fn test_pool_release_with_channel_recycle() { + let mgr = ChannelManager::new(); + + let pool_holder = mgr.pool().clone(); + + // start channel recycle task + let addr = "test_addr"; + let _ = mgr.get(addr); + + let mgr_clone_1 = mgr.clone(); + let mgr_clone_2 = mgr.clone(); + assert_eq!(3, Arc::strong_count(mgr.pool())); + + drop(mgr_clone_1); + drop(mgr_clone_2); + assert_eq!(3, Arc::strong_count(mgr.pool())); + + drop(mgr); + + // wait for the channel recycle task to finish + tokio::time::sleep(Duration::from_millis(10)).await; + + assert_eq!(1, Arc::strong_count(&pool_holder)); + } + + #[tokio::test] + async fn test_pool_release_without_channel_recycle() { + let mgr = ChannelManager::new(); + + let pool_holder = mgr.pool().clone(); + + let mgr_clone_1 = mgr.clone(); + let mgr_clone_2 = mgr.clone(); + assert_eq!(2, Arc::strong_count(mgr.pool())); + + drop(mgr_clone_1); + drop(mgr_clone_2); + assert_eq!(2, Arc::strong_count(mgr.pool())); + + drop(mgr); + + assert_eq!(1, Arc::strong_count(&pool_holder)); + } }