From 9f4902b10a63499bec9afcadb0c267aa1d107e76 Mon Sep 17 00:00:00 2001 From: shuiyisong <113876041+shuiyisong@users.noreply.github.com> Date: Mon, 24 Nov 2025 19:52:11 +0800 Subject: [PATCH] feat: reloadable tls client config (#7230) * feat: add ReloadableClientTlsConfig Signed-off-by: shuiyisong * refactor: merge tls option with the reloadable Signed-off-by: shuiyisong * chore: rename function Signed-off-by: shuiyisong * chore: update comment Signed-off-by: shuiyisong * chore: extract tls loader Signed-off-by: shuiyisong * chore: minor comment update Signed-off-by: shuiyisong * chore: add serde default to watch field Signed-off-by: shuiyisong * chore: minor update Signed-off-by: shuiyisong * chore: add log Signed-off-by: shuiyisong * fix: add error log Signed-off-by: shuiyisong --------- Signed-off-by: shuiyisong --- Cargo.lock | 2 + src/client/src/client.rs | 8 +- src/common/grpc/Cargo.toml | 2 + src/common/grpc/src/channel_manager.rs | 103 +++++++++-- src/common/grpc/src/error.rs | 10 ++ src/common/grpc/src/lib.rs | 1 + src/common/grpc/src/reloadable_tls.rs | 163 ++++++++++++++++++ src/common/grpc/tests/mod.rs | 115 +++++++++++- src/flow/src/batching_mode/frontend_client.rs | 4 +- src/frontend/src/server.rs | 6 +- src/servers/src/mysql/server.rs | 2 +- src/servers/src/postgres/server.rs | 2 +- src/servers/src/tls.rs | 115 +++--------- tests-integration/tests/grpc.rs | 1 + 14 files changed, 416 insertions(+), 118 deletions(-) create mode 100644 src/common/grpc/src/reloadable_tls.rs diff --git a/Cargo.lock b/Cargo.lock index d32a94c3a2..440ee75945 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2265,11 +2265,13 @@ dependencies = [ "hyper 1.6.0", "hyper-util", "lazy_static", + "notify", "prost 0.13.5", "rand 0.9.1", "serde", "serde_json", "snafu 0.8.6", + "tempfile", "tokio", "tokio-util", "tonic 0.13.1", diff --git a/src/client/src/client.rs b/src/client/src/client.rs index 611cce954d..39cb5c30aa 100644 --- a/src/client/src/client.rs +++ b/src/client/src/client.rs @@ -21,7 +21,7 @@ use api::v1::prometheus_gateway_client::PrometheusGatewayClient; use api::v1::region::region_client::RegionClient as PbRegionClient; use arrow_flight::flight_service_client::FlightServiceClient; use common_grpc::channel_manager::{ - ChannelConfig, ChannelManager, ClientTlsOption, load_tls_config, + ChannelConfig, ChannelManager, ClientTlsOption, load_client_tls_config, }; use parking_lot::RwLock; use snafu::{OptionExt, ResultExt}; @@ -95,9 +95,9 @@ impl Client { U: AsRef, A: AsRef<[U]>, { - let channel_config = ChannelConfig::default().client_tls_config(client_tls); - let tls_config = load_tls_config(channel_config.client_tls.as_ref()) - .context(error::CreateTlsChannelSnafu)?; + let channel_config = ChannelConfig::default().client_tls_config(client_tls.clone()); + let tls_config = + load_client_tls_config(Some(client_tls)).context(error::CreateTlsChannelSnafu)?; let channel_manager = ChannelManager::with_config(channel_config, tls_config); Ok(Self::with_manager_and_urls(channel_manager, urls)) } diff --git a/src/common/grpc/Cargo.toml b/src/common/grpc/Cargo.toml index 1684d0b297..9978791a7a 100644 --- a/src/common/grpc/Cargo.toml +++ b/src/common/grpc/Cargo.toml @@ -23,6 +23,7 @@ datatypes.workspace = true flatbuffers = "25.2" hyper.workspace = true lazy_static.workspace = true +notify.workspace = true prost.workspace = true serde.workspace = true serde_json.workspace = true @@ -37,6 +38,7 @@ vec1 = "1.12" criterion = "0.4" hyper-util = { workspace = true, features = ["tokio"] } rand.workspace = true +tempfile.workspace = true [[bench]] name = "bench_main" diff --git a/src/common/grpc/src/channel_manager.rs b/src/common/grpc/src/channel_manager.rs index 667b73f5f3..a60604da94 100644 --- a/src/common/grpc/src/channel_manager.rs +++ b/src/common/grpc/src/channel_manager.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::path::Path; use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; use std::time::Duration; @@ -30,6 +31,7 @@ use tonic::transport::{ use tower::Service; use crate::error::{CreateChannelSnafu, InvalidConfigFilePathSnafu, Result}; +use crate::reloadable_tls::{ReloadableTlsConfig, TlsConfigLoader, maybe_watch_tls_config}; const RECYCLE_CHANNEL_INTERVAL_SECS: u64 = 60; pub const DEFAULT_GRPC_REQUEST_TIMEOUT_SECS: u64 = 10; @@ -50,7 +52,7 @@ pub struct ChannelManager { struct Inner { id: u64, config: ChannelConfig, - client_tls_config: Option, + reloadable_client_tls_config: Option>, pool: Arc, channel_recycle_started: AtomicBool, cancel: CancellationToken, @@ -78,7 +80,7 @@ impl Inner { Self { id, config, - client_tls_config: None, + reloadable_client_tls_config: None, pool, channel_recycle_started: AtomicBool::new(false), cancel, @@ -91,13 +93,17 @@ impl ChannelManager { Default::default() } - /// unified with config function that support tls config - /// use [`load_tls_config`] to load tls config from file system - pub fn with_config(config: ChannelConfig, tls_config: Option) -> Self { + /// Create a ChannelManager with configuration and optional TLS config + /// + /// Use [`load_client_tls_config`] to create TLS configuration from `ClientTlsOption`. + /// The TLS config supports both static (watch disabled) and dynamic reloading (watch enabled). + /// If you want to use dynamic reloading, please **manually** invoke [`maybe_watch_client_tls_config`] after this method. + pub fn with_config( + config: ChannelConfig, + reloadable_tls_config: Option>, + ) -> Self { let mut inner = Inner::with_config(config.clone()); - if let Some(tls_config) = tls_config { - inner.client_tls_config = Some(tls_config); - } + inner.reloadable_client_tls_config = reloadable_tls_config; Self { inner: Arc::new(inner), } @@ -172,8 +178,21 @@ impl ChannelManager { self.pool().retain_channel(f); } + /// Clear all channels to force reconnection. + /// This should be called when TLS configuration changes to ensure new connections use updated certificates. + pub fn clear_all_channels(&self) { + self.pool().retain_channel(|_, _| false); + } + fn build_endpoint(&self, addr: &str) -> Result { - let http_prefix = if self.inner.client_tls_config.is_some() { + // Get the latest TLS config from reloadable config (which handles both static and dynamic cases) + let tls_config = self + .inner + .reloadable_client_tls_config + .as_ref() + .and_then(|c| c.get_config()); + + let http_prefix = if tls_config.is_some() { "https" } else { "http" @@ -212,9 +231,9 @@ impl ChannelManager { if let Some(enabled) = self.config().http2_adaptive_window { endpoint = endpoint.http2_adaptive_window(enabled); } - if let Some(tls_config) = &self.inner.client_tls_config { + if let Some(tls_config) = tls_config { endpoint = endpoint - .tls_config(tls_config.clone()) + .tls_config(tls_config) .context(CreateChannelSnafu { addr })?; } @@ -248,7 +267,7 @@ impl ChannelManager { } } -pub fn load_tls_config(tls_option: Option<&ClientTlsOption>) -> Result> { +fn load_tls_config(tls_option: Option<&ClientTlsOption>) -> Result> { let path_config = match tls_option { Some(path_config) if path_config.enabled => path_config, _ => return Ok(None), @@ -276,13 +295,69 @@ pub fn load_tls_config(tls_option: Option<&ClientTlsOption>) -> Result for ClientTlsOption { + type Error = crate::error::Error; + + fn load(&self) -> Result> { + load_tls_config(Some(self)) + } + + fn watch_paths(&self) -> Vec<&Path> { + let mut paths = Vec::new(); + if let Some(cert_path) = &self.client_cert_path { + paths.push(Path::new(cert_path.as_str())); + } + if let Some(key_path) = &self.client_key_path { + paths.push(Path::new(key_path.as_str())); + } + if let Some(ca_path) = &self.server_ca_cert_path { + paths.push(Path::new(ca_path.as_str())); + } + paths + } + + fn watch_enabled(&self) -> bool { + self.enabled && self.watch + } +} + +/// Type alias for client-side reloadable TLS config +pub type ReloadableClientTlsConfig = ReloadableTlsConfig; + +/// Load client TLS configuration from `ClientTlsOption` and return a `ReloadableClientTlsConfig`. +/// This is the primary way to create TLS configuration for the ChannelManager. +pub fn load_client_tls_config( + tls_option: Option, +) -> Result>> { + match tls_option { + Some(option) if option.enabled => { + let reloadable = ReloadableClientTlsConfig::try_new(option)?; + Ok(Some(Arc::new(reloadable))) + } + _ => Ok(None), + } +} + +pub fn maybe_watch_client_tls_config( + client_tls_config: Arc, + channel_manager: ChannelManager, +) -> Result<()> { + maybe_watch_tls_config(client_tls_config, move || { + // Clear all existing channels to force reconnection with new certificates + channel_manager.clear_all_channels(); + info!("Cleared all existing channels to use new TLS certificates."); + }) +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)] pub struct ClientTlsOption { /// Whether to enable TLS for client. pub enabled: bool, pub server_ca_cert_path: Option, pub client_cert_path: Option, pub client_key_path: Option, + #[serde(default)] + pub watch: bool, } #[derive(Clone, Debug, PartialEq, Eq)] @@ -602,6 +677,7 @@ mod tests { server_ca_cert_path: Some("some_server_path".to_string()), client_cert_path: Some("some_cert_path".to_string()), client_key_path: Some("some_key_path".to_string()), + watch: false, }); assert_eq!( @@ -623,6 +699,7 @@ mod tests { server_ca_cert_path: Some("some_server_path".to_string()), client_cert_path: Some("some_cert_path".to_string()), client_key_path: Some("some_key_path".to_string()), + watch: false, }), max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE, max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE, diff --git a/src/common/grpc/src/error.rs b/src/common/grpc/src/error.rs index 147ff70c07..4f9b8e92dd 100644 --- a/src/common/grpc/src/error.rs +++ b/src/common/grpc/src/error.rs @@ -38,6 +38,15 @@ pub enum Error { location: Location, }, + #[snafu(display("Failed to watch config file path: {}", path))] + FileWatch { + path: String, + #[snafu(source)] + error: notify::Error, + #[snafu(implicit)] + location: Location, + }, + #[snafu(display( "Write type mismatch, column name: {}, expected: {}, actual: {}", column_name, @@ -108,6 +117,7 @@ impl ErrorExt for Error { match self { Error::InvalidTlsConfig { .. } | Error::InvalidConfigFilePath { .. } + | Error::FileWatch { .. } | Error::TypeMismatch { .. } | Error::InvalidFlightData { .. } | Error::NotSupported { .. } => StatusCode::InvalidArguments, diff --git a/src/common/grpc/src/lib.rs b/src/common/grpc/src/lib.rs index 287644b529..8527dd079b 100644 --- a/src/common/grpc/src/lib.rs +++ b/src/common/grpc/src/lib.rs @@ -16,6 +16,7 @@ pub mod channel_manager; pub mod error; pub mod flight; pub mod precision; +pub mod reloadable_tls; pub mod select; pub use arrow_flight::FlightData; diff --git a/src/common/grpc/src/reloadable_tls.rs b/src/common/grpc/src/reloadable_tls.rs new file mode 100644 index 0000000000..c1bd3aca52 --- /dev/null +++ b/src/common/grpc/src/reloadable_tls.rs @@ -0,0 +1,163 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::path::Path; +use std::result::Result as StdResult; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::mpsc::channel; +use std::sync::{Arc, RwLock}; + +use common_telemetry::{error, info}; +use notify::{EventKind, RecursiveMode, Watcher}; +use snafu::ResultExt; + +use crate::error::{FileWatchSnafu, Result}; + +/// A trait for loading TLS configuration from an option type +pub trait TlsConfigLoader { + type Error; + + /// Load the TLS configuration + fn load(&self) -> StdResult, Self::Error>; + + /// Get paths to certificate files for watching + fn watch_paths(&self) -> Vec<&Path>; + + /// Check if watching is enabled + fn watch_enabled(&self) -> bool; +} + +/// A mutable container for TLS config +/// +/// This struct allows dynamic reloading of certificates and keys. +/// It's generic over the config type (e.g., ServerConfig, ClientTlsConfig) +/// and the option type (e.g., TlsOption, ClientTlsOption). +#[derive(Debug)] +pub struct ReloadableTlsConfig +where + O: TlsConfigLoader, +{ + tls_option: O, + config: RwLock>, + version: AtomicUsize, +} + +impl ReloadableTlsConfig +where + O: TlsConfigLoader, +{ + /// Create config by loading configuration from the option type + pub fn try_new(tls_option: O) -> StdResult { + let config = tls_option.load()?; + Ok(Self { + tls_option, + config: RwLock::new(config), + version: AtomicUsize::new(0), + }) + } + + /// Reread certificates and keys from file system. + pub fn reload(&self) -> StdResult<(), O::Error> { + let config = self.tls_option.load()?; + *self.config.write().unwrap() = config; + self.version.fetch_add(1, Ordering::Relaxed); + Ok(()) + } + + /// Get the config held by this container + pub fn get_config(&self) -> Option + where + T: Clone, + { + self.config.read().unwrap().clone() + } + + /// Get associated option + pub fn get_tls_option(&self) -> &O { + &self.tls_option + } + + /// Get version of current config + /// + /// this version will auto increase when config get reloaded. + pub fn get_version(&self) -> usize { + self.version.load(Ordering::Relaxed) + } +} + +/// Watch TLS configuration files for changes and reload automatically +/// +/// This is a generic function that works with any ReloadableTlsConfig. +/// When changes are detected, it calls the provided callback after reloading. +/// +/// T: the original TLS config +/// O: the compiled TLS option +/// F: the hook function to be called after reloading +/// E: the error type for the loading operation +pub fn maybe_watch_tls_config( + tls_config: Arc>, + on_reload: F, +) -> Result<()> +where + T: Send + Sync + 'static, + O: TlsConfigLoader + Send + Sync + 'static, + E: std::error::Error + Send + Sync + 'static, + F: Fn() + Send + 'static, +{ + if !tls_config.get_tls_option().watch_enabled() { + return Ok(()); + } + + let tls_config_for_watcher = tls_config.clone(); + + let (tx, rx) = channel::>(); + let mut watcher = notify::recommended_watcher(tx).context(FileWatchSnafu { path: "" })?; + + // Watch all paths returned by the TlsConfigLoader + for path in tls_config.get_tls_option().watch_paths() { + watcher + .watch(path, RecursiveMode::NonRecursive) + .with_context(|_| FileWatchSnafu { + path: path.display().to_string(), + })?; + } + + info!("Spawning background task for watching TLS cert/key file changes"); + std::thread::spawn(move || { + let _watcher = watcher; + loop { + match rx.recv() { + Ok(Ok(event)) => { + if let EventKind::Modify(_) | EventKind::Create(_) = event.kind { + info!("Detected TLS cert/key file change: {:?}", event); + if let Err(err) = tls_config_for_watcher.reload() { + error!("Failed to reload TLS config: {}", err); + } else { + info!("Reloaded TLS cert/key file successfully."); + on_reload(); + } + } + } + Ok(Err(err)) => { + error!("Failed to watch TLS cert/key file: {}", err); + } + Err(err) => { + error!("TLS cert/key file watcher channel closed: {}", err); + } + } + } + }); + + Ok(()) +} diff --git a/src/common/grpc/tests/mod.rs b/src/common/grpc/tests/mod.rs index a437d21cd9..93188e35fc 100644 --- a/src/common/grpc/tests/mod.rs +++ b/src/common/grpc/tests/mod.rs @@ -13,14 +13,15 @@ // limitations under the License. use common_grpc::channel_manager::{ - ChannelConfig, ChannelManager, ClientTlsOption, load_tls_config, + ChannelConfig, ChannelManager, ClientTlsOption, load_client_tls_config, + maybe_watch_client_tls_config, }; #[tokio::test] async fn test_mtls_config() { // test no config let config = ChannelConfig::new(); - let re = load_tls_config(config.client_tls.as_ref()); + let re = load_client_tls_config(config.client_tls.clone()); assert!(re.is_ok()); assert!(re.unwrap().is_none()); @@ -30,9 +31,10 @@ async fn test_mtls_config() { server_ca_cert_path: Some("tests/tls/wrong_ca.pem".to_string()), client_cert_path: Some("tests/tls/wrong_client.pem".to_string()), client_key_path: Some("tests/tls/wrong_client.key".to_string()), + watch: false, }); - let re = load_tls_config(config.client_tls.as_ref()); + let re = load_client_tls_config(config.client_tls.clone()); assert!(re.is_err()); // test corrupted file content @@ -41,9 +43,10 @@ async fn test_mtls_config() { server_ca_cert_path: Some("tests/tls/ca.pem".to_string()), client_cert_path: Some("tests/tls/client.pem".to_string()), client_key_path: Some("tests/tls/corrupted".to_string()), + watch: false, }); - let tls_config = load_tls_config(config.client_tls.as_ref()).unwrap(); + let tls_config = load_client_tls_config(config.client_tls.clone()).unwrap(); let re = ChannelManager::with_config(config, tls_config); let re = re.get("127.0.0.1:0"); @@ -55,10 +58,112 @@ async fn test_mtls_config() { server_ca_cert_path: Some("tests/tls/ca.pem".to_string()), client_cert_path: Some("tests/tls/client.pem".to_string()), client_key_path: Some("tests/tls/client.key".to_string()), + watch: false, }); - let tls_config = load_tls_config(config.client_tls.as_ref()).unwrap(); + let tls_config = load_client_tls_config(config.client_tls.clone()).unwrap(); let re = ChannelManager::with_config(config, tls_config); let re = re.get("127.0.0.1:0"); let _ = re.unwrap(); } + +#[tokio::test] +async fn test_reloadable_client_tls_config() { + common_telemetry::init_default_ut_logging(); + + let dir = tempfile::tempdir().unwrap(); + let cert_path = dir.path().join("client.pem"); + let key_path = dir.path().join("client.key"); + + std::fs::copy("tests/tls/client.pem", &cert_path).expect("failed to copy cert to tmpdir"); + std::fs::copy("tests/tls/client.key", &key_path).expect("failed to copy key to tmpdir"); + + assert!(std::fs::exists(&cert_path).unwrap()); + assert!(std::fs::exists(&key_path).unwrap()); + + let client_tls_option = ClientTlsOption { + enabled: true, + server_ca_cert_path: Some("tests/tls/ca.pem".to_string()), + client_cert_path: Some( + cert_path + .clone() + .into_os_string() + .into_string() + .expect("failed to convert path to string"), + ), + client_key_path: Some( + key_path + .clone() + .into_os_string() + .into_string() + .expect("failed to convert path to string"), + ), + watch: true, + }; + + let reloadable_config = load_client_tls_config(Some(client_tls_option)) + .expect("failed to load tls config") + .expect("tls config should be present"); + + let config = ChannelConfig::new(); + let manager = ChannelManager::with_config(config, Some(reloadable_config.clone())); + + maybe_watch_client_tls_config(reloadable_config.clone(), manager.clone()) + .expect("failed to watch client config"); + + assert_eq!(0, reloadable_config.get_version()); + assert!(reloadable_config.get_config().is_some()); + + // Create a channel to verify it gets cleared on reload + let _ = manager.get("127.0.0.1:0").expect("failed to get channel"); + + // Simulate file change by copying a different key file + let tmp_file = key_path.with_extension("tmp"); + std::fs::copy("tests/tls/server.key", &tmp_file).expect("Failed to copy temp key file"); + std::fs::rename(&tmp_file, &key_path).expect("Failed to rename temp key file"); + + const MAX_RETRIES: usize = 30; + let mut retries = 0; + let mut version_updated = false; + + while retries < MAX_RETRIES { + if reloadable_config.get_version() > 0 { + version_updated = true; + break; + } + std::thread::sleep(std::time::Duration::from_millis(100)); + retries += 1; + } + + assert!(version_updated, "TLS config did not reload in time"); + assert!(reloadable_config.get_version() > 0); + assert!(reloadable_config.get_config().is_some()); +} + +#[tokio::test] +async fn test_channel_manager_with_reloadable_tls() { + common_telemetry::init_default_ut_logging(); + + let client_tls_option = ClientTlsOption { + enabled: true, + server_ca_cert_path: Some("tests/tls/ca.pem".to_string()), + client_cert_path: Some("tests/tls/client.pem".to_string()), + client_key_path: Some("tests/tls/client.key".to_string()), + watch: false, + }; + + let reloadable_config = load_client_tls_config(Some(client_tls_option)) + .expect("failed to load tls config") + .expect("tls config should be present"); + + let config = ChannelConfig::new(); + let manager = ChannelManager::with_config(config, Some(reloadable_config.clone())); + + // Test that we can get a channel + let channel = manager.get("127.0.0.1:0"); + assert!(channel.is_ok()); + + // Test that config is properly set + assert_eq!(0, reloadable_config.get_version()); + assert!(reloadable_config.get_config().is_some()); +} diff --git a/src/flow/src/batching_mode/frontend_client.rs b/src/flow/src/batching_mode/frontend_client.rs index e9994b5b14..d79c3033e3 100644 --- a/src/flow/src/batching_mode/frontend_client.rs +++ b/src/flow/src/batching_mode/frontend_client.rs @@ -23,7 +23,7 @@ use api::v1::query_request::Query; use api::v1::{CreateTableExpr, QueryRequest}; use client::{Client, Database}; use common_error::ext::{BoxedError, ErrorExt}; -use common_grpc::channel_manager::{ChannelConfig, ChannelManager, load_tls_config}; +use common_grpc::channel_manager::{ChannelConfig, ChannelManager, load_client_tls_config}; use common_meta::cluster::{NodeInfo, NodeInfoKey, Role}; use common_meta::peer::Peer; use common_meta::rpc::store::RangeRequest; @@ -124,7 +124,7 @@ impl FrontendClient { .connect_timeout(batch_opts.grpc_conn_timeout) .timeout(batch_opts.query_timeout); - let tls_config = load_tls_config(batch_opts.frontend_tls.as_ref()) + let tls_config = load_client_tls_config(batch_opts.frontend_tls.clone()) .context(InvalidClientConfigSnafu)?; ChannelManager::with_config(cfg, tls_config) }, diff --git a/src/frontend/src/server.rs b/src/frontend/src/server.rs index 6c19109ab2..d70aa3dd49 100644 --- a/src/frontend/src/server.rs +++ b/src/frontend/src/server.rs @@ -36,7 +36,7 @@ use servers::postgres::PostgresServer; use servers::query_handler::grpc::ServerGrpcQueryHandlerAdapter; use servers::query_handler::sql::ServerSqlQueryHandlerAdapter; use servers::server::{Server, ServerHandlers}; -use servers::tls::{ReloadableTlsServerConfig, maybe_watch_tls_config}; +use servers::tls::{ReloadableTlsServerConfig, maybe_watch_server_tls_config}; use snafu::ResultExt; use crate::error::{self, Result, StartServerSnafu, TomlFormatSnafu}; @@ -258,7 +258,7 @@ where ); // will not watch if watch is disabled in tls option - maybe_watch_tls_config(tls_server_config.clone()).context(StartServerSnafu)?; + maybe_watch_server_tls_config(tls_server_config.clone()).context(StartServerSnafu)?; let mysql_server = MysqlServer::create_server( common_runtime::global_runtime(), @@ -287,7 +287,7 @@ where ReloadableTlsServerConfig::try_new(opts.tls.clone()).context(StartServerSnafu)?, ); - maybe_watch_tls_config(tls_server_config.clone()).context(StartServerSnafu)?; + maybe_watch_server_tls_config(tls_server_config.clone()).context(StartServerSnafu)?; let pg_server = Box::new(PostgresServer::new( ServerSqlQueryHandlerAdapter::arc(instance.clone()), diff --git a/src/servers/src/mysql/server.rs b/src/servers/src/mysql/server.rs index c27d3ebbda..bda027ca55 100644 --- a/src/servers/src/mysql/server.rs +++ b/src/servers/src/mysql/server.rs @@ -99,7 +99,7 @@ impl MysqlSpawnConfig { } fn tls(&self) -> Option> { - self.tls.get_server_config() + self.tls.get_config() } } diff --git a/src/servers/src/postgres/server.rs b/src/servers/src/postgres/server.rs index 3c7a711780..3478a6da78 100644 --- a/src/servers/src/postgres/server.rs +++ b/src/servers/src/postgres/server.rs @@ -80,7 +80,7 @@ impl PostgresServer { let process_manager = self.process_manager.clone(); accepting_stream.for_each(move |tcp_stream| { let io_runtime = io_runtime.clone(); - let tls_acceptor = tls_server_config.get_server_config().map(TlsAcceptor::from); + let tls_acceptor = tls_server_config.get_config().map(TlsAcceptor::from); let handler_maker = handler_maker.clone(); let process_id = process_manager.as_ref().map(|p| p.next_id()).unwrap_or(0); diff --git a/src/servers/src/tls.rs b/src/servers/src/tls.rs index 115e3d39c6..f4afc477b7 100644 --- a/src/servers/src/tls.rs +++ b/src/servers/src/tls.rs @@ -15,12 +15,10 @@ use std::fs::File; use std::io::{BufReader, Error as IoError, ErrorKind}; use std::path::Path; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::mpsc::channel; -use std::sync::{Arc, RwLock}; +use std::sync::Arc; -use common_telemetry::{error, info}; -use notify::{EventKind, RecursiveMode, Watcher}; +use common_grpc::reloadable_tls::{ReloadableTlsConfig, TlsConfigLoader}; +use common_telemetry::error; use rustls::ServerConfig; use rustls_pemfile::{Item, certs, read_one}; use rustls_pki_types::{CertificateDer, PrivateKeyDer}; @@ -28,7 +26,7 @@ use serde::{Deserialize, Serialize}; use snafu::ResultExt; use strum::EnumString; -use crate::error::{FileWatchSnafu, InternalIoSnafu, Result}; +use crate::error::{InternalIoSnafu, Result}; /// TlsMode is used for Mysql and Postgres server start up. #[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq, EnumString)] @@ -149,96 +147,34 @@ impl TlsOption { } } -/// A mutable container for TLS server config -/// -/// This struct allows dynamic reloading of server certificates and keys -pub struct ReloadableTlsServerConfig { - tls_option: TlsOption, - config: RwLock>>, - version: AtomicUsize, -} +impl TlsConfigLoader> for TlsOption { + type Error = crate::error::Error; -impl ReloadableTlsServerConfig { - /// Create server config by loading configuration from `TlsOption` - pub fn try_new(tls_option: TlsOption) -> Result { - let server_config = tls_option.setup()?; - Ok(Self { - tls_option, - config: RwLock::new(server_config.map(Arc::new)), - version: AtomicUsize::new(0), - }) + fn load(&self) -> Result>> { + Ok(self.setup()?.map(Arc::new)) } - /// Reread server certificates and keys from file system. - pub fn reload(&self) -> Result<()> { - let server_config = self.tls_option.setup()?; - *self.config.write().unwrap() = server_config.map(Arc::new); - self.version.fetch_add(1, Ordering::Relaxed); - Ok(()) + fn watch_paths(&self) -> Vec<&Path> { + vec![self.cert_path(), self.key_path()] } - /// Get the server config hold by this container - pub fn get_server_config(&self) -> Option> { - self.config.read().unwrap().clone() - } - - /// Get associated `TlsOption` - pub fn get_tls_option(&self) -> &TlsOption { - &self.tls_option - } - - /// Get version of current config - /// - /// this version will auto increase when server config get reloaded. - pub fn get_version(&self) -> usize { - self.version.load(Ordering::Relaxed) + fn watch_enabled(&self) -> bool { + self.mode != TlsMode::Disable && self.watch } } -pub fn maybe_watch_tls_config(tls_server_config: Arc) -> Result<()> { - if !tls_server_config.get_tls_option().watch_enabled() { - return Ok(()); - } +/// Type alias for server-side reloadable TLS config +pub type ReloadableTlsServerConfig = ReloadableTlsConfig, TlsOption>; - let tls_server_config_for_watcher = tls_server_config.clone(); - - let (tx, rx) = channel::>(); - let mut watcher = notify::recommended_watcher(tx).context(FileWatchSnafu { path: "" })?; - - let cert_path = tls_server_config.get_tls_option().cert_path(); - watcher - .watch(cert_path, RecursiveMode::NonRecursive) - .with_context(|_| FileWatchSnafu { - path: cert_path.display().to_string(), - })?; - - let key_path = tls_server_config.get_tls_option().key_path(); - watcher - .watch(key_path, RecursiveMode::NonRecursive) - .with_context(|_| FileWatchSnafu { - path: key_path.display().to_string(), - })?; - - std::thread::spawn(move || { - let _watcher = watcher; - while let Ok(res) = rx.recv() { - if let Ok(event) = res { - match event.kind { - EventKind::Modify(_) | EventKind::Create(_) => { - info!("Detected TLS cert/key file change: {:?}", event); - if let Err(err) = tls_server_config_for_watcher.reload() { - error!(err; "Failed to reload TLS server config"); - } else { - info!("Reloaded TLS cert/key file successfully."); - } - } - _ => {} - } - } +/// Convenience function for watching server TLS configuration +pub fn maybe_watch_server_tls_config( + tls_server_config: Arc, +) -> Result<()> { + common_grpc::reloadable_tls::maybe_watch_tls_config(tls_server_config, || {}).map_err(|e| { + crate::error::Error::Internal { + err_msg: format!("Failed to watch TLS config: {}", e), } - }); - - Ok(()) + }) } #[cfg(test)] @@ -434,10 +370,11 @@ mod tests { let server_config = Arc::new( ReloadableTlsServerConfig::try_new(server_tls).expect("failed to create server config"), ); - maybe_watch_tls_config(server_config.clone()).expect("failed to watch server config"); + maybe_watch_server_tls_config(server_config.clone()) + .expect("failed to watch server config"); assert_eq!(0, server_config.get_version()); - assert!(server_config.get_server_config().is_some()); + assert!(server_config.get_config().is_some()); let tmp_file = key_path.with_extension("tmp"); std::fs::copy("tests/ssl/server-pkcs8.key", &tmp_file) @@ -459,6 +396,6 @@ mod tests { assert!(version_updated, "TLS config did not reload in time"); assert!(server_config.get_version() > 0); - assert!(server_config.get_server_config().is_some()); + assert!(server_config.get_config().is_some()); } } diff --git a/tests-integration/tests/grpc.rs b/tests-integration/tests/grpc.rs index f02d88f45b..0b849f5c97 100644 --- a/tests-integration/tests/grpc.rs +++ b/tests-integration/tests/grpc.rs @@ -971,6 +971,7 @@ pub async fn test_grpc_tls_config(store_type: StorageType) { server_ca_cert_path: Some(ca_path), client_cert_path: Some(client_cert_path), client_key_path: Some(client_key_path), + watch: false, }; { let grpc_client =