diff --git a/Cargo.lock b/Cargo.lock index eeb75e0f07..cb732d1ab0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -687,8 +687,10 @@ dependencies = [ "async-trait", "common-error", "common-macro", + "common-telemetry", "common-test-util", "digest", + "notify", "secrecy", "sha1", "snafu", diff --git a/Cargo.toml b/Cargo.toml index cebad1ef89..54354000b4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -110,6 +110,7 @@ lazy_static = "1.4" meter-core = { git = "https://github.com/GreptimeTeam/greptime-meter.git", rev = "80b72716dcde47ec4161478416a5c6c21343364d" } mockall = "0.11.4" moka = "0.12" +notify = "6.1" num_cpus = "1.16" once_cell = "1.18" opentelemetry-proto = { git = "https://github.com/waynexia/opentelemetry-rust.git", rev = "33841b38dda79b15f2024952be5f32533325ca02", features = [ diff --git a/src/auth/Cargo.toml b/src/auth/Cargo.toml index a3e9f199a1..c10a38e86f 100644 --- a/src/auth/Cargo.toml +++ b/src/auth/Cargo.toml @@ -16,7 +16,9 @@ api.workspace = true async-trait.workspace = true common-error.workspace = true common-macro.workspace = true +common-telemetry.workspace = true digest = "0.10" +notify.workspace = true secrecy = { version = "0.8", features = ["serde", "alloc"] } sha1 = "0.10" snafu.workspace = true diff --git a/src/auth/src/common.rs b/src/auth/src/common.rs index 109a98175d..d8b70cea68 100644 --- a/src/auth/src/common.rs +++ b/src/auth/src/common.rs @@ -22,6 +22,9 @@ use snafu::{ensure, OptionExt}; use crate::error::{IllegalParamSnafu, InvalidConfigSnafu, Result, UserPasswordMismatchSnafu}; use crate::user_info::DefaultUserInfo; use crate::user_provider::static_user_provider::{StaticUserProvider, STATIC_USER_PROVIDER}; +use crate::user_provider::watch_file_user_provider::{ + WatchFileUserProvider, WATCH_FILE_USER_PROVIDER, +}; use crate::{UserInfoRef, UserProviderRef}; pub(crate) const DEFAULT_USERNAME: &str = "greptime"; @@ -43,6 +46,9 @@ pub fn user_provider_from_option(opt: &String) -> Result { StaticUserProvider::new(content).map(|p| Arc::new(p) as UserProviderRef)?; Ok(provider) } + WATCH_FILE_USER_PROVIDER => { + WatchFileUserProvider::new(content).map(|p| Arc::new(p) as UserProviderRef) + } _ => InvalidConfigSnafu { value: name.to_string(), msg: "Invalid UserProviderOption", diff --git a/src/auth/src/error.rs b/src/auth/src/error.rs index 529d711659..bb9f37e83b 100644 --- a/src/auth/src/error.rs +++ b/src/auth/src/error.rs @@ -64,6 +64,13 @@ pub enum Error { username: String, }, + #[snafu(display("Failed to initialize a watcher for file {}", path))] + FileWatch { + path: String, + #[snafu(source)] + error: notify::Error, + }, + #[snafu(display("User is not authorized to perform this action"))] PermissionDenied { location: Location }, } @@ -73,6 +80,7 @@ impl ErrorExt for Error { match self { Error::InvalidConfig { .. } => StatusCode::InvalidArguments, Error::IllegalParam { .. } => StatusCode::InvalidArguments, + Error::FileWatch { .. } => StatusCode::InvalidArguments, Error::InternalState { .. } => StatusCode::Unexpected, Error::Io { .. } => StatusCode::Internal, Error::AuthBackend { .. } => StatusCode::Internal, diff --git a/src/auth/src/user_provider.rs b/src/auth/src/user_provider.rs index 1acf499a8d..4fab604e62 100644 --- a/src/auth/src/user_provider.rs +++ b/src/auth/src/user_provider.rs @@ -13,10 +13,24 @@ // limitations under the License. pub(crate) mod static_user_provider; +pub(crate) mod watch_file_user_provider; + +use std::collections::HashMap; +use std::fs::File; +use std::io; +use std::io::BufRead; +use std::path::Path; + +use secrecy::ExposeSecret; +use snafu::{ensure, OptionExt, ResultExt}; use crate::common::{Identity, Password}; -use crate::error::Result; -use crate::UserInfoRef; +use crate::error::{ + IllegalParamSnafu, InvalidConfigSnafu, IoSnafu, Result, UnsupportedPasswordTypeSnafu, + UserNotFoundSnafu, UserPasswordMismatchSnafu, +}; +use crate::user_info::DefaultUserInfo; +use crate::{auth_mysql, UserInfoRef}; #[async_trait::async_trait] pub trait UserProvider: Send + Sync { @@ -44,3 +58,88 @@ pub trait UserProvider: Send + Sync { Ok(user_info) } } + +fn load_credential_from_file(filepath: &str) -> Result>>> { + // check valid path + let path = Path::new(filepath); + if !path.exists() { + return Ok(None); + } + + ensure!( + path.is_file(), + InvalidConfigSnafu { + value: filepath, + msg: "UserProvider file must be a file", + } + ); + let file = File::open(path).context(IoSnafu)?; + let credential = io::BufReader::new(file) + .lines() + .map_while(std::result::Result::ok) + .filter_map(|line| { + if let Some((k, v)) = line.split_once('=') { + Some((k.to_string(), v.as_bytes().to_vec())) + } else { + None + } + }) + .collect::>>(); + + ensure!( + !credential.is_empty(), + InvalidConfigSnafu { + value: filepath, + msg: "UserProvider's file must contains at least one valid credential", + } + ); + + Ok(Some(credential)) +} + +fn authenticate_with_credential( + users: &HashMap>, + input_id: Identity<'_>, + input_pwd: Password<'_>, +) -> Result { + match input_id { + Identity::UserId(username, _) => { + ensure!( + !username.is_empty(), + IllegalParamSnafu { + msg: "blank username" + } + ); + let save_pwd = users.get(username).context(UserNotFoundSnafu { + username: username.to_string(), + })?; + + match input_pwd { + Password::PlainText(pwd) => { + ensure!( + !pwd.expose_secret().is_empty(), + IllegalParamSnafu { + msg: "blank password" + } + ); + if save_pwd == pwd.expose_secret().as_bytes() { + Ok(DefaultUserInfo::with_name(username)) + } else { + UserPasswordMismatchSnafu { + username: username.to_string(), + } + .fail() + } + } + Password::MysqlNativePassword(auth_data, salt) => { + auth_mysql(auth_data, salt, username, save_pwd) + .map(|_| DefaultUserInfo::with_name(username)) + } + Password::PgMD5(_, _) => UnsupportedPasswordTypeSnafu { + password_type: "pg_md5", + } + .fail(), + } + } + } +} diff --git a/src/auth/src/user_provider/static_user_provider.rs b/src/auth/src/user_provider/static_user_provider.rs index e6d4743894..9e05671219 100644 --- a/src/auth/src/user_provider/static_user_provider.rs +++ b/src/auth/src/user_provider/static_user_provider.rs @@ -13,21 +13,13 @@ // limitations under the License. use std::collections::HashMap; -use std::fs::File; -use std::io; -use std::io::BufRead; -use std::path::Path; use async_trait::async_trait; -use secrecy::ExposeSecret; -use snafu::{ensure, OptionExt, ResultExt}; +use snafu::OptionExt; -use crate::error::{ - IllegalParamSnafu, InvalidConfigSnafu, IoSnafu, Result, UnsupportedPasswordTypeSnafu, - UserNotFoundSnafu, UserPasswordMismatchSnafu, -}; -use crate::user_info::DefaultUserInfo; -use crate::{auth_mysql, Identity, Password, UserInfoRef, UserProvider}; +use crate::error::{InvalidConfigSnafu, Result}; +use crate::user_provider::{authenticate_with_credential, load_credential_from_file}; +use crate::{Identity, Password, UserInfoRef, UserProvider}; pub(crate) const STATIC_USER_PROVIDER: &str = "static_user_provider"; @@ -43,32 +35,12 @@ impl StaticUserProvider { })?; return match mode { "file" => { - // check valid path - let path = Path::new(content); - ensure!(path.exists() && path.is_file(), InvalidConfigSnafu { - value: content.to_string(), - msg: "StaticUserProviderOption file must be a valid file path", - }); - - let file = File::open(path).context(IoSnafu)?; - let credential = io::BufReader::new(file) - .lines() - .map_while(std::result::Result::ok) - .filter_map(|line| { - if let Some((k, v)) = line.split_once('=') { - Some((k.to_string(), v.as_bytes().to_vec())) - } else { - None - } - }) - .collect::>>(); - - ensure!(!credential.is_empty(), InvalidConfigSnafu { - value: content.to_string(), - msg: "StaticUserProviderOption file must contains at least one valid credential", - }); - - Ok(StaticUserProvider { users: credential, }) + let users = load_credential_from_file(content)? + .context(InvalidConfigSnafu { + value: content.to_string(), + msg: "StaticFileUserProvider must be a valid file path", + })?; + Ok(StaticUserProvider { users }) } "cmd" => content .split(',') @@ -96,51 +68,8 @@ impl UserProvider for StaticUserProvider { STATIC_USER_PROVIDER } - async fn authenticate( - &self, - input_id: Identity<'_>, - input_pwd: Password<'_>, - ) -> Result { - match input_id { - Identity::UserId(username, _) => { - ensure!( - !username.is_empty(), - IllegalParamSnafu { - msg: "blank username" - } - ); - let save_pwd = self.users.get(username).context(UserNotFoundSnafu { - username: username.to_string(), - })?; - - match input_pwd { - Password::PlainText(pwd) => { - ensure!( - !pwd.expose_secret().is_empty(), - IllegalParamSnafu { - msg: "blank password" - } - ); - return if save_pwd == pwd.expose_secret().as_bytes() { - Ok(DefaultUserInfo::with_name(username)) - } else { - UserPasswordMismatchSnafu { - username: username.to_string(), - } - .fail() - }; - } - Password::MysqlNativePassword(auth_data, salt) => { - auth_mysql(auth_data, salt, username, save_pwd) - .map(|_| DefaultUserInfo::with_name(username)) - } - Password::PgMD5(_, _) => UnsupportedPasswordTypeSnafu { - password_type: "pg_md5", - } - .fail(), - } - } - } + async fn authenticate(&self, id: Identity<'_>, pwd: Password<'_>) -> Result { + authenticate_with_credential(&self.users, id, pwd) } async fn authorize( diff --git a/src/auth/src/user_provider/watch_file_user_provider.rs b/src/auth/src/user_provider/watch_file_user_provider.rs new file mode 100644 index 0000000000..4a654f2f31 --- /dev/null +++ b/src/auth/src/user_provider/watch_file_user_provider.rs @@ -0,0 +1,215 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; +use std::path::Path; +use std::sync::mpsc::channel; +use std::sync::{Arc, Mutex}; + +use async_trait::async_trait; +use common_telemetry::{info, warn}; +use notify::{EventKind, RecursiveMode, Watcher}; +use snafu::{ensure, ResultExt}; + +use crate::error::{FileWatchSnafu, InvalidConfigSnafu, Result}; +use crate::user_info::DefaultUserInfo; +use crate::user_provider::{authenticate_with_credential, load_credential_from_file}; +use crate::{Identity, Password, UserInfoRef, UserProvider}; + +pub(crate) const WATCH_FILE_USER_PROVIDER: &str = "watch_file_user_provider"; + +type WatchedCredentialRef = Arc>>>>; + +/// A user provider that reads user credential from a file and watches the file for changes. +/// +/// Empty file is invalid; but file not exist means every user can be authenticated. +pub(crate) struct WatchFileUserProvider { + users: WatchedCredentialRef, +} + +impl WatchFileUserProvider { + pub fn new(filepath: &str) -> Result { + let credential = load_credential_from_file(filepath)?; + let users = Arc::new(Mutex::new(credential)); + let this = WatchFileUserProvider { + users: users.clone(), + }; + + let (tx, rx) = channel::>(); + let mut debouncer = + notify::recommended_watcher(tx).context(FileWatchSnafu { path: "" })?; + let mut dir = Path::new(filepath).to_path_buf(); + ensure!( + dir.pop(), + InvalidConfigSnafu { + value: filepath, + msg: "UserProvider path must be a file path", + } + ); + debouncer + .watch(&dir, RecursiveMode::NonRecursive) + .context(FileWatchSnafu { path: filepath })?; + + let filepath = filepath.to_string(); + std::thread::spawn(move || { + let filename = Path::new(&filepath).file_name(); + let _hold = debouncer; + while let Ok(res) = rx.recv() { + if let Ok(event) = res { + let is_this_file = event.paths.iter().any(|p| p.file_name() == filename); + let is_relevant_event = matches!( + event.kind, + EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_) + ); + if is_this_file && is_relevant_event { + info!(?event.kind, "User provider file {} changed", &filepath); + match load_credential_from_file(&filepath) { + Ok(credential) => { + let mut users = + users.lock().expect("users credential must be valid"); + #[cfg(not(test))] + info!("User provider file {filepath} reloaded"); + #[cfg(test)] + info!("User provider file {filepath} reloaded: {credential:?}"); + *users = credential; + } + Err(err) => { + warn!( + ?err, + "Fail to load credential from file {filepath}; keep the old one", + ) + } + } + } + } + } + }); + + Ok(this) + } +} + +#[async_trait] +impl UserProvider for WatchFileUserProvider { + fn name(&self) -> &str { + WATCH_FILE_USER_PROVIDER + } + + async fn authenticate(&self, id: Identity<'_>, password: Password<'_>) -> Result { + let users = self.users.lock().expect("users credential must be valid"); + if let Some(users) = users.as_ref() { + authenticate_with_credential(users, id, password) + } else { + match id { + Identity::UserId(id, _) => { + warn!(id, "User provider file not exist, allow all users"); + Ok(DefaultUserInfo::with_name(id)) + } + } + } + } + + async fn authorize(&self, _: &str, _: &str, _: &UserInfoRef) -> Result<()> { + // default allow all + Ok(()) + } +} + +#[cfg(test)] +pub mod test { + use std::time::{Duration, Instant}; + + use common_test_util::temp_dir::create_temp_dir; + use tokio::time::sleep; + + use crate::user_provider::watch_file_user_provider::WatchFileUserProvider; + use crate::user_provider::{Identity, Password}; + use crate::UserProvider; + + async fn test_authenticate( + provider: &dyn UserProvider, + username: &str, + password: &str, + ok: bool, + timeout: Option, + ) { + if let Some(timeout) = timeout { + let deadline = Instant::now().checked_add(timeout).unwrap(); + loop { + let re = provider + .authenticate( + Identity::UserId(username, None), + Password::PlainText(password.to_string().into()), + ) + .await; + if re.is_ok() == ok { + break; + } else if Instant::now() < deadline { + sleep(Duration::from_millis(100)).await; + } else { + panic!("timeout (username: {username}, password: {password}, expected: {ok})"); + } + } + } else { + let re = provider + .authenticate( + Identity::UserId(username, None), + Password::PlainText(password.to_string().into()), + ) + .await; + assert_eq!( + re.is_ok(), + ok, + "username: {}, password: {}", + username, + password + ); + } + } + + #[tokio::test] + async fn test_file_provider() { + common_telemetry::init_default_ut_logging(); + + let dir = create_temp_dir("test_file_provider"); + let file_path = format!("{}/test_file_provider", dir.path().to_str().unwrap()); + + // write a tmp file + assert!(std::fs::write(&file_path, "root=123456\nadmin=654321\n").is_ok()); + let provider = WatchFileUserProvider::new(file_path.as_str()).unwrap(); + let timeout = Duration::from_secs(60); + + test_authenticate(&provider, "root", "123456", true, None).await; + test_authenticate(&provider, "admin", "654321", true, None).await; + test_authenticate(&provider, "root", "654321", false, None).await; + + // update the tmp file + assert!(std::fs::write(&file_path, "root=654321\n").is_ok()); + test_authenticate(&provider, "root", "123456", false, Some(timeout)).await; + test_authenticate(&provider, "root", "654321", true, Some(timeout)).await; + test_authenticate(&provider, "admin", "654321", false, Some(timeout)).await; + + // remove the tmp file + assert!(std::fs::remove_file(&file_path).is_ok()); + test_authenticate(&provider, "root", "123456", true, Some(timeout)).await; + test_authenticate(&provider, "root", "654321", true, Some(timeout)).await; + test_authenticate(&provider, "admin", "654321", true, Some(timeout)).await; + + // recreate the tmp file + assert!(std::fs::write(&file_path, "root=123456\n").is_ok()); + test_authenticate(&provider, "root", "123456", true, Some(timeout)).await; + test_authenticate(&provider, "root", "654321", false, Some(timeout)).await; + test_authenticate(&provider, "admin", "654321", false, Some(timeout)).await; + } +} diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index 4a54a49b27..43dbc55703 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -56,7 +56,7 @@ influxdb_line_protocol = { git = "https://github.com/evenyag/influxdb_iox", bran itertools.workspace = true lazy_static.workspace = true mime_guess = "2.0" -notify = "6.1" +notify.workspace = true object-pool = "0.5" once_cell.workspace = true openmetrics-parser = "0.4" diff --git a/src/servers/src/error.rs b/src/servers/src/error.rs index 0546d2a262..2d47547e65 100644 --- a/src/servers/src/error.rs +++ b/src/servers/src/error.rs @@ -449,8 +449,9 @@ pub enum Error { ))] UnexpectedPhysicalTable { location: Location }, - #[snafu(display("Failed to initialize a watcher for file"))] + #[snafu(display("Failed to initialize a watcher for file {}", path))] FileWatch { + path: String, #[snafu(source)] error: notify::Error, }, diff --git a/src/servers/src/tls.rs b/src/servers/src/tls.rs index f36970a42b..2055081012 100644 --- a/src/servers/src/tls.rs +++ b/src/servers/src/tls.rs @@ -200,21 +200,21 @@ pub fn maybe_watch_tls_config(tls_server_config: Arc) let tls_server_config_for_watcher = tls_server_config.clone(); let (tx, rx) = channel::>(); - let mut watcher = notify::recommended_watcher(tx).context(FileWatchSnafu)?; + let mut watcher = notify::recommended_watcher(tx).context(FileWatchSnafu { path: "" })?; + let cert_path = tls_server_config.get_tls_option().cert_path(); watcher - .watch( - tls_server_config.get_tls_option().cert_path(), - RecursiveMode::NonRecursive, - ) - .context(FileWatchSnafu)?; + .watch(cert_path, RecursiveMode::NonRecursive) + .with_context(|_| FileWatchSnafu { + path: cert_path.display().to_string(), + })?; + let key_path = tls_server_config.get_tls_option().key_path(); watcher - .watch( - tls_server_config.get_tls_option().key_path(), - RecursiveMode::NonRecursive, - ) - .context(FileWatchSnafu)?; + .watch(key_path, RecursiveMode::NonRecursive) + .with_context(|_| FileWatchSnafu { + path: key_path.display().to_string(), + })?; std::thread::spawn(move || { let _watcher = watcher;