From 1e83ab8e2a4a56189876b2c8db27e4bbc9b17884 Mon Sep 17 00:00:00 2001 From: tison Date: Fri, 22 Mar 2024 19:24:04 +0800 Subject: [PATCH] impl Signed-off-by: tison --- Cargo.lock | 1 + src/auth/Cargo.toml | 1 + src/auth/src/common.rs | 6 ++ src/auth/src/error.rs | 7 ++ src/auth/src/user_provider.rs | 100 +++++++++++++++++- .../src/user_provider/static_user_provider.rs | 85 ++------------- .../user_provider/watch_file_user_provider.rs | 68 ++++++++++++ 7 files changed, 188 insertions(+), 80 deletions(-) create mode 100644 src/auth/src/user_provider/watch_file_user_provider.rs diff --git a/Cargo.lock b/Cargo.lock index dc63049a70..35bcabe50c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -701,6 +701,7 @@ dependencies = [ "async-trait", "common-error", "common-macro", + "common-telemetry", "common-test-util", "digest", "hex", diff --git a/src/auth/Cargo.toml b/src/auth/Cargo.toml index a8f9b631ef..db8c200545 100644 --- a/src/auth/Cargo.toml +++ b/src/auth/Cargo.toml @@ -16,6 +16,7 @@ api.workspace = true async-trait.workspace = true common-error.workspace = true common-macro.workspace = true +common-telemetry.workspace = true digest = "0.10" hex = { version = "0.4" } notify.workspace = true diff --git a/src/auth/src/common.rs b/src/auth/src/common.rs index 109a98175d..d8b70cea68 100644 --- a/src/auth/src/common.rs +++ b/src/auth/src/common.rs @@ -22,6 +22,9 @@ use snafu::{ensure, OptionExt}; use crate::error::{IllegalParamSnafu, InvalidConfigSnafu, Result, UserPasswordMismatchSnafu}; use crate::user_info::DefaultUserInfo; use crate::user_provider::static_user_provider::{StaticUserProvider, STATIC_USER_PROVIDER}; +use crate::user_provider::watch_file_user_provider::{ + WatchFileUserProvider, WATCH_FILE_USER_PROVIDER, +}; use crate::{UserInfoRef, UserProviderRef}; pub(crate) const DEFAULT_USERNAME: &str = "greptime"; @@ -43,6 +46,9 @@ pub fn user_provider_from_option(opt: &String) -> Result { StaticUserProvider::new(content).map(|p| Arc::new(p) as UserProviderRef)?; Ok(provider) } + WATCH_FILE_USER_PROVIDER => { + WatchFileUserProvider::new(content).map(|p| Arc::new(p) as UserProviderRef) + } _ => InvalidConfigSnafu { value: name.to_string(), msg: "Invalid UserProviderOption", diff --git a/src/auth/src/error.rs b/src/auth/src/error.rs index 529d711659..88351694e7 100644 --- a/src/auth/src/error.rs +++ b/src/auth/src/error.rs @@ -64,6 +64,12 @@ pub enum Error { username: String, }, + #[snafu(display("Failed to initialize a watcher for file"))] + FileWatch { + #[snafu(source)] + error: notify::Error, + }, + #[snafu(display("User is not authorized to perform this action"))] PermissionDenied { location: Location }, } @@ -73,6 +79,7 @@ impl ErrorExt for Error { match self { Error::InvalidConfig { .. } => StatusCode::InvalidArguments, Error::IllegalParam { .. } => StatusCode::InvalidArguments, + Error::FileWatch { .. } => StatusCode::InvalidArguments, Error::InternalState { .. } => StatusCode::Unexpected, Error::Io { .. } => StatusCode::Internal, Error::AuthBackend { .. } => StatusCode::Internal, diff --git a/src/auth/src/user_provider.rs b/src/auth/src/user_provider.rs index 1acf499a8d..ef95056194 100644 --- a/src/auth/src/user_provider.rs +++ b/src/auth/src/user_provider.rs @@ -13,10 +13,24 @@ // limitations under the License. pub(crate) mod static_user_provider; +pub(crate) mod watch_file_user_provider; + +use std::collections::HashMap; +use std::fs::File; +use std::io; +use std::io::BufRead; +use std::path::Path; + +use secrecy::ExposeSecret; +use snafu::{ensure, OptionExt, ResultExt}; use crate::common::{Identity, Password}; -use crate::error::Result; -use crate::UserInfoRef; +use crate::error::{ + IllegalParamSnafu, InvalidConfigSnafu, IoSnafu, Result, UnsupportedPasswordTypeSnafu, + UserNotFoundSnafu, UserPasswordMismatchSnafu, +}; +use crate::user_info::DefaultUserInfo; +use crate::{auth_mysql, UserInfoRef}; #[async_trait::async_trait] pub trait UserProvider: Send + Sync { @@ -44,3 +58,85 @@ pub trait UserProvider: Send + Sync { Ok(user_info) } } + +fn load_credential_from_file(filepath: &str) -> Result>> { + // check valid path + let path = Path::new(filepath); + ensure!( + path.exists() && path.is_file(), + InvalidConfigSnafu { + value: filepath.to_string(), + msg: "UserProvider file must be a valid file path", + } + ); + + let file = File::open(path).context(IoSnafu)?; + let credential = io::BufReader::new(file) + .lines() + .map_while(std::result::Result::ok) + .filter_map(|line| { + if let Some((k, v)) = line.split_once('=') { + Some((k.to_string(), v.as_bytes().to_vec())) + } else { + None + } + }) + .collect::>>(); + + ensure!( + !credential.is_empty(), + InvalidConfigSnafu { + value: filepath.to_string(), + msg: "UserProvider's file must contains at least one valid credential", + } + ); + + Ok(credential) +} + +fn authenticate_with_credential( + users: &HashMap>, + input_id: Identity<'_>, + input_pwd: Password<'_>, +) -> Result { + match input_id { + Identity::UserId(username, _) => { + ensure!( + !username.is_empty(), + IllegalParamSnafu { + msg: "blank username" + } + ); + let save_pwd = users.get(username).context(UserNotFoundSnafu { + username: username.to_string(), + })?; + + match input_pwd { + Password::PlainText(pwd) => { + ensure!( + !pwd.expose_secret().is_empty(), + IllegalParamSnafu { + msg: "blank password" + } + ); + return if save_pwd == pwd.expose_secret().as_bytes() { + Ok(DefaultUserInfo::with_name(username)) + } else { + UserPasswordMismatchSnafu { + username: username.to_string(), + } + .fail() + }; + } + Password::MysqlNativePassword(auth_data, salt) => { + auth_mysql(auth_data, salt, username, save_pwd) + .map(|_| DefaultUserInfo::with_name(username)) + } + Password::PgMD5(_, _) => UnsupportedPasswordTypeSnafu { + password_type: "pg_md5", + } + .fail(), + } + } + } +} diff --git a/src/auth/src/user_provider/static_user_provider.rs b/src/auth/src/user_provider/static_user_provider.rs index e6d4743894..d93ec8ee7a 100644 --- a/src/auth/src/user_provider/static_user_provider.rs +++ b/src/auth/src/user_provider/static_user_provider.rs @@ -13,21 +13,13 @@ // limitations under the License. use std::collections::HashMap; -use std::fs::File; -use std::io; -use std::io::BufRead; -use std::path::Path; use async_trait::async_trait; -use secrecy::ExposeSecret; -use snafu::{ensure, OptionExt, ResultExt}; +use snafu::OptionExt; -use crate::error::{ - IllegalParamSnafu, InvalidConfigSnafu, IoSnafu, Result, UnsupportedPasswordTypeSnafu, - UserNotFoundSnafu, UserPasswordMismatchSnafu, -}; -use crate::user_info::DefaultUserInfo; -use crate::{auth_mysql, Identity, Password, UserInfoRef, UserProvider}; +use crate::error::{InvalidConfigSnafu, Result}; +use crate::user_provider::{authenticate_with_credential, load_credential_from_file}; +use crate::{Identity, Password, UserInfoRef, UserProvider}; pub(crate) const STATIC_USER_PROVIDER: &str = "static_user_provider"; @@ -43,32 +35,8 @@ impl StaticUserProvider { })?; return match mode { "file" => { - // check valid path - let path = Path::new(content); - ensure!(path.exists() && path.is_file(), InvalidConfigSnafu { - value: content.to_string(), - msg: "StaticUserProviderOption file must be a valid file path", - }); - - let file = File::open(path).context(IoSnafu)?; - let credential = io::BufReader::new(file) - .lines() - .map_while(std::result::Result::ok) - .filter_map(|line| { - if let Some((k, v)) = line.split_once('=') { - Some((k.to_string(), v.as_bytes().to_vec())) - } else { - None - } - }) - .collect::>>(); - - ensure!(!credential.is_empty(), InvalidConfigSnafu { - value: content.to_string(), - msg: "StaticUserProviderOption file must contains at least one valid credential", - }); - - Ok(StaticUserProvider { users: credential, }) + let users = load_credential_from_file(content)?; + Ok(StaticUserProvider { users }) } "cmd" => content .split(',') @@ -101,46 +69,7 @@ impl UserProvider for StaticUserProvider { input_id: Identity<'_>, input_pwd: Password<'_>, ) -> Result { - match input_id { - Identity::UserId(username, _) => { - ensure!( - !username.is_empty(), - IllegalParamSnafu { - msg: "blank username" - } - ); - let save_pwd = self.users.get(username).context(UserNotFoundSnafu { - username: username.to_string(), - })?; - - match input_pwd { - Password::PlainText(pwd) => { - ensure!( - !pwd.expose_secret().is_empty(), - IllegalParamSnafu { - msg: "blank password" - } - ); - return if save_pwd == pwd.expose_secret().as_bytes() { - Ok(DefaultUserInfo::with_name(username)) - } else { - UserPasswordMismatchSnafu { - username: username.to_string(), - } - .fail() - }; - } - Password::MysqlNativePassword(auth_data, salt) => { - auth_mysql(auth_data, salt, username, save_pwd) - .map(|_| DefaultUserInfo::with_name(username)) - } - Password::PgMD5(_, _) => UnsupportedPasswordTypeSnafu { - password_type: "pg_md5", - } - .fail(), - } - } - } + authenticate_with_credential(&self.users, input_id, input_pwd) } async fn authorize( diff --git a/src/auth/src/user_provider/watch_file_user_provider.rs b/src/auth/src/user_provider/watch_file_user_provider.rs new file mode 100644 index 0000000000..196d9ad2f6 --- /dev/null +++ b/src/auth/src/user_provider/watch_file_user_provider.rs @@ -0,0 +1,68 @@ +use std::collections::HashMap; +use std::path::Path; +use std::sync::mpsc::channel; +use std::sync::{Arc, Mutex}; + +use async_trait::async_trait; +use common_telemetry::info; +use notify::{EventKind, RecursiveMode, Watcher}; +use snafu::ResultExt; + +use crate::error::{FileWatchSnafu, Result}; +use crate::user_provider::{authenticate_with_credential, load_credential_from_file}; +use crate::{Identity, Password, UserInfoRef, UserProvider}; + +pub(crate) const WATCH_FILE_USER_PROVIDER: &str = "watch_file_user_provider"; + +pub(crate) struct WatchFileUserProvider { + users: Arc>>>, +} + +impl WatchFileUserProvider { + pub fn new(filepath: &str) -> Result { + let credential = load_credential_from_file(filepath)?; + let users = Arc::new(Mutex::new(credential)); + let this = WatchFileUserProvider { + users: users.clone(), + }; + + let (tx, rx) = channel::>(); + let mut watcher = notify::recommended_watcher(tx).context(FileWatchSnafu)?; + watcher + .watch(Path::new(filepath), RecursiveMode::NonRecursive) + .context(FileWatchSnafu)?; + let filepath = filepath.to_string(); + std::thread::spawn(move || { + let _watcher = watcher; + while let Ok(res) = rx.recv() { + if let Ok(event) = res { + if matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_)) { + info!("Detected user provider file change: {:?}", event); + if let Ok(credential) = load_credential_from_file(&filepath) { + *users.lock().expect("users credential must be valid") = credential; + } + } + } + } + }); + + Ok(this) + } +} + +#[async_trait] +impl UserProvider for WatchFileUserProvider { + fn name(&self) -> &str { + WATCH_FILE_USER_PROVIDER + } + + async fn authenticate(&self, id: Identity<'_>, password: Password<'_>) -> Result { + let users = self.users.lock().expect("users credential must be valid"); + authenticate_with_credential(&users, id, password) + } + + async fn authorize(&self, _: &str, _: &str, _: &UserInfoRef) -> Result<()> { + // default allow all + Ok(()) + } +}