diff --git a/src/auth/src/user_provider.rs b/src/auth/src/user_provider.rs index 03ec034f3a..4fab604e62 100644 --- a/src/auth/src/user_provider.rs +++ b/src/auth/src/user_provider.rs @@ -59,17 +59,20 @@ pub trait UserProvider: Send + Sync { } } -fn load_credential_from_file(filepath: &str) -> Result>> { +fn load_credential_from_file(filepath: &str) -> Result>>> { // check valid path let path = Path::new(filepath); + if !path.exists() { + return Ok(None); + } + ensure!( - path.exists() && path.is_file(), + path.is_file(), InvalidConfigSnafu { value: filepath, - msg: "UserProvider file must be a valid file path", + msg: "UserProvider file must be a file", } ); - let file = File::open(path).context(IoSnafu)?; let credential = io::BufReader::new(file) .lines() @@ -91,7 +94,7 @@ fn load_credential_from_file(filepath: &str) -> Result>> } ); - Ok(credential) + Ok(Some(credential)) } fn authenticate_with_credential( diff --git a/src/auth/src/user_provider/static_user_provider.rs b/src/auth/src/user_provider/static_user_provider.rs index d93ec8ee7a..9e05671219 100644 --- a/src/auth/src/user_provider/static_user_provider.rs +++ b/src/auth/src/user_provider/static_user_provider.rs @@ -35,7 +35,11 @@ impl StaticUserProvider { })?; return match mode { "file" => { - let users = load_credential_from_file(content)?; + let users = load_credential_from_file(content)? + .context(InvalidConfigSnafu { + value: content.to_string(), + msg: "StaticFileUserProvider must be a valid file path", + })?; Ok(StaticUserProvider { users }) } "cmd" => content @@ -64,12 +68,8 @@ impl UserProvider for StaticUserProvider { STATIC_USER_PROVIDER } - async fn authenticate( - &self, - input_id: Identity<'_>, - input_pwd: Password<'_>, - ) -> Result { - authenticate_with_credential(&self.users, input_id, input_pwd) + async fn authenticate(&self, id: Identity<'_>, pwd: Password<'_>) -> Result { + authenticate_with_credential(&self.users, id, pwd) } async fn authorize( diff --git a/src/auth/src/user_provider/watch_file_user_provider.rs b/src/auth/src/user_provider/watch_file_user_provider.rs index 6a8f59f87b..d3abe37a8a 100644 --- a/src/auth/src/user_provider/watch_file_user_provider.rs +++ b/src/auth/src/user_provider/watch_file_user_provider.rs @@ -16,22 +16,26 @@ use std::collections::HashMap; use std::path::Path; use std::sync::mpsc::channel; use std::sync::{Arc, Mutex}; -use std::time::Duration; use async_trait::async_trait; -use common_telemetry::info; +use common_telemetry::{error, info, warn}; use notify::{EventKind, RecursiveMode, Watcher}; use snafu::ResultExt; use crate::error::{FileWatchSnafu, Result}; +use crate::user_info::DefaultUserInfo; use crate::user_provider::{authenticate_with_credential, load_credential_from_file}; use crate::{Identity, Password, UserInfoRef, UserProvider}; pub(crate) const WATCH_FILE_USER_PROVIDER: &str = "watch_file_user_provider"; +type WatchedCredentialRef = Arc>>>>; + /// A user provider that reads user credential from a file and watches the file for changes. +/// +/// Empty file is invalid; but file not exist means every user can be authenticated. pub(crate) struct WatchFileUserProvider { - users: Arc>>>, + users: WatchedCredentialRef, } impl WatchFileUserProvider { @@ -54,10 +58,21 @@ impl WatchFileUserProvider { let _hold = debouncer; while let Ok(res) = rx.recv() { if let Ok(event) = res { - if matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_)) { + if matches!( + event.kind, + EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_) + ) { info!("User provider file {} changed", &filepath); - if let Ok(credential) = load_credential_from_file(&filepath) { - *users.lock().expect("users credential must be valid") = credential; + match load_credential_from_file(&filepath) { + Ok(credential) => { + *users.lock().expect("users credential must be valid") = credential; + } + Err(err) => { + error!( + ?err, + "fail to load credential from changed file; keep the old one" + ) + } } } } @@ -76,7 +91,16 @@ impl UserProvider for WatchFileUserProvider { async fn authenticate(&self, id: Identity<'_>, password: Password<'_>) -> Result { let users = self.users.lock().expect("users credential must be valid"); - authenticate_with_credential(&users, id, password) + if let Some(users) = users.as_ref() { + authenticate_with_credential(users, id, password) + } else { + match id { + Identity::UserId(id, _) => { + warn!(id, "User provider file is empty, allow all users"); + Ok(DefaultUserInfo::with_name(id)) + } + } + } } async fn authorize(&self, _: &str, _: &str, _: &UserInfoRef) -> Result<()> { @@ -148,8 +172,18 @@ admin=654321", assert!(lw.write_all(b"root=654321",).is_ok()); lw.flush().unwrap(); } - sleep(Duration::from_secs(2)).await; // wait the watcher to apply the change + sleep(Duration::from_secs(1)).await; // wait the watcher to apply the change + test_authenticate(&provider, "root", "123456", false).await; test_authenticate(&provider, "root", "654321", true).await; test_authenticate(&provider, "admin", "654321", false).await; + + { + // remove the tmp file + std::fs::remove_file(&file_path).unwrap(); + } + sleep(Duration::from_secs(1)).await; // wait the watcher to apply the change + test_authenticate(&provider, "root", "123456", true).await; + test_authenticate(&provider, "root", "654321", true).await; + test_authenticate(&provider, "admin", "654321", true).await; } }