chore: unify initialization of channel manager (#7159)

* chore: unify initialization of channel manager and extract loading tls

Signed-off-by: shuiyisong <xixing.sys@gmail.com>

* chore: fix cr issue

Signed-off-by: shuiyisong <xixing.sys@gmail.com>

---------

Signed-off-by: shuiyisong <xixing.sys@gmail.com>
This commit is contained in:
shuiyisong
2025-10-30 12:26:02 +08:00
committed by GitHub
parent 5d0ef376de
commit ee5b7ff3c8
9 changed files with 67 additions and 71 deletions

View File

@@ -20,7 +20,9 @@ use api::v1::health_check_client::HealthCheckClient;
use api::v1::prometheus_gateway_client::PrometheusGatewayClient;
use api::v1::region::region_client::RegionClient as PbRegionClient;
use arrow_flight::flight_service_client::FlightServiceClient;
use common_grpc::channel_manager::{ChannelConfig, ChannelManager, ClientTlsOption};
use common_grpc::channel_manager::{
ChannelConfig, ChannelManager, ClientTlsOption, load_tls_config,
};
use parking_lot::RwLock;
use snafu::{OptionExt, ResultExt};
use tonic::codec::CompressionEncoding;
@@ -94,8 +96,9 @@ impl Client {
A: AsRef<[U]>,
{
let channel_config = ChannelConfig::default().client_tls_config(client_tls);
let channel_manager = ChannelManager::with_tls_config(channel_config)
let tls_config = load_tls_config(channel_config.client_tls.as_ref())
.context(error::CreateTlsChannelSnafu)?;
let channel_manager = ChannelManager::with_config(channel_config, tls_config);
Ok(Self::with_manager_and_urls(channel_manager, urls))
}

View File

@@ -74,7 +74,7 @@ impl FlownodeManager for NodeClients {
impl NodeClients {
pub fn new(config: ChannelConfig) -> Self {
Self {
channel_manager: ChannelManager::with_config(config),
channel_manager: ChannelManager::with_config(config, None),
clients: CacheBuilder::new(1024)
.time_to_live(Duration::from_secs(30 * 60))
.time_to_idle(Duration::from_secs(5 * 60))

View File

@@ -104,7 +104,7 @@ impl MetaClientSelector {
let cfg = ChannelConfig::new()
.connect_timeout(Duration::from_secs(30))
.timeout(Duration::from_secs(30));
let channel_manager = ChannelManager::with_config(cfg);
let channel_manager = ChannelManager::with_config(cfg, None);
Self {
meta_client,
channel_manager,

View File

@@ -22,14 +22,14 @@ use dashmap::DashMap;
use dashmap::mapref::entry::Entry;
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use snafu::{OptionExt, ResultExt};
use snafu::ResultExt;
use tokio_util::sync::CancellationToken;
use tonic::transport::{
Certificate, Channel as InnerChannel, ClientTlsConfig, Endpoint, Identity, Uri,
};
use tower::Service;
use crate::error::{CreateChannelSnafu, InvalidConfigFilePathSnafu, InvalidTlsConfigSnafu, Result};
use crate::error::{CreateChannelSnafu, InvalidConfigFilePathSnafu, Result};
const RECYCLE_CHANNEL_INTERVAL_SECS: u64 = 60;
pub const DEFAULT_GRPC_REQUEST_TIMEOUT_SECS: u64 = 10;
@@ -91,57 +91,18 @@ impl ChannelManager {
Default::default()
}
pub fn with_config(config: ChannelConfig) -> Self {
let inner = Inner::with_config(config);
/// unified with config function that support tls config
/// use [`load_tls_config`] to load tls config from file system
pub fn with_config(config: ChannelConfig, tls_config: Option<ClientTlsConfig>) -> Self {
let mut inner = Inner::with_config(config.clone());
if let Some(tls_config) = tls_config {
inner.client_tls_config = Some(tls_config);
}
Self {
inner: Arc::new(inner),
}
}
/// Read tls cert and key files and create a ChannelManager with TLS config.
pub fn with_tls_config(config: ChannelConfig) -> Result<Self> {
let mut inner = Inner::with_config(config.clone());
// setup tls
let path_config = config.client_tls.context(InvalidTlsConfigSnafu {
msg: "no config input",
})?;
if !path_config.enabled {
// if TLS not enabled, just ignore other tls config
// and not set `client_tls_config` hence not use TLS
return Ok(Self {
inner: Arc::new(inner),
});
}
let mut tls_config = ClientTlsConfig::new();
if let Some(server_ca) = path_config.server_ca_cert_path {
let server_root_ca_cert =
std::fs::read_to_string(server_ca).context(InvalidConfigFilePathSnafu)?;
let server_root_ca_cert = Certificate::from_pem(server_root_ca_cert);
tls_config = tls_config.ca_certificate(server_root_ca_cert);
}
if let (Some(client_cert_path), Some(client_key_path)) =
(&path_config.client_cert_path, &path_config.client_key_path)
{
let client_cert =
std::fs::read_to_string(client_cert_path).context(InvalidConfigFilePathSnafu)?;
let client_key =
std::fs::read_to_string(client_key_path).context(InvalidConfigFilePathSnafu)?;
let client_identity = Identity::from_pem(client_cert, client_key);
tls_config = tls_config.identity(client_identity);
}
inner.client_tls_config = Some(tls_config);
Ok(Self {
inner: Arc::new(inner),
})
}
pub fn config(&self) -> &ChannelConfig {
&self.inner.config
}
@@ -287,6 +248,34 @@ impl ChannelManager {
}
}
pub fn load_tls_config(tls_option: Option<&ClientTlsOption>) -> Result<Option<ClientTlsConfig>> {
let path_config = match tls_option {
Some(path_config) if path_config.enabled => path_config,
_ => return Ok(None),
};
let mut tls_config = ClientTlsConfig::new();
if let Some(server_ca) = &path_config.server_ca_cert_path {
let server_root_ca_cert =
std::fs::read_to_string(server_ca).context(InvalidConfigFilePathSnafu)?;
let server_root_ca_cert = Certificate::from_pem(server_root_ca_cert);
tls_config = tls_config.ca_certificate(server_root_ca_cert);
}
if let (Some(client_cert_path), Some(client_key_path)) =
(&path_config.client_cert_path, &path_config.client_key_path)
{
let client_cert =
std::fs::read_to_string(client_cert_path).context(InvalidConfigFilePathSnafu)?;
let client_key =
std::fs::read_to_string(client_key_path).context(InvalidConfigFilePathSnafu)?;
let client_identity = Identity::from_pem(client_cert, client_key);
tls_config = tls_config.identity(client_identity);
}
Ok(Some(tls_config))
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct ClientTlsOption {
/// Whether to enable TLS for client.
@@ -659,7 +648,7 @@ mod tests {
.http2_adaptive_window(true)
.tcp_keepalive(Duration::from_secs(2))
.tcp_nodelay(true);
let mgr = ChannelManager::with_config(config);
let mgr = ChannelManager::with_config(config, None);
let res = mgr.build_endpoint("test_addr");

View File

@@ -12,14 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use common_grpc::channel_manager::{ChannelConfig, ChannelManager, ClientTlsOption};
use common_grpc::channel_manager::{
ChannelConfig, ChannelManager, ClientTlsOption, load_tls_config,
};
#[tokio::test]
async fn test_mtls_config() {
// test no config
let config = ChannelConfig::new();
let re = ChannelManager::with_tls_config(config);
assert!(re.is_err());
let re = load_tls_config(config.client_tls.as_ref());
assert!(re.is_ok());
assert!(re.unwrap().is_none());
// test wrong file
let config = ChannelConfig::new().client_tls_config(ClientTlsOption {
@@ -29,7 +32,7 @@ async fn test_mtls_config() {
client_key_path: Some("tests/tls/wrong_client.key".to_string()),
});
let re = ChannelManager::with_tls_config(config);
let re = load_tls_config(config.client_tls.as_ref());
assert!(re.is_err());
// test corrupted file content
@@ -40,7 +43,9 @@ async fn test_mtls_config() {
client_key_path: Some("tests/tls/corrupted".to_string()),
});
let re = ChannelManager::with_tls_config(config).unwrap();
let tls_config = load_tls_config(config.client_tls.as_ref()).unwrap();
let re = ChannelManager::with_config(config, tls_config);
let re = re.get("127.0.0.1:0");
assert!(re.is_err());
@@ -52,7 +57,8 @@ async fn test_mtls_config() {
client_key_path: Some("tests/tls/client.key".to_string()),
});
let re = ChannelManager::with_tls_config(config).unwrap();
let tls_config = load_tls_config(config.client_tls.as_ref()).unwrap();
let re = ChannelManager::with_config(config, tls_config);
let re = re.get("127.0.0.1:0");
let _ = re.unwrap();
}

View File

@@ -23,7 +23,7 @@ use api::v1::query_request::Query;
use api::v1::{CreateTableExpr, QueryRequest};
use client::{Client, Database};
use common_error::ext::{BoxedError, ErrorExt};
use common_grpc::channel_manager::{ChannelConfig, ChannelManager};
use common_grpc::channel_manager::{ChannelConfig, ChannelManager, load_tls_config};
use common_meta::cluster::{NodeInfo, NodeInfoKey, Role};
use common_meta::peer::Peer;
use common_meta::rpc::store::RangeRequest;
@@ -123,12 +123,10 @@ impl FrontendClient {
let cfg = ChannelConfig::new()
.connect_timeout(batch_opts.grpc_conn_timeout)
.timeout(batch_opts.query_timeout);
if let Some(tls) = &batch_opts.frontend_tls {
let cfg = cfg.client_tls_config(tls.clone());
ChannelManager::with_tls_config(cfg).context(InvalidClientConfigSnafu)?
} else {
ChannelManager::with_config(cfg)
}
let tls_config = load_tls_config(batch_opts.frontend_tls.as_ref())
.context(InvalidClientConfigSnafu)?;
ChannelManager::with_config(cfg, tls_config)
},
auth,
query,

View File

@@ -36,7 +36,7 @@ async fn run() {
.timeout(Duration::from_secs(3))
.connect_timeout(Duration::from_secs(5))
.tcp_nodelay(true);
let channel_manager = ChannelManager::with_config(config);
let channel_manager = ChannelManager::with_config(config, None);
let mut meta_client = MetaClientBuilder::datanode_default_options(id)
.channel_manager(channel_manager)
.build();

View File

@@ -101,7 +101,7 @@ pub async fn create_meta_client(
if let MetaClientType::Frontend = client_type {
let ddl_config = base_config.clone().timeout(meta_client_options.ddl_timeout);
builder = builder.ddl_channel_manager(ChannelManager::with_config(ddl_config));
builder = builder.ddl_channel_manager(ChannelManager::with_config(ddl_config, None));
if let Some(plugins) = plugins {
let region_follower = plugins.get::<RegionFollowerClientRef>();
if let Some(region_follower) = region_follower {
@@ -112,8 +112,8 @@ pub async fn create_meta_client(
}
builder = builder
.channel_manager(ChannelManager::with_config(base_config))
.heartbeat_channel_manager(ChannelManager::with_config(heartbeat_config));
.channel_manager(ChannelManager::with_config(base_config, None))
.heartbeat_channel_manager(ChannelManager::with_config(heartbeat_config, None));
let mut meta_client = builder.build();

View File

@@ -134,7 +134,7 @@ pub async fn mock(
.timeout(Duration::from_secs(10))
.connect_timeout(Duration::from_secs(10))
.tcp_nodelay(true);
let channel_manager = ChannelManager::with_config(config);
let channel_manager = ChannelManager::with_config(config, None);
// Move client to an option so we can _move_ the inner value
// on the first attempt to connect. All other attempts will fail.