fix: gRPC connection pool leak (#5876)

* fix: gRPC connection pool leak

* use .config() instead of .inner.config

* cancel the bg task if it is running

* fix: cr

* add unit test for pool release

* Avoid potential data races
This commit is contained in:
fys
2025-04-11 13:54:28 +08:00
committed by GitHub
parent 71255b3cbd
commit 84e2bc52c2
3 changed files with 140 additions and 48 deletions

1
Cargo.lock generated
View File

@@ -2099,6 +2099,7 @@ dependencies = [
"rand 0.9.0",
"snafu 0.8.5",
"tokio",
"tokio-util",
"tonic 0.12.3",
"tower 0.5.2",
]

View File

@@ -25,6 +25,7 @@ lazy_static.workspace = true
prost.workspace = true
snafu.workspace = true
tokio.workspace = true
tokio-util.workspace = true
tonic.workspace = true
tower.workspace = true

View File

@@ -22,6 +22,7 @@ use dashmap::mapref::entry::Entry;
use dashmap::DashMap;
use lazy_static::lazy_static;
use snafu::{OptionExt, ResultExt};
use tokio_util::sync::CancellationToken;
use tonic::transport::{
Certificate, Channel as InnerChannel, ClientTlsConfig, Endpoint, Identity, Uri,
};
@@ -39,18 +40,48 @@ lazy_static! {
static ref ID: AtomicU64 = AtomicU64::new(0);
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Default)]
pub struct ChannelManager {
inner: Arc<Inner>,
}
#[derive(Debug)]
struct Inner {
id: u64,
config: ChannelConfig,
client_tls_config: Option<ClientTlsConfig>,
pool: Arc<Pool>,
channel_recycle_started: Arc<AtomicBool>,
channel_recycle_started: AtomicBool,
cancel: CancellationToken,
}
impl Default for ChannelManager {
impl Default for Inner {
fn default() -> Self {
ChannelManager::with_config(ChannelConfig::default())
Self::with_config(ChannelConfig::default())
}
}
impl Drop for Inner {
fn drop(&mut self) {
// Cancel the channel recycle task.
self.cancel.cancel();
}
}
impl Inner {
fn with_config(config: ChannelConfig) -> Self {
let id = ID.fetch_add(1, Ordering::Relaxed);
let pool = Arc::new(Pool::default());
let cancel = CancellationToken::new();
Self {
id,
config,
client_tls_config: None,
pool,
channel_recycle_started: AtomicBool::new(false),
cancel,
}
}
}
@@ -60,19 +91,14 @@ impl ChannelManager {
}
pub fn with_config(config: ChannelConfig) -> Self {
let id = ID.fetch_add(1, Ordering::Relaxed);
let pool = Arc::new(Pool::default());
let inner = Inner::with_config(config);
Self {
id,
config,
client_tls_config: None,
pool,
channel_recycle_started: Arc::new(AtomicBool::new(false)),
inner: Arc::new(inner),
}
}
pub fn with_tls_config(config: ChannelConfig) -> Result<Self> {
let mut cm = Self::with_config(config.clone());
let mut inner = Inner::with_config(config.clone());
// setup tls
let path_config = config.client_tls.context(InvalidTlsConfigSnafu {
@@ -88,17 +114,23 @@ impl ChannelManager {
.context(InvalidConfigFilePathSnafu)?;
let client_identity = Identity::from_pem(client_cert, client_key);
cm.client_tls_config = Some(
inner.client_tls_config = Some(
ClientTlsConfig::new()
.ca_certificate(server_root_ca_cert)
.identity(client_identity),
);
Ok(cm)
Ok(Self {
inner: Arc::new(inner),
})
}
pub fn config(&self) -> &ChannelConfig {
&self.config
&self.inner.config
}
fn pool(&self) -> &Arc<Pool> {
&self.inner.pool
}
pub fn get(&self, addr: impl AsRef<str>) -> Result<InnerChannel> {
@@ -106,12 +138,12 @@ impl ChannelManager {
let addr = addr.as_ref();
// It will acquire the read lock.
if let Some(inner_ch) = self.pool.get(addr) {
if let Some(inner_ch) = self.pool().get(addr) {
return Ok(inner_ch);
}
// It will acquire the write lock.
let entry = match self.pool.entry(addr.to_string()) {
let entry = match self.pool().entry(addr.to_string()) {
Entry::Occupied(entry) => {
entry.get().increase_access();
entry.into_ref()
@@ -150,7 +182,7 @@ impl ChannelManager {
access: AtomicUsize::new(1),
use_default_connector: false,
};
self.pool.put(addr, channel);
self.pool().put(addr, channel);
Ok(inner_channel)
}
@@ -159,11 +191,11 @@ impl ChannelManager {
where
F: FnMut(&String, &mut Channel) -> bool,
{
self.pool.retain_channel(f);
self.pool().retain_channel(f);
}
fn build_endpoint(&self, addr: &str) -> Result<Endpoint> {
let http_prefix = if self.client_tls_config.is_some() {
let http_prefix = if self.inner.client_tls_config.is_some() {
"https"
} else {
"http"
@@ -172,51 +204,52 @@ impl ChannelManager {
let mut endpoint =
Endpoint::new(format!("{http_prefix}://{addr}")).context(CreateChannelSnafu)?;
if let Some(dur) = self.config.timeout {
if let Some(dur) = self.config().timeout {
endpoint = endpoint.timeout(dur);
}
if let Some(dur) = self.config.connect_timeout {
if let Some(dur) = self.config().connect_timeout {
endpoint = endpoint.connect_timeout(dur);
}
if let Some(limit) = self.config.concurrency_limit {
if let Some(limit) = self.config().concurrency_limit {
endpoint = endpoint.concurrency_limit(limit);
}
if let Some((limit, dur)) = self.config.rate_limit {
if let Some((limit, dur)) = self.config().rate_limit {
endpoint = endpoint.rate_limit(limit, dur);
}
if let Some(size) = self.config.initial_stream_window_size {
if let Some(size) = self.config().initial_stream_window_size {
endpoint = endpoint.initial_stream_window_size(size);
}
if let Some(size) = self.config.initial_connection_window_size {
if let Some(size) = self.config().initial_connection_window_size {
endpoint = endpoint.initial_connection_window_size(size);
}
if let Some(dur) = self.config.http2_keep_alive_interval {
if let Some(dur) = self.config().http2_keep_alive_interval {
endpoint = endpoint.http2_keep_alive_interval(dur);
}
if let Some(dur) = self.config.http2_keep_alive_timeout {
if let Some(dur) = self.config().http2_keep_alive_timeout {
endpoint = endpoint.keep_alive_timeout(dur);
}
if let Some(enabled) = self.config.http2_keep_alive_while_idle {
if let Some(enabled) = self.config().http2_keep_alive_while_idle {
endpoint = endpoint.keep_alive_while_idle(enabled);
}
if let Some(enabled) = self.config.http2_adaptive_window {
if let Some(enabled) = self.config().http2_adaptive_window {
endpoint = endpoint.http2_adaptive_window(enabled);
}
if let Some(tls_config) = &self.client_tls_config {
if let Some(tls_config) = &self.inner.client_tls_config {
endpoint = endpoint
.tls_config(tls_config.clone())
.context(CreateChannelSnafu)?;
}
endpoint = endpoint
.tcp_keepalive(self.config.tcp_keepalive)
.tcp_nodelay(self.config.tcp_nodelay);
.tcp_keepalive(self.config().tcp_keepalive)
.tcp_nodelay(self.config().tcp_nodelay);
Ok(endpoint)
}
fn trigger_channel_recycling(&self) {
if self
.inner
.channel_recycle_started
.compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
.is_err()
@@ -224,13 +257,15 @@ impl ChannelManager {
return;
}
let pool = self.pool.clone();
let _handle = common_runtime::spawn_global(async {
recycle_channel_in_loop(pool, RECYCLE_CHANNEL_INTERVAL_SECS).await;
let pool = self.pool().clone();
let cancel = self.inner.cancel.clone();
let id = self.inner.id;
let _handle = common_runtime::spawn_global(async move {
recycle_channel_in_loop(pool, id, cancel, RECYCLE_CHANNEL_INTERVAL_SECS).await;
});
info!(
"ChannelManager: {}, channel recycle is started, running in the background!",
self.id
self.inner.id
);
}
}
@@ -443,11 +478,23 @@ impl Pool {
}
}
async fn recycle_channel_in_loop(pool: Arc<Pool>, interval_secs: u64) {
async fn recycle_channel_in_loop(
pool: Arc<Pool>,
id: u64,
cancel: CancellationToken,
interval_secs: u64,
) {
let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
loop {
let _ = interval.tick().await;
tokio::select! {
_ = cancel.cancelled() => {
info!("Stop channel recycle, ChannelManager id: {}", id);
break;
},
_ = interval.tick() => {}
}
pool.retain_channel(|_, c| c.access.swap(0, Ordering::Relaxed) != 0)
}
}
@@ -461,11 +508,7 @@ mod tests {
#[should_panic]
#[test]
fn test_invalid_addr() {
let pool = Arc::new(Pool::default());
let mgr = ChannelManager {
pool,
..Default::default()
};
let mgr = ChannelManager::default();
let addr = "http://test";
let _ = mgr.get(addr).unwrap();
@@ -475,7 +518,9 @@ mod tests {
async fn test_access_count() {
let mgr = ChannelManager::new();
// Do not start recycle
mgr.channel_recycle_started.store(true, Ordering::Relaxed);
mgr.inner
.channel_recycle_started
.store(true, Ordering::Relaxed);
let mgr = Arc::new(mgr);
let addr = "test_uri";
@@ -493,12 +538,12 @@ mod tests {
join.await.unwrap();
}
assert_eq!(1000, mgr.pool.get_access(addr).unwrap());
assert_eq!(1000, mgr.pool().get_access(addr).unwrap());
mgr.pool
mgr.pool()
.retain_channel(|_, c| c.access.swap(0, Ordering::Relaxed) != 0);
assert_eq!(0, mgr.pool.get_access(addr).unwrap());
assert_eq!(0, mgr.pool().get_access(addr).unwrap());
}
#[test]
@@ -624,4 +669,49 @@ mod tests {
true
});
}
#[tokio::test]
async fn test_pool_release_with_channel_recycle() {
let mgr = ChannelManager::new();
let pool_holder = mgr.pool().clone();
// start channel recycle task
let addr = "test_addr";
let _ = mgr.get(addr);
let mgr_clone_1 = mgr.clone();
let mgr_clone_2 = mgr.clone();
assert_eq!(3, Arc::strong_count(mgr.pool()));
drop(mgr_clone_1);
drop(mgr_clone_2);
assert_eq!(3, Arc::strong_count(mgr.pool()));
drop(mgr);
// wait for the channel recycle task to finish
tokio::time::sleep(Duration::from_millis(10)).await;
assert_eq!(1, Arc::strong_count(&pool_holder));
}
#[tokio::test]
async fn test_pool_release_without_channel_recycle() {
let mgr = ChannelManager::new();
let pool_holder = mgr.pool().clone();
let mgr_clone_1 = mgr.clone();
let mgr_clone_2 = mgr.clone();
assert_eq!(2, Arc::strong_count(mgr.pool()));
drop(mgr_clone_1);
drop(mgr_clone_2);
assert_eq!(2, Arc::strong_count(mgr.pool()));
drop(mgr);
assert_eq!(1, Arc::strong_count(&pool_holder));
}
}