mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2025-12-22 22:20:02 +00:00
feat: reloadable tls client config (#7230)
* feat: add ReloadableClientTlsConfig Signed-off-by: shuiyisong <xixing.sys@gmail.com> * refactor: merge tls option with the reloadable Signed-off-by: shuiyisong <xixing.sys@gmail.com> * chore: rename function Signed-off-by: shuiyisong <xixing.sys@gmail.com> * chore: update comment Signed-off-by: shuiyisong <xixing.sys@gmail.com> * chore: extract tls loader Signed-off-by: shuiyisong <xixing.sys@gmail.com> * chore: minor comment update Signed-off-by: shuiyisong <xixing.sys@gmail.com> * chore: add serde default to watch field Signed-off-by: shuiyisong <xixing.sys@gmail.com> * chore: minor update Signed-off-by: shuiyisong <xixing.sys@gmail.com> * chore: add log Signed-off-by: shuiyisong <xixing.sys@gmail.com> * fix: add error log Signed-off-by: shuiyisong <xixing.sys@gmail.com> --------- Signed-off-by: shuiyisong <xixing.sys@gmail.com>
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -2265,11 +2265,13 @@ dependencies = [
|
||||
"hyper 1.6.0",
|
||||
"hyper-util",
|
||||
"lazy_static",
|
||||
"notify",
|
||||
"prost 0.13.5",
|
||||
"rand 0.9.1",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"snafu 0.8.6",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tonic 0.13.1",
|
||||
|
||||
@@ -21,7 +21,7 @@ use api::v1::prometheus_gateway_client::PrometheusGatewayClient;
|
||||
use api::v1::region::region_client::RegionClient as PbRegionClient;
|
||||
use arrow_flight::flight_service_client::FlightServiceClient;
|
||||
use common_grpc::channel_manager::{
|
||||
ChannelConfig, ChannelManager, ClientTlsOption, load_tls_config,
|
||||
ChannelConfig, ChannelManager, ClientTlsOption, load_client_tls_config,
|
||||
};
|
||||
use parking_lot::RwLock;
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
@@ -95,9 +95,9 @@ impl Client {
|
||||
U: AsRef<str>,
|
||||
A: AsRef<[U]>,
|
||||
{
|
||||
let channel_config = ChannelConfig::default().client_tls_config(client_tls);
|
||||
let tls_config = load_tls_config(channel_config.client_tls.as_ref())
|
||||
.context(error::CreateTlsChannelSnafu)?;
|
||||
let channel_config = ChannelConfig::default().client_tls_config(client_tls.clone());
|
||||
let tls_config =
|
||||
load_client_tls_config(Some(client_tls)).context(error::CreateTlsChannelSnafu)?;
|
||||
let channel_manager = ChannelManager::with_config(channel_config, tls_config);
|
||||
Ok(Self::with_manager_and_urls(channel_manager, urls))
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ datatypes.workspace = true
|
||||
flatbuffers = "25.2"
|
||||
hyper.workspace = true
|
||||
lazy_static.workspace = true
|
||||
notify.workspace = true
|
||||
prost.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
@@ -37,6 +38,7 @@ vec1 = "1.12"
|
||||
criterion = "0.4"
|
||||
hyper-util = { workspace = true, features = ["tokio"] }
|
||||
rand.workspace = true
|
||||
tempfile.workspace = true
|
||||
|
||||
[[bench]]
|
||||
name = "bench_main"
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
|
||||
use std::time::Duration;
|
||||
@@ -30,6 +31,7 @@ use tonic::transport::{
|
||||
use tower::Service;
|
||||
|
||||
use crate::error::{CreateChannelSnafu, InvalidConfigFilePathSnafu, Result};
|
||||
use crate::reloadable_tls::{ReloadableTlsConfig, TlsConfigLoader, maybe_watch_tls_config};
|
||||
|
||||
const RECYCLE_CHANNEL_INTERVAL_SECS: u64 = 60;
|
||||
pub const DEFAULT_GRPC_REQUEST_TIMEOUT_SECS: u64 = 10;
|
||||
@@ -50,7 +52,7 @@ pub struct ChannelManager {
|
||||
struct Inner {
|
||||
id: u64,
|
||||
config: ChannelConfig,
|
||||
client_tls_config: Option<ClientTlsConfig>,
|
||||
reloadable_client_tls_config: Option<Arc<ReloadableClientTlsConfig>>,
|
||||
pool: Arc<Pool>,
|
||||
channel_recycle_started: AtomicBool,
|
||||
cancel: CancellationToken,
|
||||
@@ -78,7 +80,7 @@ impl Inner {
|
||||
Self {
|
||||
id,
|
||||
config,
|
||||
client_tls_config: None,
|
||||
reloadable_client_tls_config: None,
|
||||
pool,
|
||||
channel_recycle_started: AtomicBool::new(false),
|
||||
cancel,
|
||||
@@ -91,13 +93,17 @@ impl ChannelManager {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
/// unified with config function that support tls config
|
||||
/// use [`load_tls_config`] to load tls config from file system
|
||||
pub fn with_config(config: ChannelConfig, tls_config: Option<ClientTlsConfig>) -> Self {
|
||||
/// Create a ChannelManager with configuration and optional TLS config
|
||||
///
|
||||
/// Use [`load_client_tls_config`] to create TLS configuration from `ClientTlsOption`.
|
||||
/// The TLS config supports both static (watch disabled) and dynamic reloading (watch enabled).
|
||||
/// If you want to use dynamic reloading, please **manually** invoke [`maybe_watch_client_tls_config`] after this method.
|
||||
pub fn with_config(
|
||||
config: ChannelConfig,
|
||||
reloadable_tls_config: Option<Arc<ReloadableClientTlsConfig>>,
|
||||
) -> Self {
|
||||
let mut inner = Inner::with_config(config.clone());
|
||||
if let Some(tls_config) = tls_config {
|
||||
inner.client_tls_config = Some(tls_config);
|
||||
}
|
||||
inner.reloadable_client_tls_config = reloadable_tls_config;
|
||||
Self {
|
||||
inner: Arc::new(inner),
|
||||
}
|
||||
@@ -172,8 +178,21 @@ impl ChannelManager {
|
||||
self.pool().retain_channel(f);
|
||||
}
|
||||
|
||||
/// Clear all channels to force reconnection.
|
||||
/// This should be called when TLS configuration changes to ensure new connections use updated certificates.
|
||||
pub fn clear_all_channels(&self) {
|
||||
self.pool().retain_channel(|_, _| false);
|
||||
}
|
||||
|
||||
fn build_endpoint(&self, addr: &str) -> Result<Endpoint> {
|
||||
let http_prefix = if self.inner.client_tls_config.is_some() {
|
||||
// Get the latest TLS config from reloadable config (which handles both static and dynamic cases)
|
||||
let tls_config = self
|
||||
.inner
|
||||
.reloadable_client_tls_config
|
||||
.as_ref()
|
||||
.and_then(|c| c.get_config());
|
||||
|
||||
let http_prefix = if tls_config.is_some() {
|
||||
"https"
|
||||
} else {
|
||||
"http"
|
||||
@@ -212,9 +231,9 @@ impl ChannelManager {
|
||||
if let Some(enabled) = self.config().http2_adaptive_window {
|
||||
endpoint = endpoint.http2_adaptive_window(enabled);
|
||||
}
|
||||
if let Some(tls_config) = &self.inner.client_tls_config {
|
||||
if let Some(tls_config) = tls_config {
|
||||
endpoint = endpoint
|
||||
.tls_config(tls_config.clone())
|
||||
.tls_config(tls_config)
|
||||
.context(CreateChannelSnafu { addr })?;
|
||||
}
|
||||
|
||||
@@ -248,7 +267,7 @@ impl ChannelManager {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_tls_config(tls_option: Option<&ClientTlsOption>) -> Result<Option<ClientTlsConfig>> {
|
||||
fn load_tls_config(tls_option: Option<&ClientTlsOption>) -> Result<Option<ClientTlsConfig>> {
|
||||
let path_config = match tls_option {
|
||||
Some(path_config) if path_config.enabled => path_config,
|
||||
_ => return Ok(None),
|
||||
@@ -276,13 +295,69 @@ pub fn load_tls_config(tls_option: Option<&ClientTlsOption>) -> Result<Option<Cl
|
||||
Ok(Some(tls_config))
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
impl TlsConfigLoader<ClientTlsConfig> for ClientTlsOption {
|
||||
type Error = crate::error::Error;
|
||||
|
||||
fn load(&self) -> Result<Option<ClientTlsConfig>> {
|
||||
load_tls_config(Some(self))
|
||||
}
|
||||
|
||||
fn watch_paths(&self) -> Vec<&Path> {
|
||||
let mut paths = Vec::new();
|
||||
if let Some(cert_path) = &self.client_cert_path {
|
||||
paths.push(Path::new(cert_path.as_str()));
|
||||
}
|
||||
if let Some(key_path) = &self.client_key_path {
|
||||
paths.push(Path::new(key_path.as_str()));
|
||||
}
|
||||
if let Some(ca_path) = &self.server_ca_cert_path {
|
||||
paths.push(Path::new(ca_path.as_str()));
|
||||
}
|
||||
paths
|
||||
}
|
||||
|
||||
fn watch_enabled(&self) -> bool {
|
||||
self.enabled && self.watch
|
||||
}
|
||||
}
|
||||
|
||||
/// Type alias for client-side reloadable TLS config
|
||||
pub type ReloadableClientTlsConfig = ReloadableTlsConfig<ClientTlsConfig, ClientTlsOption>;
|
||||
|
||||
/// Load client TLS configuration from `ClientTlsOption` and return a `ReloadableClientTlsConfig`.
|
||||
/// This is the primary way to create TLS configuration for the ChannelManager.
|
||||
pub fn load_client_tls_config(
|
||||
tls_option: Option<ClientTlsOption>,
|
||||
) -> Result<Option<Arc<ReloadableClientTlsConfig>>> {
|
||||
match tls_option {
|
||||
Some(option) if option.enabled => {
|
||||
let reloadable = ReloadableClientTlsConfig::try_new(option)?;
|
||||
Ok(Some(Arc::new(reloadable)))
|
||||
}
|
||||
_ => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn maybe_watch_client_tls_config(
|
||||
client_tls_config: Arc<ReloadableClientTlsConfig>,
|
||||
channel_manager: ChannelManager,
|
||||
) -> Result<()> {
|
||||
maybe_watch_tls_config(client_tls_config, move || {
|
||||
// Clear all existing channels to force reconnection with new certificates
|
||||
channel_manager.clear_all_channels();
|
||||
info!("Cleared all existing channels to use new TLS certificates.");
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
pub struct ClientTlsOption {
|
||||
/// Whether to enable TLS for client.
|
||||
pub enabled: bool,
|
||||
pub server_ca_cert_path: Option<String>,
|
||||
pub client_cert_path: Option<String>,
|
||||
pub client_key_path: Option<String>,
|
||||
#[serde(default)]
|
||||
pub watch: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
@@ -602,6 +677,7 @@ mod tests {
|
||||
server_ca_cert_path: Some("some_server_path".to_string()),
|
||||
client_cert_path: Some("some_cert_path".to_string()),
|
||||
client_key_path: Some("some_key_path".to_string()),
|
||||
watch: false,
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
@@ -623,6 +699,7 @@ mod tests {
|
||||
server_ca_cert_path: Some("some_server_path".to_string()),
|
||||
client_cert_path: Some("some_cert_path".to_string()),
|
||||
client_key_path: Some("some_key_path".to_string()),
|
||||
watch: false,
|
||||
}),
|
||||
max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE,
|
||||
max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE,
|
||||
|
||||
@@ -38,6 +38,15 @@ pub enum Error {
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to watch config file path: {}", path))]
|
||||
FileWatch {
|
||||
path: String,
|
||||
#[snafu(source)]
|
||||
error: notify::Error,
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display(
|
||||
"Write type mismatch, column name: {}, expected: {}, actual: {}",
|
||||
column_name,
|
||||
@@ -108,6 +117,7 @@ impl ErrorExt for Error {
|
||||
match self {
|
||||
Error::InvalidTlsConfig { .. }
|
||||
| Error::InvalidConfigFilePath { .. }
|
||||
| Error::FileWatch { .. }
|
||||
| Error::TypeMismatch { .. }
|
||||
| Error::InvalidFlightData { .. }
|
||||
| Error::NotSupported { .. } => StatusCode::InvalidArguments,
|
||||
|
||||
@@ -16,6 +16,7 @@ pub mod channel_manager;
|
||||
pub mod error;
|
||||
pub mod flight;
|
||||
pub mod precision;
|
||||
pub mod reloadable_tls;
|
||||
pub mod select;
|
||||
|
||||
pub use arrow_flight::FlightData;
|
||||
|
||||
163
src/common/grpc/src/reloadable_tls.rs
Normal file
163
src/common/grpc/src/reloadable_tls.rs
Normal file
@@ -0,0 +1,163 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::path::Path;
|
||||
use std::result::Result as StdResult;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::mpsc::channel;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use common_telemetry::{error, info};
|
||||
use notify::{EventKind, RecursiveMode, Watcher};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::error::{FileWatchSnafu, Result};
|
||||
|
||||
/// A trait for loading TLS configuration from an option type
|
||||
pub trait TlsConfigLoader<T> {
|
||||
type Error;
|
||||
|
||||
/// Load the TLS configuration
|
||||
fn load(&self) -> StdResult<Option<T>, Self::Error>;
|
||||
|
||||
/// Get paths to certificate files for watching
|
||||
fn watch_paths(&self) -> Vec<&Path>;
|
||||
|
||||
/// Check if watching is enabled
|
||||
fn watch_enabled(&self) -> bool;
|
||||
}
|
||||
|
||||
/// A mutable container for TLS config
|
||||
///
|
||||
/// This struct allows dynamic reloading of certificates and keys.
|
||||
/// It's generic over the config type (e.g., ServerConfig, ClientTlsConfig)
|
||||
/// and the option type (e.g., TlsOption, ClientTlsOption).
|
||||
#[derive(Debug)]
|
||||
pub struct ReloadableTlsConfig<T, O>
|
||||
where
|
||||
O: TlsConfigLoader<T>,
|
||||
{
|
||||
tls_option: O,
|
||||
config: RwLock<Option<T>>,
|
||||
version: AtomicUsize,
|
||||
}
|
||||
|
||||
impl<T, O> ReloadableTlsConfig<T, O>
|
||||
where
|
||||
O: TlsConfigLoader<T>,
|
||||
{
|
||||
/// Create config by loading configuration from the option type
|
||||
pub fn try_new(tls_option: O) -> StdResult<Self, O::Error> {
|
||||
let config = tls_option.load()?;
|
||||
Ok(Self {
|
||||
tls_option,
|
||||
config: RwLock::new(config),
|
||||
version: AtomicUsize::new(0),
|
||||
})
|
||||
}
|
||||
|
||||
/// Reread certificates and keys from file system.
|
||||
pub fn reload(&self) -> StdResult<(), O::Error> {
|
||||
let config = self.tls_option.load()?;
|
||||
*self.config.write().unwrap() = config;
|
||||
self.version.fetch_add(1, Ordering::Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the config held by this container
|
||||
pub fn get_config(&self) -> Option<T>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
self.config.read().unwrap().clone()
|
||||
}
|
||||
|
||||
/// Get associated option
|
||||
pub fn get_tls_option(&self) -> &O {
|
||||
&self.tls_option
|
||||
}
|
||||
|
||||
/// Get version of current config
|
||||
///
|
||||
/// this version will auto increase when config get reloaded.
|
||||
pub fn get_version(&self) -> usize {
|
||||
self.version.load(Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
/// Watch TLS configuration files for changes and reload automatically
|
||||
///
|
||||
/// This is a generic function that works with any ReloadableTlsConfig.
|
||||
/// When changes are detected, it calls the provided callback after reloading.
|
||||
///
|
||||
/// T: the original TLS config
|
||||
/// O: the compiled TLS option
|
||||
/// F: the hook function to be called after reloading
|
||||
/// E: the error type for the loading operation
|
||||
pub fn maybe_watch_tls_config<T, O, F, E>(
|
||||
tls_config: Arc<ReloadableTlsConfig<T, O>>,
|
||||
on_reload: F,
|
||||
) -> Result<()>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
O: TlsConfigLoader<T, Error = E> + Send + Sync + 'static,
|
||||
E: std::error::Error + Send + Sync + 'static,
|
||||
F: Fn() + Send + 'static,
|
||||
{
|
||||
if !tls_config.get_tls_option().watch_enabled() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let tls_config_for_watcher = tls_config.clone();
|
||||
|
||||
let (tx, rx) = channel::<notify::Result<notify::Event>>();
|
||||
let mut watcher = notify::recommended_watcher(tx).context(FileWatchSnafu { path: "<none>" })?;
|
||||
|
||||
// Watch all paths returned by the TlsConfigLoader
|
||||
for path in tls_config.get_tls_option().watch_paths() {
|
||||
watcher
|
||||
.watch(path, RecursiveMode::NonRecursive)
|
||||
.with_context(|_| FileWatchSnafu {
|
||||
path: path.display().to_string(),
|
||||
})?;
|
||||
}
|
||||
|
||||
info!("Spawning background task for watching TLS cert/key file changes");
|
||||
std::thread::spawn(move || {
|
||||
let _watcher = watcher;
|
||||
loop {
|
||||
match rx.recv() {
|
||||
Ok(Ok(event)) => {
|
||||
if let EventKind::Modify(_) | EventKind::Create(_) = event.kind {
|
||||
info!("Detected TLS cert/key file change: {:?}", event);
|
||||
if let Err(err) = tls_config_for_watcher.reload() {
|
||||
error!("Failed to reload TLS config: {}", err);
|
||||
} else {
|
||||
info!("Reloaded TLS cert/key file successfully.");
|
||||
on_reload();
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Err(err)) => {
|
||||
error!("Failed to watch TLS cert/key file: {}", err);
|
||||
}
|
||||
Err(err) => {
|
||||
error!("TLS cert/key file watcher channel closed: {}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -13,14 +13,15 @@
|
||||
// limitations under the License.
|
||||
|
||||
use common_grpc::channel_manager::{
|
||||
ChannelConfig, ChannelManager, ClientTlsOption, load_tls_config,
|
||||
ChannelConfig, ChannelManager, ClientTlsOption, load_client_tls_config,
|
||||
maybe_watch_client_tls_config,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mtls_config() {
|
||||
// test no config
|
||||
let config = ChannelConfig::new();
|
||||
let re = load_tls_config(config.client_tls.as_ref());
|
||||
let re = load_client_tls_config(config.client_tls.clone());
|
||||
assert!(re.is_ok());
|
||||
assert!(re.unwrap().is_none());
|
||||
|
||||
@@ -30,9 +31,10 @@ async fn test_mtls_config() {
|
||||
server_ca_cert_path: Some("tests/tls/wrong_ca.pem".to_string()),
|
||||
client_cert_path: Some("tests/tls/wrong_client.pem".to_string()),
|
||||
client_key_path: Some("tests/tls/wrong_client.key".to_string()),
|
||||
watch: false,
|
||||
});
|
||||
|
||||
let re = load_tls_config(config.client_tls.as_ref());
|
||||
let re = load_client_tls_config(config.client_tls.clone());
|
||||
assert!(re.is_err());
|
||||
|
||||
// test corrupted file content
|
||||
@@ -41,9 +43,10 @@ async fn test_mtls_config() {
|
||||
server_ca_cert_path: Some("tests/tls/ca.pem".to_string()),
|
||||
client_cert_path: Some("tests/tls/client.pem".to_string()),
|
||||
client_key_path: Some("tests/tls/corrupted".to_string()),
|
||||
watch: false,
|
||||
});
|
||||
|
||||
let tls_config = load_tls_config(config.client_tls.as_ref()).unwrap();
|
||||
let tls_config = load_client_tls_config(config.client_tls.clone()).unwrap();
|
||||
let re = ChannelManager::with_config(config, tls_config);
|
||||
|
||||
let re = re.get("127.0.0.1:0");
|
||||
@@ -55,10 +58,112 @@ async fn test_mtls_config() {
|
||||
server_ca_cert_path: Some("tests/tls/ca.pem".to_string()),
|
||||
client_cert_path: Some("tests/tls/client.pem".to_string()),
|
||||
client_key_path: Some("tests/tls/client.key".to_string()),
|
||||
watch: false,
|
||||
});
|
||||
|
||||
let tls_config = load_tls_config(config.client_tls.as_ref()).unwrap();
|
||||
let tls_config = load_client_tls_config(config.client_tls.clone()).unwrap();
|
||||
let re = ChannelManager::with_config(config, tls_config);
|
||||
let re = re.get("127.0.0.1:0");
|
||||
let _ = re.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reloadable_client_tls_config() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let cert_path = dir.path().join("client.pem");
|
||||
let key_path = dir.path().join("client.key");
|
||||
|
||||
std::fs::copy("tests/tls/client.pem", &cert_path).expect("failed to copy cert to tmpdir");
|
||||
std::fs::copy("tests/tls/client.key", &key_path).expect("failed to copy key to tmpdir");
|
||||
|
||||
assert!(std::fs::exists(&cert_path).unwrap());
|
||||
assert!(std::fs::exists(&key_path).unwrap());
|
||||
|
||||
let client_tls_option = ClientTlsOption {
|
||||
enabled: true,
|
||||
server_ca_cert_path: Some("tests/tls/ca.pem".to_string()),
|
||||
client_cert_path: Some(
|
||||
cert_path
|
||||
.clone()
|
||||
.into_os_string()
|
||||
.into_string()
|
||||
.expect("failed to convert path to string"),
|
||||
),
|
||||
client_key_path: Some(
|
||||
key_path
|
||||
.clone()
|
||||
.into_os_string()
|
||||
.into_string()
|
||||
.expect("failed to convert path to string"),
|
||||
),
|
||||
watch: true,
|
||||
};
|
||||
|
||||
let reloadable_config = load_client_tls_config(Some(client_tls_option))
|
||||
.expect("failed to load tls config")
|
||||
.expect("tls config should be present");
|
||||
|
||||
let config = ChannelConfig::new();
|
||||
let manager = ChannelManager::with_config(config, Some(reloadable_config.clone()));
|
||||
|
||||
maybe_watch_client_tls_config(reloadable_config.clone(), manager.clone())
|
||||
.expect("failed to watch client config");
|
||||
|
||||
assert_eq!(0, reloadable_config.get_version());
|
||||
assert!(reloadable_config.get_config().is_some());
|
||||
|
||||
// Create a channel to verify it gets cleared on reload
|
||||
let _ = manager.get("127.0.0.1:0").expect("failed to get channel");
|
||||
|
||||
// Simulate file change by copying a different key file
|
||||
let tmp_file = key_path.with_extension("tmp");
|
||||
std::fs::copy("tests/tls/server.key", &tmp_file).expect("Failed to copy temp key file");
|
||||
std::fs::rename(&tmp_file, &key_path).expect("Failed to rename temp key file");
|
||||
|
||||
const MAX_RETRIES: usize = 30;
|
||||
let mut retries = 0;
|
||||
let mut version_updated = false;
|
||||
|
||||
while retries < MAX_RETRIES {
|
||||
if reloadable_config.get_version() > 0 {
|
||||
version_updated = true;
|
||||
break;
|
||||
}
|
||||
std::thread::sleep(std::time::Duration::from_millis(100));
|
||||
retries += 1;
|
||||
}
|
||||
|
||||
assert!(version_updated, "TLS config did not reload in time");
|
||||
assert!(reloadable_config.get_version() > 0);
|
||||
assert!(reloadable_config.get_config().is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_channel_manager_with_reloadable_tls() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let client_tls_option = ClientTlsOption {
|
||||
enabled: true,
|
||||
server_ca_cert_path: Some("tests/tls/ca.pem".to_string()),
|
||||
client_cert_path: Some("tests/tls/client.pem".to_string()),
|
||||
client_key_path: Some("tests/tls/client.key".to_string()),
|
||||
watch: false,
|
||||
};
|
||||
|
||||
let reloadable_config = load_client_tls_config(Some(client_tls_option))
|
||||
.expect("failed to load tls config")
|
||||
.expect("tls config should be present");
|
||||
|
||||
let config = ChannelConfig::new();
|
||||
let manager = ChannelManager::with_config(config, Some(reloadable_config.clone()));
|
||||
|
||||
// Test that we can get a channel
|
||||
let channel = manager.get("127.0.0.1:0");
|
||||
assert!(channel.is_ok());
|
||||
|
||||
// Test that config is properly set
|
||||
assert_eq!(0, reloadable_config.get_version());
|
||||
assert!(reloadable_config.get_config().is_some());
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ use api::v1::query_request::Query;
|
||||
use api::v1::{CreateTableExpr, QueryRequest};
|
||||
use client::{Client, Database};
|
||||
use common_error::ext::{BoxedError, ErrorExt};
|
||||
use common_grpc::channel_manager::{ChannelConfig, ChannelManager, load_tls_config};
|
||||
use common_grpc::channel_manager::{ChannelConfig, ChannelManager, load_client_tls_config};
|
||||
use common_meta::cluster::{NodeInfo, NodeInfoKey, Role};
|
||||
use common_meta::peer::Peer;
|
||||
use common_meta::rpc::store::RangeRequest;
|
||||
@@ -124,7 +124,7 @@ impl FrontendClient {
|
||||
.connect_timeout(batch_opts.grpc_conn_timeout)
|
||||
.timeout(batch_opts.query_timeout);
|
||||
|
||||
let tls_config = load_tls_config(batch_opts.frontend_tls.as_ref())
|
||||
let tls_config = load_client_tls_config(batch_opts.frontend_tls.clone())
|
||||
.context(InvalidClientConfigSnafu)?;
|
||||
ChannelManager::with_config(cfg, tls_config)
|
||||
},
|
||||
|
||||
@@ -36,7 +36,7 @@ use servers::postgres::PostgresServer;
|
||||
use servers::query_handler::grpc::ServerGrpcQueryHandlerAdapter;
|
||||
use servers::query_handler::sql::ServerSqlQueryHandlerAdapter;
|
||||
use servers::server::{Server, ServerHandlers};
|
||||
use servers::tls::{ReloadableTlsServerConfig, maybe_watch_tls_config};
|
||||
use servers::tls::{ReloadableTlsServerConfig, maybe_watch_server_tls_config};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::error::{self, Result, StartServerSnafu, TomlFormatSnafu};
|
||||
@@ -258,7 +258,7 @@ where
|
||||
);
|
||||
|
||||
// will not watch if watch is disabled in tls option
|
||||
maybe_watch_tls_config(tls_server_config.clone()).context(StartServerSnafu)?;
|
||||
maybe_watch_server_tls_config(tls_server_config.clone()).context(StartServerSnafu)?;
|
||||
|
||||
let mysql_server = MysqlServer::create_server(
|
||||
common_runtime::global_runtime(),
|
||||
@@ -287,7 +287,7 @@ where
|
||||
ReloadableTlsServerConfig::try_new(opts.tls.clone()).context(StartServerSnafu)?,
|
||||
);
|
||||
|
||||
maybe_watch_tls_config(tls_server_config.clone()).context(StartServerSnafu)?;
|
||||
maybe_watch_server_tls_config(tls_server_config.clone()).context(StartServerSnafu)?;
|
||||
|
||||
let pg_server = Box::new(PostgresServer::new(
|
||||
ServerSqlQueryHandlerAdapter::arc(instance.clone()),
|
||||
|
||||
@@ -99,7 +99,7 @@ impl MysqlSpawnConfig {
|
||||
}
|
||||
|
||||
fn tls(&self) -> Option<Arc<ServerConfig>> {
|
||||
self.tls.get_server_config()
|
||||
self.tls.get_config()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ impl PostgresServer {
|
||||
let process_manager = self.process_manager.clone();
|
||||
accepting_stream.for_each(move |tcp_stream| {
|
||||
let io_runtime = io_runtime.clone();
|
||||
let tls_acceptor = tls_server_config.get_server_config().map(TlsAcceptor::from);
|
||||
let tls_acceptor = tls_server_config.get_config().map(TlsAcceptor::from);
|
||||
let handler_maker = handler_maker.clone();
|
||||
let process_id = process_manager.as_ref().map(|p| p.next_id()).unwrap_or(0);
|
||||
|
||||
|
||||
@@ -15,12 +15,10 @@
|
||||
use std::fs::File;
|
||||
use std::io::{BufReader, Error as IoError, ErrorKind};
|
||||
use std::path::Path;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::mpsc::channel;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_telemetry::{error, info};
|
||||
use notify::{EventKind, RecursiveMode, Watcher};
|
||||
use common_grpc::reloadable_tls::{ReloadableTlsConfig, TlsConfigLoader};
|
||||
use common_telemetry::error;
|
||||
use rustls::ServerConfig;
|
||||
use rustls_pemfile::{Item, certs, read_one};
|
||||
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
|
||||
@@ -28,7 +26,7 @@ use serde::{Deserialize, Serialize};
|
||||
use snafu::ResultExt;
|
||||
use strum::EnumString;
|
||||
|
||||
use crate::error::{FileWatchSnafu, InternalIoSnafu, Result};
|
||||
use crate::error::{InternalIoSnafu, Result};
|
||||
|
||||
/// TlsMode is used for Mysql and Postgres server start up.
|
||||
#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq, EnumString)]
|
||||
@@ -149,96 +147,34 @@ impl TlsOption {
|
||||
}
|
||||
}
|
||||
|
||||
/// A mutable container for TLS server config
|
||||
///
|
||||
/// This struct allows dynamic reloading of server certificates and keys
|
||||
pub struct ReloadableTlsServerConfig {
|
||||
tls_option: TlsOption,
|
||||
config: RwLock<Option<Arc<ServerConfig>>>,
|
||||
version: AtomicUsize,
|
||||
}
|
||||
impl TlsConfigLoader<Arc<ServerConfig>> for TlsOption {
|
||||
type Error = crate::error::Error;
|
||||
|
||||
impl ReloadableTlsServerConfig {
|
||||
/// Create server config by loading configuration from `TlsOption`
|
||||
pub fn try_new(tls_option: TlsOption) -> Result<ReloadableTlsServerConfig> {
|
||||
let server_config = tls_option.setup()?;
|
||||
Ok(Self {
|
||||
tls_option,
|
||||
config: RwLock::new(server_config.map(Arc::new)),
|
||||
version: AtomicUsize::new(0),
|
||||
})
|
||||
fn load(&self) -> Result<Option<Arc<ServerConfig>>> {
|
||||
Ok(self.setup()?.map(Arc::new))
|
||||
}
|
||||
|
||||
/// Reread server certificates and keys from file system.
|
||||
pub fn reload(&self) -> Result<()> {
|
||||
let server_config = self.tls_option.setup()?;
|
||||
*self.config.write().unwrap() = server_config.map(Arc::new);
|
||||
self.version.fetch_add(1, Ordering::Relaxed);
|
||||
Ok(())
|
||||
fn watch_paths(&self) -> Vec<&Path> {
|
||||
vec![self.cert_path(), self.key_path()]
|
||||
}
|
||||
|
||||
/// Get the server config hold by this container
|
||||
pub fn get_server_config(&self) -> Option<Arc<ServerConfig>> {
|
||||
self.config.read().unwrap().clone()
|
||||
}
|
||||
|
||||
/// Get associated `TlsOption`
|
||||
pub fn get_tls_option(&self) -> &TlsOption {
|
||||
&self.tls_option
|
||||
}
|
||||
|
||||
/// Get version of current config
|
||||
///
|
||||
/// this version will auto increase when server config get reloaded.
|
||||
pub fn get_version(&self) -> usize {
|
||||
self.version.load(Ordering::Relaxed)
|
||||
fn watch_enabled(&self) -> bool {
|
||||
self.mode != TlsMode::Disable && self.watch
|
||||
}
|
||||
}
|
||||
|
||||
pub fn maybe_watch_tls_config(tls_server_config: Arc<ReloadableTlsServerConfig>) -> Result<()> {
|
||||
if !tls_server_config.get_tls_option().watch_enabled() {
|
||||
return Ok(());
|
||||
}
|
||||
/// Type alias for server-side reloadable TLS config
|
||||
pub type ReloadableTlsServerConfig = ReloadableTlsConfig<Arc<ServerConfig>, TlsOption>;
|
||||
|
||||
let tls_server_config_for_watcher = tls_server_config.clone();
|
||||
|
||||
let (tx, rx) = channel::<notify::Result<notify::Event>>();
|
||||
let mut watcher = notify::recommended_watcher(tx).context(FileWatchSnafu { path: "<none>" })?;
|
||||
|
||||
let cert_path = tls_server_config.get_tls_option().cert_path();
|
||||
watcher
|
||||
.watch(cert_path, RecursiveMode::NonRecursive)
|
||||
.with_context(|_| FileWatchSnafu {
|
||||
path: cert_path.display().to_string(),
|
||||
})?;
|
||||
|
||||
let key_path = tls_server_config.get_tls_option().key_path();
|
||||
watcher
|
||||
.watch(key_path, RecursiveMode::NonRecursive)
|
||||
.with_context(|_| FileWatchSnafu {
|
||||
path: key_path.display().to_string(),
|
||||
})?;
|
||||
|
||||
std::thread::spawn(move || {
|
||||
let _watcher = watcher;
|
||||
while let Ok(res) = rx.recv() {
|
||||
if let Ok(event) = res {
|
||||
match event.kind {
|
||||
EventKind::Modify(_) | EventKind::Create(_) => {
|
||||
info!("Detected TLS cert/key file change: {:?}", event);
|
||||
if let Err(err) = tls_server_config_for_watcher.reload() {
|
||||
error!(err; "Failed to reload TLS server config");
|
||||
} else {
|
||||
info!("Reloaded TLS cert/key file successfully.");
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
/// Convenience function for watching server TLS configuration
|
||||
pub fn maybe_watch_server_tls_config(
|
||||
tls_server_config: Arc<ReloadableTlsServerConfig>,
|
||||
) -> Result<()> {
|
||||
common_grpc::reloadable_tls::maybe_watch_tls_config(tls_server_config, || {}).map_err(|e| {
|
||||
crate::error::Error::Internal {
|
||||
err_msg: format!("Failed to watch TLS config: {}", e),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -434,10 +370,11 @@ mod tests {
|
||||
let server_config = Arc::new(
|
||||
ReloadableTlsServerConfig::try_new(server_tls).expect("failed to create server config"),
|
||||
);
|
||||
maybe_watch_tls_config(server_config.clone()).expect("failed to watch server config");
|
||||
maybe_watch_server_tls_config(server_config.clone())
|
||||
.expect("failed to watch server config");
|
||||
|
||||
assert_eq!(0, server_config.get_version());
|
||||
assert!(server_config.get_server_config().is_some());
|
||||
assert!(server_config.get_config().is_some());
|
||||
|
||||
let tmp_file = key_path.with_extension("tmp");
|
||||
std::fs::copy("tests/ssl/server-pkcs8.key", &tmp_file)
|
||||
@@ -459,6 +396,6 @@ mod tests {
|
||||
|
||||
assert!(version_updated, "TLS config did not reload in time");
|
||||
assert!(server_config.get_version() > 0);
|
||||
assert!(server_config.get_server_config().is_some());
|
||||
assert!(server_config.get_config().is_some());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -971,6 +971,7 @@ pub async fn test_grpc_tls_config(store_type: StorageType) {
|
||||
server_ca_cert_path: Some(ca_path),
|
||||
client_cert_path: Some(client_cert_path),
|
||||
client_key_path: Some(client_key_path),
|
||||
watch: false,
|
||||
};
|
||||
{
|
||||
let grpc_client =
|
||||
|
||||
Reference in New Issue
Block a user