feat: reloadable tls client config (#7230)

* feat: add ReloadableClientTlsConfig

Signed-off-by: shuiyisong <xixing.sys@gmail.com>

* refactor: merge tls option with the reloadable

Signed-off-by: shuiyisong <xixing.sys@gmail.com>

* chore: rename function

Signed-off-by: shuiyisong <xixing.sys@gmail.com>

* chore: update comment

Signed-off-by: shuiyisong <xixing.sys@gmail.com>

* chore: extract tls loader

Signed-off-by: shuiyisong <xixing.sys@gmail.com>

* chore: minor comment update

Signed-off-by: shuiyisong <xixing.sys@gmail.com>

* chore: add serde default to watch field

Signed-off-by: shuiyisong <xixing.sys@gmail.com>

* chore: minor update

Signed-off-by: shuiyisong <xixing.sys@gmail.com>

* chore: add log

Signed-off-by: shuiyisong <xixing.sys@gmail.com>

* fix: add error log

Signed-off-by: shuiyisong <xixing.sys@gmail.com>

---------

Signed-off-by: shuiyisong <xixing.sys@gmail.com>
This commit is contained in:
shuiyisong
2025-11-24 19:52:11 +08:00
committed by GitHub
parent b32ca3ad86
commit 9f4902b10a
14 changed files with 416 additions and 118 deletions

2
Cargo.lock generated
View File

@@ -2265,11 +2265,13 @@ dependencies = [
"hyper 1.6.0",
"hyper-util",
"lazy_static",
"notify",
"prost 0.13.5",
"rand 0.9.1",
"serde",
"serde_json",
"snafu 0.8.6",
"tempfile",
"tokio",
"tokio-util",
"tonic 0.13.1",

View File

@@ -21,7 +21,7 @@ use api::v1::prometheus_gateway_client::PrometheusGatewayClient;
use api::v1::region::region_client::RegionClient as PbRegionClient;
use arrow_flight::flight_service_client::FlightServiceClient;
use common_grpc::channel_manager::{
ChannelConfig, ChannelManager, ClientTlsOption, load_tls_config,
ChannelConfig, ChannelManager, ClientTlsOption, load_client_tls_config,
};
use parking_lot::RwLock;
use snafu::{OptionExt, ResultExt};
@@ -95,9 +95,9 @@ impl Client {
U: AsRef<str>,
A: AsRef<[U]>,
{
let channel_config = ChannelConfig::default().client_tls_config(client_tls);
let tls_config = load_tls_config(channel_config.client_tls.as_ref())
.context(error::CreateTlsChannelSnafu)?;
let channel_config = ChannelConfig::default().client_tls_config(client_tls.clone());
let tls_config =
load_client_tls_config(Some(client_tls)).context(error::CreateTlsChannelSnafu)?;
let channel_manager = ChannelManager::with_config(channel_config, tls_config);
Ok(Self::with_manager_and_urls(channel_manager, urls))
}

View File

@@ -23,6 +23,7 @@ datatypes.workspace = true
flatbuffers = "25.2"
hyper.workspace = true
lazy_static.workspace = true
notify.workspace = true
prost.workspace = true
serde.workspace = true
serde_json.workspace = true
@@ -37,6 +38,7 @@ vec1 = "1.12"
criterion = "0.4"
hyper-util = { workspace = true, features = ["tokio"] }
rand.workspace = true
tempfile.workspace = true
[[bench]]
name = "bench_main"

View File

@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::path::Path;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::time::Duration;
@@ -30,6 +31,7 @@ use tonic::transport::{
use tower::Service;
use crate::error::{CreateChannelSnafu, InvalidConfigFilePathSnafu, Result};
use crate::reloadable_tls::{ReloadableTlsConfig, TlsConfigLoader, maybe_watch_tls_config};
const RECYCLE_CHANNEL_INTERVAL_SECS: u64 = 60;
pub const DEFAULT_GRPC_REQUEST_TIMEOUT_SECS: u64 = 10;
@@ -50,7 +52,7 @@ pub struct ChannelManager {
struct Inner {
id: u64,
config: ChannelConfig,
client_tls_config: Option<ClientTlsConfig>,
reloadable_client_tls_config: Option<Arc<ReloadableClientTlsConfig>>,
pool: Arc<Pool>,
channel_recycle_started: AtomicBool,
cancel: CancellationToken,
@@ -78,7 +80,7 @@ impl Inner {
Self {
id,
config,
client_tls_config: None,
reloadable_client_tls_config: None,
pool,
channel_recycle_started: AtomicBool::new(false),
cancel,
@@ -91,13 +93,17 @@ impl ChannelManager {
Default::default()
}
/// unified with config function that support tls config
/// use [`load_tls_config`] to load tls config from file system
pub fn with_config(config: ChannelConfig, tls_config: Option<ClientTlsConfig>) -> Self {
/// Create a ChannelManager with configuration and optional TLS config
///
/// Use [`load_client_tls_config`] to create TLS configuration from `ClientTlsOption`.
/// The TLS config supports both static (watch disabled) and dynamic reloading (watch enabled).
/// If you want to use dynamic reloading, please **manually** invoke [`maybe_watch_client_tls_config`] after this method.
pub fn with_config(
config: ChannelConfig,
reloadable_tls_config: Option<Arc<ReloadableClientTlsConfig>>,
) -> Self {
let mut inner = Inner::with_config(config.clone());
if let Some(tls_config) = tls_config {
inner.client_tls_config = Some(tls_config);
}
inner.reloadable_client_tls_config = reloadable_tls_config;
Self {
inner: Arc::new(inner),
}
@@ -172,8 +178,21 @@ impl ChannelManager {
self.pool().retain_channel(f);
}
/// Clear all channels to force reconnection.
/// This should be called when TLS configuration changes to ensure new connections use updated certificates.
pub fn clear_all_channels(&self) {
self.pool().retain_channel(|_, _| false);
}
fn build_endpoint(&self, addr: &str) -> Result<Endpoint> {
let http_prefix = if self.inner.client_tls_config.is_some() {
// Get the latest TLS config from reloadable config (which handles both static and dynamic cases)
let tls_config = self
.inner
.reloadable_client_tls_config
.as_ref()
.and_then(|c| c.get_config());
let http_prefix = if tls_config.is_some() {
"https"
} else {
"http"
@@ -212,9 +231,9 @@ impl ChannelManager {
if let Some(enabled) = self.config().http2_adaptive_window {
endpoint = endpoint.http2_adaptive_window(enabled);
}
if let Some(tls_config) = &self.inner.client_tls_config {
if let Some(tls_config) = tls_config {
endpoint = endpoint
.tls_config(tls_config.clone())
.tls_config(tls_config)
.context(CreateChannelSnafu { addr })?;
}
@@ -248,7 +267,7 @@ impl ChannelManager {
}
}
pub fn load_tls_config(tls_option: Option<&ClientTlsOption>) -> Result<Option<ClientTlsConfig>> {
fn load_tls_config(tls_option: Option<&ClientTlsOption>) -> Result<Option<ClientTlsConfig>> {
let path_config = match tls_option {
Some(path_config) if path_config.enabled => path_config,
_ => return Ok(None),
@@ -276,13 +295,69 @@ pub fn load_tls_config(tls_option: Option<&ClientTlsOption>) -> Result<Option<Cl
Ok(Some(tls_config))
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
impl TlsConfigLoader<ClientTlsConfig> for ClientTlsOption {
type Error = crate::error::Error;
fn load(&self) -> Result<Option<ClientTlsConfig>> {
load_tls_config(Some(self))
}
fn watch_paths(&self) -> Vec<&Path> {
let mut paths = Vec::new();
if let Some(cert_path) = &self.client_cert_path {
paths.push(Path::new(cert_path.as_str()));
}
if let Some(key_path) = &self.client_key_path {
paths.push(Path::new(key_path.as_str()));
}
if let Some(ca_path) = &self.server_ca_cert_path {
paths.push(Path::new(ca_path.as_str()));
}
paths
}
fn watch_enabled(&self) -> bool {
self.enabled && self.watch
}
}
/// Type alias for client-side reloadable TLS config
pub type ReloadableClientTlsConfig = ReloadableTlsConfig<ClientTlsConfig, ClientTlsOption>;
/// Load client TLS configuration from `ClientTlsOption` and return a `ReloadableClientTlsConfig`.
/// This is the primary way to create TLS configuration for the ChannelManager.
pub fn load_client_tls_config(
tls_option: Option<ClientTlsOption>,
) -> Result<Option<Arc<ReloadableClientTlsConfig>>> {
match tls_option {
Some(option) if option.enabled => {
let reloadable = ReloadableClientTlsConfig::try_new(option)?;
Ok(Some(Arc::new(reloadable)))
}
_ => Ok(None),
}
}
pub fn maybe_watch_client_tls_config(
client_tls_config: Arc<ReloadableClientTlsConfig>,
channel_manager: ChannelManager,
) -> Result<()> {
maybe_watch_tls_config(client_tls_config, move || {
// Clear all existing channels to force reconnection with new certificates
channel_manager.clear_all_channels();
info!("Cleared all existing channels to use new TLS certificates.");
})
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct ClientTlsOption {
/// Whether to enable TLS for client.
pub enabled: bool,
pub server_ca_cert_path: Option<String>,
pub client_cert_path: Option<String>,
pub client_key_path: Option<String>,
#[serde(default)]
pub watch: bool,
}
#[derive(Clone, Debug, PartialEq, Eq)]
@@ -602,6 +677,7 @@ mod tests {
server_ca_cert_path: Some("some_server_path".to_string()),
client_cert_path: Some("some_cert_path".to_string()),
client_key_path: Some("some_key_path".to_string()),
watch: false,
});
assert_eq!(
@@ -623,6 +699,7 @@ mod tests {
server_ca_cert_path: Some("some_server_path".to_string()),
client_cert_path: Some("some_cert_path".to_string()),
client_key_path: Some("some_key_path".to_string()),
watch: false,
}),
max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE,
max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE,

View File

@@ -38,6 +38,15 @@ pub enum Error {
location: Location,
},
#[snafu(display("Failed to watch config file path: {}", path))]
FileWatch {
path: String,
#[snafu(source)]
error: notify::Error,
#[snafu(implicit)]
location: Location,
},
#[snafu(display(
"Write type mismatch, column name: {}, expected: {}, actual: {}",
column_name,
@@ -108,6 +117,7 @@ impl ErrorExt for Error {
match self {
Error::InvalidTlsConfig { .. }
| Error::InvalidConfigFilePath { .. }
| Error::FileWatch { .. }
| Error::TypeMismatch { .. }
| Error::InvalidFlightData { .. }
| Error::NotSupported { .. } => StatusCode::InvalidArguments,

View File

@@ -16,6 +16,7 @@ pub mod channel_manager;
pub mod error;
pub mod flight;
pub mod precision;
pub mod reloadable_tls;
pub mod select;
pub use arrow_flight::FlightData;

View File

@@ -0,0 +1,163 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::path::Path;
use std::result::Result as StdResult;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::mpsc::channel;
use std::sync::{Arc, RwLock};
use common_telemetry::{error, info};
use notify::{EventKind, RecursiveMode, Watcher};
use snafu::ResultExt;
use crate::error::{FileWatchSnafu, Result};
/// A trait for loading TLS configuration from an option type
pub trait TlsConfigLoader<T> {
type Error;
/// Load the TLS configuration
fn load(&self) -> StdResult<Option<T>, Self::Error>;
/// Get paths to certificate files for watching
fn watch_paths(&self) -> Vec<&Path>;
/// Check if watching is enabled
fn watch_enabled(&self) -> bool;
}
/// A mutable container for TLS config
///
/// This struct allows dynamic reloading of certificates and keys.
/// It's generic over the config type (e.g., ServerConfig, ClientTlsConfig)
/// and the option type (e.g., TlsOption, ClientTlsOption).
#[derive(Debug)]
pub struct ReloadableTlsConfig<T, O>
where
O: TlsConfigLoader<T>,
{
tls_option: O,
config: RwLock<Option<T>>,
version: AtomicUsize,
}
impl<T, O> ReloadableTlsConfig<T, O>
where
O: TlsConfigLoader<T>,
{
/// Create config by loading configuration from the option type
pub fn try_new(tls_option: O) -> StdResult<Self, O::Error> {
let config = tls_option.load()?;
Ok(Self {
tls_option,
config: RwLock::new(config),
version: AtomicUsize::new(0),
})
}
/// Reread certificates and keys from file system.
pub fn reload(&self) -> StdResult<(), O::Error> {
let config = self.tls_option.load()?;
*self.config.write().unwrap() = config;
self.version.fetch_add(1, Ordering::Relaxed);
Ok(())
}
/// Get the config held by this container
pub fn get_config(&self) -> Option<T>
where
T: Clone,
{
self.config.read().unwrap().clone()
}
/// Get associated option
pub fn get_tls_option(&self) -> &O {
&self.tls_option
}
/// Get version of current config
///
/// this version will auto increase when config get reloaded.
pub fn get_version(&self) -> usize {
self.version.load(Ordering::Relaxed)
}
}
/// Watch TLS configuration files for changes and reload automatically
///
/// This is a generic function that works with any ReloadableTlsConfig.
/// When changes are detected, it calls the provided callback after reloading.
///
/// T: the original TLS config
/// O: the compiled TLS option
/// F: the hook function to be called after reloading
/// E: the error type for the loading operation
pub fn maybe_watch_tls_config<T, O, F, E>(
tls_config: Arc<ReloadableTlsConfig<T, O>>,
on_reload: F,
) -> Result<()>
where
T: Send + Sync + 'static,
O: TlsConfigLoader<T, Error = E> + Send + Sync + 'static,
E: std::error::Error + Send + Sync + 'static,
F: Fn() + Send + 'static,
{
if !tls_config.get_tls_option().watch_enabled() {
return Ok(());
}
let tls_config_for_watcher = tls_config.clone();
let (tx, rx) = channel::<notify::Result<notify::Event>>();
let mut watcher = notify::recommended_watcher(tx).context(FileWatchSnafu { path: "<none>" })?;
// Watch all paths returned by the TlsConfigLoader
for path in tls_config.get_tls_option().watch_paths() {
watcher
.watch(path, RecursiveMode::NonRecursive)
.with_context(|_| FileWatchSnafu {
path: path.display().to_string(),
})?;
}
info!("Spawning background task for watching TLS cert/key file changes");
std::thread::spawn(move || {
let _watcher = watcher;
loop {
match rx.recv() {
Ok(Ok(event)) => {
if let EventKind::Modify(_) | EventKind::Create(_) = event.kind {
info!("Detected TLS cert/key file change: {:?}", event);
if let Err(err) = tls_config_for_watcher.reload() {
error!("Failed to reload TLS config: {}", err);
} else {
info!("Reloaded TLS cert/key file successfully.");
on_reload();
}
}
}
Ok(Err(err)) => {
error!("Failed to watch TLS cert/key file: {}", err);
}
Err(err) => {
error!("TLS cert/key file watcher channel closed: {}", err);
}
}
}
});
Ok(())
}

View File

@@ -13,14 +13,15 @@
// limitations under the License.
use common_grpc::channel_manager::{
ChannelConfig, ChannelManager, ClientTlsOption, load_tls_config,
ChannelConfig, ChannelManager, ClientTlsOption, load_client_tls_config,
maybe_watch_client_tls_config,
};
#[tokio::test]
async fn test_mtls_config() {
// test no config
let config = ChannelConfig::new();
let re = load_tls_config(config.client_tls.as_ref());
let re = load_client_tls_config(config.client_tls.clone());
assert!(re.is_ok());
assert!(re.unwrap().is_none());
@@ -30,9 +31,10 @@ async fn test_mtls_config() {
server_ca_cert_path: Some("tests/tls/wrong_ca.pem".to_string()),
client_cert_path: Some("tests/tls/wrong_client.pem".to_string()),
client_key_path: Some("tests/tls/wrong_client.key".to_string()),
watch: false,
});
let re = load_tls_config(config.client_tls.as_ref());
let re = load_client_tls_config(config.client_tls.clone());
assert!(re.is_err());
// test corrupted file content
@@ -41,9 +43,10 @@ async fn test_mtls_config() {
server_ca_cert_path: Some("tests/tls/ca.pem".to_string()),
client_cert_path: Some("tests/tls/client.pem".to_string()),
client_key_path: Some("tests/tls/corrupted".to_string()),
watch: false,
});
let tls_config = load_tls_config(config.client_tls.as_ref()).unwrap();
let tls_config = load_client_tls_config(config.client_tls.clone()).unwrap();
let re = ChannelManager::with_config(config, tls_config);
let re = re.get("127.0.0.1:0");
@@ -55,10 +58,112 @@ async fn test_mtls_config() {
server_ca_cert_path: Some("tests/tls/ca.pem".to_string()),
client_cert_path: Some("tests/tls/client.pem".to_string()),
client_key_path: Some("tests/tls/client.key".to_string()),
watch: false,
});
let tls_config = load_tls_config(config.client_tls.as_ref()).unwrap();
let tls_config = load_client_tls_config(config.client_tls.clone()).unwrap();
let re = ChannelManager::with_config(config, tls_config);
let re = re.get("127.0.0.1:0");
let _ = re.unwrap();
}
#[tokio::test]
async fn test_reloadable_client_tls_config() {
common_telemetry::init_default_ut_logging();
let dir = tempfile::tempdir().unwrap();
let cert_path = dir.path().join("client.pem");
let key_path = dir.path().join("client.key");
std::fs::copy("tests/tls/client.pem", &cert_path).expect("failed to copy cert to tmpdir");
std::fs::copy("tests/tls/client.key", &key_path).expect("failed to copy key to tmpdir");
assert!(std::fs::exists(&cert_path).unwrap());
assert!(std::fs::exists(&key_path).unwrap());
let client_tls_option = ClientTlsOption {
enabled: true,
server_ca_cert_path: Some("tests/tls/ca.pem".to_string()),
client_cert_path: Some(
cert_path
.clone()
.into_os_string()
.into_string()
.expect("failed to convert path to string"),
),
client_key_path: Some(
key_path
.clone()
.into_os_string()
.into_string()
.expect("failed to convert path to string"),
),
watch: true,
};
let reloadable_config = load_client_tls_config(Some(client_tls_option))
.expect("failed to load tls config")
.expect("tls config should be present");
let config = ChannelConfig::new();
let manager = ChannelManager::with_config(config, Some(reloadable_config.clone()));
maybe_watch_client_tls_config(reloadable_config.clone(), manager.clone())
.expect("failed to watch client config");
assert_eq!(0, reloadable_config.get_version());
assert!(reloadable_config.get_config().is_some());
// Create a channel to verify it gets cleared on reload
let _ = manager.get("127.0.0.1:0").expect("failed to get channel");
// Simulate file change by copying a different key file
let tmp_file = key_path.with_extension("tmp");
std::fs::copy("tests/tls/server.key", &tmp_file).expect("Failed to copy temp key file");
std::fs::rename(&tmp_file, &key_path).expect("Failed to rename temp key file");
const MAX_RETRIES: usize = 30;
let mut retries = 0;
let mut version_updated = false;
while retries < MAX_RETRIES {
if reloadable_config.get_version() > 0 {
version_updated = true;
break;
}
std::thread::sleep(std::time::Duration::from_millis(100));
retries += 1;
}
assert!(version_updated, "TLS config did not reload in time");
assert!(reloadable_config.get_version() > 0);
assert!(reloadable_config.get_config().is_some());
}
#[tokio::test]
async fn test_channel_manager_with_reloadable_tls() {
common_telemetry::init_default_ut_logging();
let client_tls_option = ClientTlsOption {
enabled: true,
server_ca_cert_path: Some("tests/tls/ca.pem".to_string()),
client_cert_path: Some("tests/tls/client.pem".to_string()),
client_key_path: Some("tests/tls/client.key".to_string()),
watch: false,
};
let reloadable_config = load_client_tls_config(Some(client_tls_option))
.expect("failed to load tls config")
.expect("tls config should be present");
let config = ChannelConfig::new();
let manager = ChannelManager::with_config(config, Some(reloadable_config.clone()));
// Test that we can get a channel
let channel = manager.get("127.0.0.1:0");
assert!(channel.is_ok());
// Test that config is properly set
assert_eq!(0, reloadable_config.get_version());
assert!(reloadable_config.get_config().is_some());
}

View File

@@ -23,7 +23,7 @@ use api::v1::query_request::Query;
use api::v1::{CreateTableExpr, QueryRequest};
use client::{Client, Database};
use common_error::ext::{BoxedError, ErrorExt};
use common_grpc::channel_manager::{ChannelConfig, ChannelManager, load_tls_config};
use common_grpc::channel_manager::{ChannelConfig, ChannelManager, load_client_tls_config};
use common_meta::cluster::{NodeInfo, NodeInfoKey, Role};
use common_meta::peer::Peer;
use common_meta::rpc::store::RangeRequest;
@@ -124,7 +124,7 @@ impl FrontendClient {
.connect_timeout(batch_opts.grpc_conn_timeout)
.timeout(batch_opts.query_timeout);
let tls_config = load_tls_config(batch_opts.frontend_tls.as_ref())
let tls_config = load_client_tls_config(batch_opts.frontend_tls.clone())
.context(InvalidClientConfigSnafu)?;
ChannelManager::with_config(cfg, tls_config)
},

View File

@@ -36,7 +36,7 @@ use servers::postgres::PostgresServer;
use servers::query_handler::grpc::ServerGrpcQueryHandlerAdapter;
use servers::query_handler::sql::ServerSqlQueryHandlerAdapter;
use servers::server::{Server, ServerHandlers};
use servers::tls::{ReloadableTlsServerConfig, maybe_watch_tls_config};
use servers::tls::{ReloadableTlsServerConfig, maybe_watch_server_tls_config};
use snafu::ResultExt;
use crate::error::{self, Result, StartServerSnafu, TomlFormatSnafu};
@@ -258,7 +258,7 @@ where
);
// will not watch if watch is disabled in tls option
maybe_watch_tls_config(tls_server_config.clone()).context(StartServerSnafu)?;
maybe_watch_server_tls_config(tls_server_config.clone()).context(StartServerSnafu)?;
let mysql_server = MysqlServer::create_server(
common_runtime::global_runtime(),
@@ -287,7 +287,7 @@ where
ReloadableTlsServerConfig::try_new(opts.tls.clone()).context(StartServerSnafu)?,
);
maybe_watch_tls_config(tls_server_config.clone()).context(StartServerSnafu)?;
maybe_watch_server_tls_config(tls_server_config.clone()).context(StartServerSnafu)?;
let pg_server = Box::new(PostgresServer::new(
ServerSqlQueryHandlerAdapter::arc(instance.clone()),

View File

@@ -99,7 +99,7 @@ impl MysqlSpawnConfig {
}
fn tls(&self) -> Option<Arc<ServerConfig>> {
self.tls.get_server_config()
self.tls.get_config()
}
}

View File

@@ -80,7 +80,7 @@ impl PostgresServer {
let process_manager = self.process_manager.clone();
accepting_stream.for_each(move |tcp_stream| {
let io_runtime = io_runtime.clone();
let tls_acceptor = tls_server_config.get_server_config().map(TlsAcceptor::from);
let tls_acceptor = tls_server_config.get_config().map(TlsAcceptor::from);
let handler_maker = handler_maker.clone();
let process_id = process_manager.as_ref().map(|p| p.next_id()).unwrap_or(0);

View File

@@ -15,12 +15,10 @@
use std::fs::File;
use std::io::{BufReader, Error as IoError, ErrorKind};
use std::path::Path;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::mpsc::channel;
use std::sync::{Arc, RwLock};
use std::sync::Arc;
use common_telemetry::{error, info};
use notify::{EventKind, RecursiveMode, Watcher};
use common_grpc::reloadable_tls::{ReloadableTlsConfig, TlsConfigLoader};
use common_telemetry::error;
use rustls::ServerConfig;
use rustls_pemfile::{Item, certs, read_one};
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
@@ -28,7 +26,7 @@ use serde::{Deserialize, Serialize};
use snafu::ResultExt;
use strum::EnumString;
use crate::error::{FileWatchSnafu, InternalIoSnafu, Result};
use crate::error::{InternalIoSnafu, Result};
/// TlsMode is used for Mysql and Postgres server start up.
#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq, EnumString)]
@@ -149,96 +147,34 @@ impl TlsOption {
}
}
/// A mutable container for TLS server config
///
/// This struct allows dynamic reloading of server certificates and keys
pub struct ReloadableTlsServerConfig {
tls_option: TlsOption,
config: RwLock<Option<Arc<ServerConfig>>>,
version: AtomicUsize,
}
impl TlsConfigLoader<Arc<ServerConfig>> for TlsOption {
type Error = crate::error::Error;
impl ReloadableTlsServerConfig {
/// Create server config by loading configuration from `TlsOption`
pub fn try_new(tls_option: TlsOption) -> Result<ReloadableTlsServerConfig> {
let server_config = tls_option.setup()?;
Ok(Self {
tls_option,
config: RwLock::new(server_config.map(Arc::new)),
version: AtomicUsize::new(0),
})
fn load(&self) -> Result<Option<Arc<ServerConfig>>> {
Ok(self.setup()?.map(Arc::new))
}
/// Reread server certificates and keys from file system.
pub fn reload(&self) -> Result<()> {
let server_config = self.tls_option.setup()?;
*self.config.write().unwrap() = server_config.map(Arc::new);
self.version.fetch_add(1, Ordering::Relaxed);
Ok(())
fn watch_paths(&self) -> Vec<&Path> {
vec![self.cert_path(), self.key_path()]
}
/// Get the server config hold by this container
pub fn get_server_config(&self) -> Option<Arc<ServerConfig>> {
self.config.read().unwrap().clone()
}
/// Get associated `TlsOption`
pub fn get_tls_option(&self) -> &TlsOption {
&self.tls_option
}
/// Get version of current config
///
/// this version will auto increase when server config get reloaded.
pub fn get_version(&self) -> usize {
self.version.load(Ordering::Relaxed)
fn watch_enabled(&self) -> bool {
self.mode != TlsMode::Disable && self.watch
}
}
pub fn maybe_watch_tls_config(tls_server_config: Arc<ReloadableTlsServerConfig>) -> Result<()> {
if !tls_server_config.get_tls_option().watch_enabled() {
return Ok(());
}
/// Type alias for server-side reloadable TLS config
pub type ReloadableTlsServerConfig = ReloadableTlsConfig<Arc<ServerConfig>, TlsOption>;
let tls_server_config_for_watcher = tls_server_config.clone();
let (tx, rx) = channel::<notify::Result<notify::Event>>();
let mut watcher = notify::recommended_watcher(tx).context(FileWatchSnafu { path: "<none>" })?;
let cert_path = tls_server_config.get_tls_option().cert_path();
watcher
.watch(cert_path, RecursiveMode::NonRecursive)
.with_context(|_| FileWatchSnafu {
path: cert_path.display().to_string(),
})?;
let key_path = tls_server_config.get_tls_option().key_path();
watcher
.watch(key_path, RecursiveMode::NonRecursive)
.with_context(|_| FileWatchSnafu {
path: key_path.display().to_string(),
})?;
std::thread::spawn(move || {
let _watcher = watcher;
while let Ok(res) = rx.recv() {
if let Ok(event) = res {
match event.kind {
EventKind::Modify(_) | EventKind::Create(_) => {
info!("Detected TLS cert/key file change: {:?}", event);
if let Err(err) = tls_server_config_for_watcher.reload() {
error!(err; "Failed to reload TLS server config");
} else {
info!("Reloaded TLS cert/key file successfully.");
}
}
_ => {}
}
}
/// Convenience function for watching server TLS configuration
pub fn maybe_watch_server_tls_config(
tls_server_config: Arc<ReloadableTlsServerConfig>,
) -> Result<()> {
common_grpc::reloadable_tls::maybe_watch_tls_config(tls_server_config, || {}).map_err(|e| {
crate::error::Error::Internal {
err_msg: format!("Failed to watch TLS config: {}", e),
}
});
Ok(())
})
}
#[cfg(test)]
@@ -434,10 +370,11 @@ mod tests {
let server_config = Arc::new(
ReloadableTlsServerConfig::try_new(server_tls).expect("failed to create server config"),
);
maybe_watch_tls_config(server_config.clone()).expect("failed to watch server config");
maybe_watch_server_tls_config(server_config.clone())
.expect("failed to watch server config");
assert_eq!(0, server_config.get_version());
assert!(server_config.get_server_config().is_some());
assert!(server_config.get_config().is_some());
let tmp_file = key_path.with_extension("tmp");
std::fs::copy("tests/ssl/server-pkcs8.key", &tmp_file)
@@ -459,6 +396,6 @@ mod tests {
assert!(version_updated, "TLS config did not reload in time");
assert!(server_config.get_version() > 0);
assert!(server_config.get_server_config().is_some());
assert!(server_config.get_config().is_some());
}
}

View File

@@ -971,6 +971,7 @@ pub async fn test_grpc_tls_config(store_type: StorageType) {
server_ca_cert_path: Some(ca_path),
client_cert_path: Some(client_cert_path),
client_key_path: Some(client_key_path),
watch: false,
};
{
let grpc_client =