From ee5b7ff3c80afdae4da601d2b1914ce58b0ce853 Mon Sep 17 00:00:00 2001 From: shuiyisong <113876041+shuiyisong@users.noreply.github.com> Date: Thu, 30 Oct 2025 12:26:02 +0800 Subject: [PATCH] chore: unify initialization of channel manager (#7159) * chore: unify initialization of channel manager and extract loading tls Signed-off-by: shuiyisong * chore: fix cr issue Signed-off-by: shuiyisong --------- Signed-off-by: shuiyisong --- src/client/src/client.rs | 7 +- src/client/src/client_manager.rs | 2 +- src/common/frontend/src/selector.rs | 2 +- src/common/grpc/src/channel_manager.rs | 87 ++++++++----------- src/common/grpc/tests/mod.rs | 18 ++-- src/flow/src/batching_mode/frontend_client.rs | 12 ++- src/meta-client/examples/meta_client.rs | 2 +- src/meta-client/src/lib.rs | 6 +- src/meta-srv/src/mocks.rs | 2 +- 9 files changed, 67 insertions(+), 71 deletions(-) diff --git a/src/client/src/client.rs b/src/client/src/client.rs index 1506ac5208..611cce954d 100644 --- a/src/client/src/client.rs +++ b/src/client/src/client.rs @@ -20,7 +20,9 @@ use api::v1::health_check_client::HealthCheckClient; use api::v1::prometheus_gateway_client::PrometheusGatewayClient; use api::v1::region::region_client::RegionClient as PbRegionClient; use arrow_flight::flight_service_client::FlightServiceClient; -use common_grpc::channel_manager::{ChannelConfig, ChannelManager, ClientTlsOption}; +use common_grpc::channel_manager::{ + ChannelConfig, ChannelManager, ClientTlsOption, load_tls_config, +}; use parking_lot::RwLock; use snafu::{OptionExt, ResultExt}; use tonic::codec::CompressionEncoding; @@ -94,8 +96,9 @@ impl Client { A: AsRef<[U]>, { let channel_config = ChannelConfig::default().client_tls_config(client_tls); - let channel_manager = ChannelManager::with_tls_config(channel_config) + let tls_config = load_tls_config(channel_config.client_tls.as_ref()) .context(error::CreateTlsChannelSnafu)?; + let channel_manager = ChannelManager::with_config(channel_config, tls_config); Ok(Self::with_manager_and_urls(channel_manager, urls)) } diff --git a/src/client/src/client_manager.rs b/src/client/src/client_manager.rs index 80afd2fb32..edac45a9fe 100644 --- a/src/client/src/client_manager.rs +++ b/src/client/src/client_manager.rs @@ -74,7 +74,7 @@ impl FlownodeManager for NodeClients { impl NodeClients { pub fn new(config: ChannelConfig) -> Self { Self { - channel_manager: ChannelManager::with_config(config), + channel_manager: ChannelManager::with_config(config, None), clients: CacheBuilder::new(1024) .time_to_live(Duration::from_secs(30 * 60)) .time_to_idle(Duration::from_secs(5 * 60)) diff --git a/src/common/frontend/src/selector.rs b/src/common/frontend/src/selector.rs index 4e6cc9566c..f2dc337cc2 100644 --- a/src/common/frontend/src/selector.rs +++ b/src/common/frontend/src/selector.rs @@ -104,7 +104,7 @@ impl MetaClientSelector { let cfg = ChannelConfig::new() .connect_timeout(Duration::from_secs(30)) .timeout(Duration::from_secs(30)); - let channel_manager = ChannelManager::with_config(cfg); + let channel_manager = ChannelManager::with_config(cfg, None); Self { meta_client, channel_manager, diff --git a/src/common/grpc/src/channel_manager.rs b/src/common/grpc/src/channel_manager.rs index cdea89cb86..667b73f5f3 100644 --- a/src/common/grpc/src/channel_manager.rs +++ b/src/common/grpc/src/channel_manager.rs @@ -22,14 +22,14 @@ use dashmap::DashMap; use dashmap::mapref::entry::Entry; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; -use snafu::{OptionExt, ResultExt}; +use snafu::ResultExt; use tokio_util::sync::CancellationToken; use tonic::transport::{ Certificate, Channel as InnerChannel, ClientTlsConfig, Endpoint, Identity, Uri, }; use tower::Service; -use crate::error::{CreateChannelSnafu, InvalidConfigFilePathSnafu, InvalidTlsConfigSnafu, Result}; +use crate::error::{CreateChannelSnafu, InvalidConfigFilePathSnafu, Result}; const RECYCLE_CHANNEL_INTERVAL_SECS: u64 = 60; pub const DEFAULT_GRPC_REQUEST_TIMEOUT_SECS: u64 = 10; @@ -91,57 +91,18 @@ impl ChannelManager { Default::default() } - pub fn with_config(config: ChannelConfig) -> Self { - let inner = Inner::with_config(config); + /// unified with config function that support tls config + /// use [`load_tls_config`] to load tls config from file system + pub fn with_config(config: ChannelConfig, tls_config: Option) -> Self { + let mut inner = Inner::with_config(config.clone()); + if let Some(tls_config) = tls_config { + inner.client_tls_config = Some(tls_config); + } Self { inner: Arc::new(inner), } } - /// Read tls cert and key files and create a ChannelManager with TLS config. - pub fn with_tls_config(config: ChannelConfig) -> Result { - let mut inner = Inner::with_config(config.clone()); - - // setup tls - let path_config = config.client_tls.context(InvalidTlsConfigSnafu { - msg: "no config input", - })?; - - if !path_config.enabled { - // if TLS not enabled, just ignore other tls config - // and not set `client_tls_config` hence not use TLS - return Ok(Self { - inner: Arc::new(inner), - }); - } - - let mut tls_config = ClientTlsConfig::new(); - - if let Some(server_ca) = path_config.server_ca_cert_path { - let server_root_ca_cert = - std::fs::read_to_string(server_ca).context(InvalidConfigFilePathSnafu)?; - let server_root_ca_cert = Certificate::from_pem(server_root_ca_cert); - tls_config = tls_config.ca_certificate(server_root_ca_cert); - } - - if let (Some(client_cert_path), Some(client_key_path)) = - (&path_config.client_cert_path, &path_config.client_key_path) - { - let client_cert = - std::fs::read_to_string(client_cert_path).context(InvalidConfigFilePathSnafu)?; - let client_key = - std::fs::read_to_string(client_key_path).context(InvalidConfigFilePathSnafu)?; - let client_identity = Identity::from_pem(client_cert, client_key); - tls_config = tls_config.identity(client_identity); - } - - inner.client_tls_config = Some(tls_config); - - Ok(Self { - inner: Arc::new(inner), - }) - } - pub fn config(&self) -> &ChannelConfig { &self.inner.config } @@ -287,6 +248,34 @@ impl ChannelManager { } } +pub fn load_tls_config(tls_option: Option<&ClientTlsOption>) -> Result> { + let path_config = match tls_option { + Some(path_config) if path_config.enabled => path_config, + _ => return Ok(None), + }; + + let mut tls_config = ClientTlsConfig::new(); + + if let Some(server_ca) = &path_config.server_ca_cert_path { + let server_root_ca_cert = + std::fs::read_to_string(server_ca).context(InvalidConfigFilePathSnafu)?; + let server_root_ca_cert = Certificate::from_pem(server_root_ca_cert); + tls_config = tls_config.ca_certificate(server_root_ca_cert); + } + + if let (Some(client_cert_path), Some(client_key_path)) = + (&path_config.client_cert_path, &path_config.client_key_path) + { + let client_cert = + std::fs::read_to_string(client_cert_path).context(InvalidConfigFilePathSnafu)?; + let client_key = + std::fs::read_to_string(client_key_path).context(InvalidConfigFilePathSnafu)?; + let client_identity = Identity::from_pem(client_cert, client_key); + tls_config = tls_config.identity(client_identity); + } + Ok(Some(tls_config)) +} + #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct ClientTlsOption { /// Whether to enable TLS for client. @@ -659,7 +648,7 @@ mod tests { .http2_adaptive_window(true) .tcp_keepalive(Duration::from_secs(2)) .tcp_nodelay(true); - let mgr = ChannelManager::with_config(config); + let mgr = ChannelManager::with_config(config, None); let res = mgr.build_endpoint("test_addr"); diff --git a/src/common/grpc/tests/mod.rs b/src/common/grpc/tests/mod.rs index d119f22836..a437d21cd9 100644 --- a/src/common/grpc/tests/mod.rs +++ b/src/common/grpc/tests/mod.rs @@ -12,14 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -use common_grpc::channel_manager::{ChannelConfig, ChannelManager, ClientTlsOption}; +use common_grpc::channel_manager::{ + ChannelConfig, ChannelManager, ClientTlsOption, load_tls_config, +}; #[tokio::test] async fn test_mtls_config() { // test no config let config = ChannelConfig::new(); - let re = ChannelManager::with_tls_config(config); - assert!(re.is_err()); + let re = load_tls_config(config.client_tls.as_ref()); + assert!(re.is_ok()); + assert!(re.unwrap().is_none()); // test wrong file let config = ChannelConfig::new().client_tls_config(ClientTlsOption { @@ -29,7 +32,7 @@ async fn test_mtls_config() { client_key_path: Some("tests/tls/wrong_client.key".to_string()), }); - let re = ChannelManager::with_tls_config(config); + let re = load_tls_config(config.client_tls.as_ref()); assert!(re.is_err()); // test corrupted file content @@ -40,7 +43,9 @@ async fn test_mtls_config() { client_key_path: Some("tests/tls/corrupted".to_string()), }); - let re = ChannelManager::with_tls_config(config).unwrap(); + let tls_config = load_tls_config(config.client_tls.as_ref()).unwrap(); + let re = ChannelManager::with_config(config, tls_config); + let re = re.get("127.0.0.1:0"); assert!(re.is_err()); @@ -52,7 +57,8 @@ async fn test_mtls_config() { client_key_path: Some("tests/tls/client.key".to_string()), }); - let re = ChannelManager::with_tls_config(config).unwrap(); + let tls_config = load_tls_config(config.client_tls.as_ref()).unwrap(); + let re = ChannelManager::with_config(config, tls_config); let re = re.get("127.0.0.1:0"); let _ = re.unwrap(); } diff --git a/src/flow/src/batching_mode/frontend_client.rs b/src/flow/src/batching_mode/frontend_client.rs index cba8f896d5..e9994b5b14 100644 --- a/src/flow/src/batching_mode/frontend_client.rs +++ b/src/flow/src/batching_mode/frontend_client.rs @@ -23,7 +23,7 @@ use api::v1::query_request::Query; use api::v1::{CreateTableExpr, QueryRequest}; use client::{Client, Database}; use common_error::ext::{BoxedError, ErrorExt}; -use common_grpc::channel_manager::{ChannelConfig, ChannelManager}; +use common_grpc::channel_manager::{ChannelConfig, ChannelManager, load_tls_config}; use common_meta::cluster::{NodeInfo, NodeInfoKey, Role}; use common_meta::peer::Peer; use common_meta::rpc::store::RangeRequest; @@ -123,12 +123,10 @@ impl FrontendClient { let cfg = ChannelConfig::new() .connect_timeout(batch_opts.grpc_conn_timeout) .timeout(batch_opts.query_timeout); - if let Some(tls) = &batch_opts.frontend_tls { - let cfg = cfg.client_tls_config(tls.clone()); - ChannelManager::with_tls_config(cfg).context(InvalidClientConfigSnafu)? - } else { - ChannelManager::with_config(cfg) - } + + let tls_config = load_tls_config(batch_opts.frontend_tls.as_ref()) + .context(InvalidClientConfigSnafu)?; + ChannelManager::with_config(cfg, tls_config) }, auth, query, diff --git a/src/meta-client/examples/meta_client.rs b/src/meta-client/examples/meta_client.rs index fb5125224c..175888f170 100644 --- a/src/meta-client/examples/meta_client.rs +++ b/src/meta-client/examples/meta_client.rs @@ -36,7 +36,7 @@ async fn run() { .timeout(Duration::from_secs(3)) .connect_timeout(Duration::from_secs(5)) .tcp_nodelay(true); - let channel_manager = ChannelManager::with_config(config); + let channel_manager = ChannelManager::with_config(config, None); let mut meta_client = MetaClientBuilder::datanode_default_options(id) .channel_manager(channel_manager) .build(); diff --git a/src/meta-client/src/lib.rs b/src/meta-client/src/lib.rs index 47384785e2..5b56b8e181 100644 --- a/src/meta-client/src/lib.rs +++ b/src/meta-client/src/lib.rs @@ -101,7 +101,7 @@ pub async fn create_meta_client( if let MetaClientType::Frontend = client_type { let ddl_config = base_config.clone().timeout(meta_client_options.ddl_timeout); - builder = builder.ddl_channel_manager(ChannelManager::with_config(ddl_config)); + builder = builder.ddl_channel_manager(ChannelManager::with_config(ddl_config, None)); if let Some(plugins) = plugins { let region_follower = plugins.get::(); if let Some(region_follower) = region_follower { @@ -112,8 +112,8 @@ pub async fn create_meta_client( } builder = builder - .channel_manager(ChannelManager::with_config(base_config)) - .heartbeat_channel_manager(ChannelManager::with_config(heartbeat_config)); + .channel_manager(ChannelManager::with_config(base_config, None)) + .heartbeat_channel_manager(ChannelManager::with_config(heartbeat_config, None)); let mut meta_client = builder.build(); diff --git a/src/meta-srv/src/mocks.rs b/src/meta-srv/src/mocks.rs index 6c2f0d3892..c805f8ea1b 100644 --- a/src/meta-srv/src/mocks.rs +++ b/src/meta-srv/src/mocks.rs @@ -134,7 +134,7 @@ pub async fn mock( .timeout(Duration::from_secs(10)) .connect_timeout(Duration::from_secs(10)) .tcp_nodelay(true); - let channel_manager = ChannelManager::with_config(config); + let channel_manager = ChannelManager::with_config(config, None); // Move client to an option so we can _move_ the inner value // on the first attempt to connect. All other attempts will fail.