diff --git a/src/cmd/src/frontend.rs b/src/cmd/src/frontend.rs index 172b86d0a8..b939f42c8a 100644 --- a/src/cmd/src/frontend.rs +++ b/src/cmd/src/frontend.rs @@ -52,7 +52,7 @@ use plugins::frontend::context::{ }; use servers::addrs; use servers::grpc::GrpcOptions; -use servers::tls::{TlsMode, TlsOption}; +use servers::tls::{TlsMode, TlsOption, merge_tls_option}; use snafu::{OptionExt, ResultExt}; use tracing_appender::non_blocking::WorkerGuard; @@ -256,7 +256,7 @@ impl StartCommand { if let Some(addr) = &self.rpc_bind_addr { opts.grpc.bind_addr.clone_from(addr); - opts.grpc.tls = tls_opts.clone(); + opts.grpc.tls = merge_tls_option(&opts.grpc.tls, tls_opts.clone()); } if let Some(addr) = &self.rpc_server_addr { @@ -291,13 +291,13 @@ impl StartCommand { if let Some(addr) = &self.mysql_addr { opts.mysql.enable = true; opts.mysql.addr.clone_from(addr); - opts.mysql.tls = tls_opts.clone(); + opts.mysql.tls = merge_tls_option(&opts.mysql.tls, tls_opts.clone()); } if let Some(addr) = &self.postgres_addr { opts.postgres.enable = true; opts.postgres.addr.clone_from(addr); - opts.postgres.tls = tls_opts; + opts.postgres.tls = merge_tls_option(&opts.postgres.tls, tls_opts.clone()); } if let Some(enable) = self.influxdb_enable { @@ -340,7 +340,7 @@ impl StartCommand { let mut opts = opts.component; opts.grpc.detect_server_addr(); let mut plugins = Plugins::new(); - plugins::setup_frontend_plugins(&mut plugins, &plugin_opts, &opts) + plugins::setup_frontend_plugins(&mut plugins, &plugin_opts, &mut opts) .await .context(error::StartFrontendSnafu)?; @@ -591,7 +591,7 @@ mod tests { #[tokio::test] async fn test_try_from_start_command_to_anymap() { - let fe_opts = frontend::frontend::FrontendOptions { + let mut fe_opts = frontend::frontend::FrontendOptions { http: HttpOptions { disable_dashboard: false, ..Default::default() @@ -601,7 +601,7 @@ mod tests { }; let mut plugins = Plugins::new(); - plugins::setup_frontend_plugins(&mut plugins, &[], &fe_opts) + plugins::setup_frontend_plugins(&mut plugins, &[], &mut fe_opts) .await .unwrap(); diff --git a/src/cmd/src/standalone.rs b/src/cmd/src/standalone.rs index a47fd13cec..f10da156a7 100644 --- a/src/cmd/src/standalone.rs +++ b/src/cmd/src/standalone.rs @@ -62,7 +62,7 @@ use plugins::frontend::context::{ CatalogManagerConfigureContext, StandaloneCatalogManagerConfigureContext, }; use plugins::standalone::context::DdlManagerConfigureContext; -use servers::tls::{TlsMode, TlsOption}; +use servers::tls::{TlsMode, TlsOption, merge_tls_option}; use snafu::ResultExt; use standalone::StandaloneInformationExtension; use standalone::options::StandaloneOptions; @@ -294,19 +294,19 @@ impl StartCommand { }.fail(); } opts.grpc.bind_addr.clone_from(addr); - opts.grpc.tls = tls_opts.clone(); + opts.grpc.tls = merge_tls_option(&opts.grpc.tls, tls_opts.clone()); } if let Some(addr) = &self.mysql_addr { opts.mysql.enable = true; opts.mysql.addr.clone_from(addr); - opts.mysql.tls = tls_opts.clone(); + opts.mysql.tls = merge_tls_option(&opts.mysql.tls, tls_opts.clone()); } if let Some(addr) = &self.postgres_addr { opts.postgres.enable = true; opts.postgres.addr.clone_from(addr); - opts.postgres.tls = tls_opts; + opts.postgres.tls = merge_tls_option(&opts.postgres.tls, tls_opts.clone()); } if self.influxdb_enable { @@ -350,10 +350,10 @@ impl StartCommand { .context(error::BuildCliSnafu)?; opts.grpc.detect_server_addr(); - let fe_opts = opts.frontend_options(); + let mut fe_opts = opts.frontend_options(); let dn_opts = opts.datanode_options(); - plugins::setup_frontend_plugins(&mut plugins, &plugin_opts, &fe_opts) + plugins::setup_frontend_plugins(&mut plugins, &plugin_opts, &mut fe_opts) .await .context(error::StartFrontendSnafu)?; @@ -625,13 +625,13 @@ mod tests { #[tokio::test] async fn test_try_from_start_command_to_anymap() { - let fe_opts = FrontendOptions { + let mut fe_opts = FrontendOptions { user_provider: Some("static_user_provider:cmd:test=test".to_string()), ..Default::default() }; let mut plugins = Plugins::new(); - plugins::setup_frontend_plugins(&mut plugins, &[], &fe_opts) + plugins::setup_frontend_plugins(&mut plugins, &[], &mut fe_opts) .await .unwrap(); diff --git a/src/plugins/src/frontend.rs b/src/plugins/src/frontend.rs index 0d1c1af7b9..2bb97efbfd 100644 --- a/src/plugins/src/frontend.rs +++ b/src/plugins/src/frontend.rs @@ -24,7 +24,7 @@ use crate::options::PluginOptions; pub async fn setup_frontend_plugins( plugins: &mut Plugins, _plugin_options: &[PluginOptions], - fe_opts: &FrontendOptions, + fe_opts: &mut FrontendOptions, ) -> Result<()> { if let Some(user_provider) = fe_opts.user_provider.as_ref() { let provider = diff --git a/src/servers/src/tls.rs b/src/servers/src/tls.rs index f4afc477b7..ba4025ab74 100644 --- a/src/servers/src/tls.rs +++ b/src/servers/src/tls.rs @@ -91,6 +91,47 @@ impl TlsOption { tls_option } + /// Validates the TLS configuration. + /// + /// Returns an error if: + /// - TLS mode is enabled (not `Disable`) but `cert_path` or `key_path` is empty + /// - TLS mode is `VerifyCa` or `VerifyFull` but `ca_cert_path` is empty + pub fn validate(&self) -> Result<()> { + if self.mode == TlsMode::Disable { + return Ok(()); + } + + // When TLS is enabled, cert_path and key_path are required + if self.cert_path.is_empty() { + return Err(crate::error::Error::Internal { + err_msg: format!( + "TLS mode is {:?} but cert_path is not configured", + self.mode + ), + }); + } + + if self.key_path.is_empty() { + return Err(crate::error::Error::Internal { + err_msg: format!("TLS mode is {:?} but key_path is not configured", self.mode), + }); + } + + // For VerifyCa and VerifyFull modes, ca_cert_path is required for client verification + if matches!(self.mode, TlsMode::VerifyCa | TlsMode::VerifyFull) + && self.ca_cert_path.is_empty() + { + return Err(crate::error::Error::Internal { + err_msg: format!( + "TLS mode is {:?} but ca_cert_path is not configured", + self.mode + ), + }); + } + + Ok(()) + } + pub fn setup(&self) -> Result> { if let TlsMode::Disable = self.mode { return Ok(None); @@ -147,6 +188,13 @@ impl TlsOption { } } +pub fn merge_tls_option(main: &TlsOption, other: TlsOption) -> TlsOption { + if other.mode != TlsMode::Disable && other.validate().is_ok() { + return other; + } + main.clone() +} + impl TlsConfigLoader> for TlsOption { type Error = crate::error::Error; @@ -183,6 +231,118 @@ mod tests { use crate::install_ring_crypto_provider; use crate::tls::TlsMode::Disable; + #[test] + fn test_validate_disable_mode() { + let tls = TlsOption { + mode: TlsMode::Disable, + cert_path: String::new(), + key_path: String::new(), + ca_cert_path: String::new(), + watch: false, + }; + assert!(tls.validate().is_ok()); + } + + #[test] + fn test_validate_missing_cert_path() { + let tls = TlsOption { + mode: TlsMode::Require, + cert_path: String::new(), + key_path: "/path/to/key".to_string(), + ca_cert_path: String::new(), + watch: false, + }; + let err = tls.validate().unwrap_err(); + assert!(err.to_string().contains("cert_path")); + } + + #[test] + fn test_validate_missing_key_path() { + let tls = TlsOption { + mode: TlsMode::Require, + cert_path: "/path/to/cert".to_string(), + key_path: String::new(), + ca_cert_path: String::new(), + watch: false, + }; + let err = tls.validate().unwrap_err(); + assert!(err.to_string().contains("key_path")); + } + + #[test] + fn test_validate_require_mode_success() { + let tls = TlsOption { + mode: TlsMode::Require, + cert_path: "/path/to/cert".to_string(), + key_path: "/path/to/key".to_string(), + ca_cert_path: String::new(), + watch: false, + }; + assert!(tls.validate().is_ok()); + } + + #[test] + fn test_validate_verify_ca_missing_ca_cert() { + let tls = TlsOption { + mode: TlsMode::VerifyCa, + cert_path: "/path/to/cert".to_string(), + key_path: "/path/to/key".to_string(), + ca_cert_path: String::new(), + watch: false, + }; + let err = tls.validate().unwrap_err(); + assert!(err.to_string().contains("ca_cert_path")); + } + + #[test] + fn test_validate_verify_full_missing_ca_cert() { + let tls = TlsOption { + mode: TlsMode::VerifyFull, + cert_path: "/path/to/cert".to_string(), + key_path: "/path/to/key".to_string(), + ca_cert_path: String::new(), + watch: false, + }; + let err = tls.validate().unwrap_err(); + assert!(err.to_string().contains("ca_cert_path")); + } + + #[test] + fn test_validate_verify_ca_success() { + let tls = TlsOption { + mode: TlsMode::VerifyCa, + cert_path: "/path/to/cert".to_string(), + key_path: "/path/to/key".to_string(), + ca_cert_path: "/path/to/ca".to_string(), + watch: false, + }; + assert!(tls.validate().is_ok()); + } + + #[test] + fn test_validate_verify_full_success() { + let tls = TlsOption { + mode: TlsMode::VerifyFull, + cert_path: "/path/to/cert".to_string(), + key_path: "/path/to/key".to_string(), + ca_cert_path: "/path/to/ca".to_string(), + watch: false, + }; + assert!(tls.validate().is_ok()); + } + + #[test] + fn test_validate_prefer_mode() { + let tls = TlsOption { + mode: TlsMode::Prefer, + cert_path: "/path/to/cert".to_string(), + key_path: "/path/to/key".to_string(), + ca_cert_path: String::new(), + watch: false, + }; + assert!(tls.validate().is_ok()); + } + #[test] fn test_new_tls_option() { assert_eq!(