From 64e74916b91e4385fc27e36407b24d6198cc4b7a Mon Sep 17 00:00:00 2001 From: shuiyisong <113876041+shuiyisong@users.noreply.github.com> Date: Mon, 15 Dec 2025 10:53:21 +0800 Subject: [PATCH] fix: TLS option validate and merge (#7401) * chore: unify gRPC server tls behaviour Signed-off-by: shuiyisong * fix: test Signed-off-by: shuiyisong * chore: add validate and merge tls Signed-off-by: shuiyisong * chore: remove mut in func sig and add back test Signed-off-by: shuiyisong * fix: test Signed-off-by: shuiyisong --------- Signed-off-by: shuiyisong --- src/cmd/src/frontend.rs | 8 +- src/cmd/src/standalone.rs | 12 +-- src/servers/src/grpc/builder.rs | 8 +- src/servers/src/tls.rs | 160 ++++++++++++++++++++++++++++++++ tests-integration/tests/grpc.rs | 3 +- 5 files changed, 174 insertions(+), 17 deletions(-) diff --git a/src/cmd/src/frontend.rs b/src/cmd/src/frontend.rs index 172b86d0a8..fa36a99ed4 100644 --- a/src/cmd/src/frontend.rs +++ b/src/cmd/src/frontend.rs @@ -52,7 +52,7 @@ use plugins::frontend::context::{ }; use servers::addrs; use servers::grpc::GrpcOptions; -use servers::tls::{TlsMode, TlsOption}; +use servers::tls::{TlsMode, TlsOption, merge_tls_option}; use snafu::{OptionExt, ResultExt}; use tracing_appender::non_blocking::WorkerGuard; @@ -256,7 +256,7 @@ impl StartCommand { if let Some(addr) = &self.rpc_bind_addr { opts.grpc.bind_addr.clone_from(addr); - opts.grpc.tls = tls_opts.clone(); + opts.grpc.tls = merge_tls_option(&opts.grpc.tls, tls_opts.clone()); } if let Some(addr) = &self.rpc_server_addr { @@ -291,13 +291,13 @@ impl StartCommand { if let Some(addr) = &self.mysql_addr { opts.mysql.enable = true; opts.mysql.addr.clone_from(addr); - opts.mysql.tls = tls_opts.clone(); + opts.mysql.tls = merge_tls_option(&opts.mysql.tls, tls_opts.clone()); } if let Some(addr) = &self.postgres_addr { opts.postgres.enable = true; opts.postgres.addr.clone_from(addr); - opts.postgres.tls = tls_opts; + opts.postgres.tls = merge_tls_option(&opts.postgres.tls, tls_opts.clone()); } if let Some(enable) = self.influxdb_enable { diff --git a/src/cmd/src/standalone.rs b/src/cmd/src/standalone.rs index 1ef16a830f..012680ac08 100644 --- a/src/cmd/src/standalone.rs +++ b/src/cmd/src/standalone.rs @@ -62,7 +62,7 @@ use plugins::frontend::context::{ CatalogManagerConfigureContext, StandaloneCatalogManagerConfigureContext, }; use plugins::standalone::context::DdlManagerConfigureContext; -use servers::tls::{TlsMode, TlsOption}; +use servers::tls::{TlsMode, TlsOption, merge_tls_option}; use snafu::ResultExt; use standalone::StandaloneInformationExtension; use standalone::options::StandaloneOptions; @@ -293,19 +293,20 @@ impl StartCommand { ), }.fail(); } - opts.grpc.bind_addr.clone_from(addr) + opts.grpc.bind_addr.clone_from(addr); + opts.grpc.tls = merge_tls_option(&opts.grpc.tls, tls_opts.clone()); } if let Some(addr) = &self.mysql_addr { opts.mysql.enable = true; opts.mysql.addr.clone_from(addr); - opts.mysql.tls = tls_opts.clone(); + opts.mysql.tls = merge_tls_option(&opts.mysql.tls, tls_opts.clone()); } if let Some(addr) = &self.postgres_addr { opts.postgres.enable = true; opts.postgres.addr.clone_from(addr); - opts.postgres.tls = tls_opts; + opts.postgres.tls = merge_tls_option(&opts.postgres.tls, tls_opts.clone()); } if self.influxdb_enable { @@ -765,7 +766,6 @@ mod tests { user_provider: Some("static_user_provider:cmd:test=test".to_string()), mysql_addr: Some("127.0.0.1:4002".to_string()), postgres_addr: Some("127.0.0.1:4003".to_string()), - tls_watch: true, ..Default::default() }; @@ -782,8 +782,6 @@ mod tests { assert_eq!("./greptimedb_data/test/logs", opts.logging.dir); assert_eq!("debug", opts.logging.level.unwrap()); - assert!(opts.mysql.tls.watch); - assert!(opts.postgres.tls.watch); } #[test] diff --git a/src/servers/src/grpc/builder.rs b/src/servers/src/grpc/builder.rs index e90ebe1fb5..129f07c3c5 100644 --- a/src/servers/src/grpc/builder.rs +++ b/src/servers/src/grpc/builder.rs @@ -23,8 +23,9 @@ use auth::UserProviderRef; use axum::extract::Request; use axum::response::IntoResponse; use axum::routing::Route; -use common_grpc::error::{Error, InvalidConfigFilePathSnafu, Result}; +use common_grpc::error::{InvalidConfigFilePathSnafu, Result}; use common_runtime::Runtime; +use common_telemetry::warn; use otel_arrow_rust::proto::opentelemetry::arrow::v1::arrow_metrics_service_server::ArrowMetricsServiceServer; use snafu::ResultExt; use tokio::sync::Mutex; @@ -195,10 +196,7 @@ impl GrpcServerBuilder { // tonic does not support watching for tls config changes // so we don't support it either for now if tls_option.watch { - return Err(Error::NotSupported { - feat: "Certificates watch and reloading for gRPC is not supported at the moment" - .to_string(), - }); + warn!("Certificates watch and reloading for gRPC is NOT supported at the moment"); } self.tls_config = if tls_option.should_force_tls() { let cert = std::fs::read_to_string(tls_option.cert_path) diff --git a/src/servers/src/tls.rs b/src/servers/src/tls.rs index f4afc477b7..ba4025ab74 100644 --- a/src/servers/src/tls.rs +++ b/src/servers/src/tls.rs @@ -91,6 +91,47 @@ impl TlsOption { tls_option } + /// Validates the TLS configuration. + /// + /// Returns an error if: + /// - TLS mode is enabled (not `Disable`) but `cert_path` or `key_path` is empty + /// - TLS mode is `VerifyCa` or `VerifyFull` but `ca_cert_path` is empty + pub fn validate(&self) -> Result<()> { + if self.mode == TlsMode::Disable { + return Ok(()); + } + + // When TLS is enabled, cert_path and key_path are required + if self.cert_path.is_empty() { + return Err(crate::error::Error::Internal { + err_msg: format!( + "TLS mode is {:?} but cert_path is not configured", + self.mode + ), + }); + } + + if self.key_path.is_empty() { + return Err(crate::error::Error::Internal { + err_msg: format!("TLS mode is {:?} but key_path is not configured", self.mode), + }); + } + + // For VerifyCa and VerifyFull modes, ca_cert_path is required for client verification + if matches!(self.mode, TlsMode::VerifyCa | TlsMode::VerifyFull) + && self.ca_cert_path.is_empty() + { + return Err(crate::error::Error::Internal { + err_msg: format!( + "TLS mode is {:?} but ca_cert_path is not configured", + self.mode + ), + }); + } + + Ok(()) + } + pub fn setup(&self) -> Result> { if let TlsMode::Disable = self.mode { return Ok(None); @@ -147,6 +188,13 @@ impl TlsOption { } } +pub fn merge_tls_option(main: &TlsOption, other: TlsOption) -> TlsOption { + if other.mode != TlsMode::Disable && other.validate().is_ok() { + return other; + } + main.clone() +} + impl TlsConfigLoader> for TlsOption { type Error = crate::error::Error; @@ -183,6 +231,118 @@ mod tests { use crate::install_ring_crypto_provider; use crate::tls::TlsMode::Disable; + #[test] + fn test_validate_disable_mode() { + let tls = TlsOption { + mode: TlsMode::Disable, + cert_path: String::new(), + key_path: String::new(), + ca_cert_path: String::new(), + watch: false, + }; + assert!(tls.validate().is_ok()); + } + + #[test] + fn test_validate_missing_cert_path() { + let tls = TlsOption { + mode: TlsMode::Require, + cert_path: String::new(), + key_path: "/path/to/key".to_string(), + ca_cert_path: String::new(), + watch: false, + }; + let err = tls.validate().unwrap_err(); + assert!(err.to_string().contains("cert_path")); + } + + #[test] + fn test_validate_missing_key_path() { + let tls = TlsOption { + mode: TlsMode::Require, + cert_path: "/path/to/cert".to_string(), + key_path: String::new(), + ca_cert_path: String::new(), + watch: false, + }; + let err = tls.validate().unwrap_err(); + assert!(err.to_string().contains("key_path")); + } + + #[test] + fn test_validate_require_mode_success() { + let tls = TlsOption { + mode: TlsMode::Require, + cert_path: "/path/to/cert".to_string(), + key_path: "/path/to/key".to_string(), + ca_cert_path: String::new(), + watch: false, + }; + assert!(tls.validate().is_ok()); + } + + #[test] + fn test_validate_verify_ca_missing_ca_cert() { + let tls = TlsOption { + mode: TlsMode::VerifyCa, + cert_path: "/path/to/cert".to_string(), + key_path: "/path/to/key".to_string(), + ca_cert_path: String::new(), + watch: false, + }; + let err = tls.validate().unwrap_err(); + assert!(err.to_string().contains("ca_cert_path")); + } + + #[test] + fn test_validate_verify_full_missing_ca_cert() { + let tls = TlsOption { + mode: TlsMode::VerifyFull, + cert_path: "/path/to/cert".to_string(), + key_path: "/path/to/key".to_string(), + ca_cert_path: String::new(), + watch: false, + }; + let err = tls.validate().unwrap_err(); + assert!(err.to_string().contains("ca_cert_path")); + } + + #[test] + fn test_validate_verify_ca_success() { + let tls = TlsOption { + mode: TlsMode::VerifyCa, + cert_path: "/path/to/cert".to_string(), + key_path: "/path/to/key".to_string(), + ca_cert_path: "/path/to/ca".to_string(), + watch: false, + }; + assert!(tls.validate().is_ok()); + } + + #[test] + fn test_validate_verify_full_success() { + let tls = TlsOption { + mode: TlsMode::VerifyFull, + cert_path: "/path/to/cert".to_string(), + key_path: "/path/to/key".to_string(), + ca_cert_path: "/path/to/ca".to_string(), + watch: false, + }; + assert!(tls.validate().is_ok()); + } + + #[test] + fn test_validate_prefer_mode() { + let tls = TlsOption { + mode: TlsMode::Prefer, + cert_path: "/path/to/cert".to_string(), + key_path: "/path/to/key".to_string(), + ca_cert_path: String::new(), + watch: false, + }; + assert!(tls.validate().is_ok()); + } + #[test] fn test_new_tls_option() { assert_eq!( diff --git a/tests-integration/tests/grpc.rs b/tests-integration/tests/grpc.rs index 52cf6529b4..447d5afe50 100644 --- a/tests-integration/tests/grpc.rs +++ b/tests-integration/tests/grpc.rs @@ -1010,7 +1010,8 @@ pub async fn test_grpc_tls_config(store_type: StorageType) { let runtime = Runtime::builder().build().unwrap(); let grpc_builder = GrpcServerBuilder::new(config.clone(), runtime).with_tls_config(config.tls); - assert!(grpc_builder.is_err()); + // ok but print warning + assert!(grpc_builder.is_ok()); } let _ = fe_grpc_server.shutdown().await;