mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-07 05:42:57 +00:00
feat: impl static_user_provider (#739)
* feat: add MemUserProvider and impl auth * feat: impl user_provider option in fe and standalone mode * chore: add file impl for mem provider * chore: remove mem opts * chore: minor change * chore: refac pg server to use user_provider as indicator for using pwd auth * chore: fix test * chore: extract common code * chore: add unit test * chore: rebase develop * chore: add user provider to http server * chore: minor rename * chore: change to ref when convert to anymap * chore: fix according to clippy * chore: remove clone on startcommand * chore: fix cr issue * chore: update tempdir use * chore: change TryFrom to normal func while parsing anymap * chore: minor change * chore: remove to_lowercase
This commit is contained in:
11
Cargo.lock
generated
11
Cargo.lock
generated
@@ -127,6 +127,12 @@ version = "1.0.65"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "98161a4e3e2184da77bb14f02184cdd111e83bbbcc9979dfee3c44b9a85f5602"
|
||||
|
||||
[[package]]
|
||||
name = "anymap"
|
||||
version = "1.0.0-beta.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8f1f8f5a6f3d50d89e3797d7593a50f96bb2aaa20ca0cc7be1fb673232c91d72"
|
||||
|
||||
[[package]]
|
||||
name = "api"
|
||||
version = "0.1.0"
|
||||
@@ -1276,6 +1282,7 @@ dependencies = [
|
||||
name = "cmd"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anymap",
|
||||
"build-data",
|
||||
"clap 3.2.22",
|
||||
"common-error",
|
||||
@@ -2475,6 +2482,7 @@ dependencies = [
|
||||
name = "frontend"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anymap",
|
||||
"api",
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
@@ -6114,6 +6122,7 @@ dependencies = [
|
||||
"common-telemetry",
|
||||
"common-time",
|
||||
"datatypes",
|
||||
"digest",
|
||||
"futures",
|
||||
"hex",
|
||||
"http-body",
|
||||
@@ -6138,10 +6147,12 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"session",
|
||||
"sha1",
|
||||
"snafu",
|
||||
"snap",
|
||||
"strum 0.24.1",
|
||||
"table",
|
||||
"tempdir",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres-rustls",
|
||||
|
||||
@@ -10,6 +10,7 @@ name = "greptime"
|
||||
path = "src/bin/greptime.rs"
|
||||
|
||||
[dependencies]
|
||||
anymap = "1.0.0-beta.2"
|
||||
clap = { version = "3.1", features = ["derive"] }
|
||||
common-error = { path = "../common/error" }
|
||||
common-telemetry = { path = "../common/telemetry", features = [
|
||||
|
||||
@@ -55,6 +55,12 @@ pub enum Error {
|
||||
|
||||
#[snafu(display("Illegal config: {}", msg))]
|
||||
IllegalConfig { msg: String, backtrace: Backtrace },
|
||||
|
||||
#[snafu(display("Illegal auth config: {}", source))]
|
||||
IllegalAuthConfig {
|
||||
#[snafu(backtrace)]
|
||||
source: servers::auth::Error,
|
||||
},
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
@@ -69,6 +75,7 @@ impl ErrorExt for Error {
|
||||
StatusCode::InvalidArguments
|
||||
}
|
||||
Error::IllegalConfig { .. } => StatusCode::InvalidArguments,
|
||||
Error::IllegalAuthConfig { .. } => StatusCode::InvalidArguments,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use anymap::AnyMap;
|
||||
use clap::Parser;
|
||||
use frontend::frontend::{Frontend, FrontendOptions};
|
||||
use frontend::grpc::GrpcOptions;
|
||||
@@ -23,12 +24,13 @@ use frontend::mysql::MysqlOptions;
|
||||
use frontend::opentsdb::OpentsdbOptions;
|
||||
use frontend::postgres::PostgresOptions;
|
||||
use meta_client::MetaClientOpts;
|
||||
use servers::auth::UserProviderRef;
|
||||
use servers::http::HttpOptions;
|
||||
use servers::tls::{TlsMode, TlsOption};
|
||||
use servers::Mode;
|
||||
use servers::{auth, Mode};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::error::{self, Result};
|
||||
use crate::error::{self, IllegalAuthConfigSnafu, Result};
|
||||
use crate::toml_loader;
|
||||
|
||||
#[derive(Parser)]
|
||||
@@ -80,21 +82,35 @@ pub struct StartCommand {
|
||||
tls_cert_path: Option<String>,
|
||||
#[clap(long)]
|
||||
tls_key_path: Option<String>,
|
||||
#[clap(long)]
|
||||
user_provider: Option<String>,
|
||||
}
|
||||
|
||||
impl StartCommand {
|
||||
async fn run(self) -> Result<()> {
|
||||
let plugins = load_frontend_plugins(&self.user_provider)?;
|
||||
let opts: FrontendOptions = self.try_into()?;
|
||||
let mut frontend = Frontend::new(
|
||||
opts.clone(),
|
||||
Instance::try_new_distributed(&opts)
|
||||
.await
|
||||
.context(error::StartFrontendSnafu)?,
|
||||
plugins,
|
||||
);
|
||||
frontend.start().await.context(error::StartFrontendSnafu)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_frontend_plugins(user_provider: &Option<String>) -> Result<AnyMap> {
|
||||
let mut plugins = AnyMap::new();
|
||||
|
||||
if let Some(provider) = user_provider {
|
||||
let provider = auth::user_provider_from_option(provider).context(IllegalAuthConfigSnafu)?;
|
||||
plugins.insert::<UserProviderRef>(provider);
|
||||
}
|
||||
Ok(plugins)
|
||||
}
|
||||
|
||||
impl TryFrom<StartCommand> for FrontendOptions {
|
||||
type Error = error::Error;
|
||||
|
||||
@@ -160,6 +176,8 @@ impl TryFrom<StartCommand> for FrontendOptions {
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use servers::auth::{Identity, Password, UserProviderRef};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
@@ -176,6 +194,7 @@ mod tests {
|
||||
tls_mode: None,
|
||||
tls_cert_path: None,
|
||||
tls_key_path: None,
|
||||
user_provider: None,
|
||||
};
|
||||
|
||||
let opts: FrontendOptions = command.try_into().unwrap();
|
||||
@@ -228,6 +247,7 @@ mod tests {
|
||||
tls_mode: None,
|
||||
tls_cert_path: None,
|
||||
tls_key_path: None,
|
||||
user_provider: None,
|
||||
};
|
||||
|
||||
let fe_opts = FrontendOptions::try_from(command).unwrap();
|
||||
@@ -241,4 +261,34 @@ mod tests {
|
||||
fe_opts.http_options.as_ref().unwrap().timeout
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_try_from_start_command_to_anymap() {
|
||||
let command = StartCommand {
|
||||
http_addr: None,
|
||||
grpc_addr: None,
|
||||
mysql_addr: None,
|
||||
postgres_addr: None,
|
||||
opentsdb_addr: None,
|
||||
influxdb_enable: None,
|
||||
config_file: None,
|
||||
metasrv_addr: None,
|
||||
tls_mode: None,
|
||||
tls_cert_path: None,
|
||||
tls_key_path: None,
|
||||
user_provider: Some("static_user_provider:cmd:test=test".to_string()),
|
||||
};
|
||||
|
||||
let plugins = load_frontend_plugins(&command.user_provider);
|
||||
assert!(plugins.is_ok());
|
||||
let plugins = plugins.unwrap();
|
||||
let provider = plugins.get::<UserProviderRef>();
|
||||
assert!(provider.is_some());
|
||||
|
||||
let provider = provider.unwrap();
|
||||
let result = provider
|
||||
.auth(Identity::UserId("test", None), Password::PlainText("test"))
|
||||
.await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use anymap::AnyMap;
|
||||
use clap::Parser;
|
||||
use common_telemetry::info;
|
||||
use datanode::datanode::{Datanode, DatanodeOptions, ObjectStoreConfig};
|
||||
@@ -33,6 +34,7 @@ use servers::Mode;
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::error::{Error, IllegalConfigSnafu, Result, StartDatanodeSnafu, StartFrontendSnafu};
|
||||
use crate::frontend::load_frontend_plugins;
|
||||
use crate::toml_loader;
|
||||
|
||||
#[derive(Parser)]
|
||||
@@ -142,12 +144,15 @@ struct StartCommand {
|
||||
tls_cert_path: Option<String>,
|
||||
#[clap(long)]
|
||||
tls_key_path: Option<String>,
|
||||
#[clap(long)]
|
||||
user_provider: Option<String>,
|
||||
}
|
||||
|
||||
impl StartCommand {
|
||||
async fn run(self) -> Result<()> {
|
||||
let enable_memory_catalog = self.enable_memory_catalog;
|
||||
let config_file = self.config_file.clone();
|
||||
let plugins = load_frontend_plugins(&self.user_provider)?;
|
||||
let fe_opts = FrontendOptions::try_from(self)?;
|
||||
let dn_opts: DatanodeOptions = {
|
||||
let mut opts: StandaloneOptions = if let Some(path) = config_file {
|
||||
@@ -167,7 +172,7 @@ impl StartCommand {
|
||||
let mut datanode = Datanode::new(dn_opts.clone())
|
||||
.await
|
||||
.context(StartDatanodeSnafu)?;
|
||||
let mut frontend = build_frontend(fe_opts, datanode.get_instance()).await?;
|
||||
let mut frontend = build_frontend(fe_opts, plugins, datanode.get_instance()).await?;
|
||||
|
||||
// Start datanode instance before starting services, to avoid requests come in before internal components are started.
|
||||
datanode
|
||||
@@ -184,12 +189,13 @@ impl StartCommand {
|
||||
/// Build frontend instance in standalone mode
|
||||
async fn build_frontend(
|
||||
fe_opts: FrontendOptions,
|
||||
plugins: AnyMap,
|
||||
datanode_instance: InstanceRef,
|
||||
) -> Result<Frontend<FeInstance>> {
|
||||
let mut frontend_instance = FeInstance::new_standalone(datanode_instance.clone());
|
||||
frontend_instance.set_catalog_manager(datanode_instance.catalog_manager().clone());
|
||||
frontend_instance.set_script_handler(datanode_instance);
|
||||
Ok(Frontend::new(fe_opts, frontend_instance))
|
||||
Ok(Frontend::new(fe_opts, frontend_instance, plugins))
|
||||
}
|
||||
|
||||
impl TryFrom<StartCommand> for FrontendOptions {
|
||||
@@ -274,6 +280,8 @@ impl TryFrom<StartCommand> for FrontendOptions {
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use servers::auth::{Identity, Password, UserProviderRef};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
@@ -293,6 +301,7 @@ mod tests {
|
||||
tls_mode: None,
|
||||
tls_cert_path: None,
|
||||
tls_key_path: None,
|
||||
user_provider: None,
|
||||
};
|
||||
|
||||
let fe_opts = FrontendOptions::try_from(cmd).unwrap();
|
||||
@@ -316,4 +325,33 @@ mod tests {
|
||||
assert_eq!(2, fe_opts.mysql_options.as_ref().unwrap().runtime_size);
|
||||
assert!(fe_opts.influxdb_options.as_ref().unwrap().enable);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_try_from_start_command_to_anymap() {
|
||||
let command = StartCommand {
|
||||
http_addr: None,
|
||||
rpc_addr: None,
|
||||
mysql_addr: None,
|
||||
postgres_addr: None,
|
||||
opentsdb_addr: None,
|
||||
config_file: None,
|
||||
influxdb_enable: false,
|
||||
enable_memory_catalog: false,
|
||||
tls_mode: None,
|
||||
tls_cert_path: None,
|
||||
tls_key_path: None,
|
||||
user_provider: Some("static_user_provider:cmd:test=test".to_string()),
|
||||
};
|
||||
|
||||
let plugins = load_frontend_plugins(&command.user_provider);
|
||||
assert!(plugins.is_ok());
|
||||
let plugins = plugins.unwrap();
|
||||
let provider = plugins.get::<UserProviderRef>();
|
||||
assert!(provider.is_some());
|
||||
let provider = provider.unwrap();
|
||||
let result = provider
|
||||
.auth(Identity::UserId("test", None), Password::PlainText("test"))
|
||||
.await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ edition = "2021"
|
||||
license = "Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
anymap = "1.0.0-beta.2"
|
||||
api = { path = "../api" }
|
||||
async-stream = "0.3"
|
||||
async-trait = "0.1"
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_telemetry::info;
|
||||
use anymap::AnyMap;
|
||||
use meta_client::MetaClientOpts;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use servers::auth::UserProviderRef;
|
||||
@@ -67,29 +67,18 @@ where
|
||||
{
|
||||
opts: FrontendOptions,
|
||||
instance: Option<T>,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
plugins: AnyMap,
|
||||
}
|
||||
|
||||
impl<T> Frontend<T>
|
||||
where
|
||||
T: FrontendInstance,
|
||||
{
|
||||
pub fn new(opts: FrontendOptions, instance: T) -> Self {
|
||||
impl<T: FrontendInstance> Frontend<T> {
|
||||
pub fn new(opts: FrontendOptions, instance: T, plugins: AnyMap) -> Self {
|
||||
Self {
|
||||
opts,
|
||||
instance: Some(instance),
|
||||
user_provider: None,
|
||||
plugins,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_user_provider(&mut self, user_provider: Option<UserProviderRef>) {
|
||||
info!(
|
||||
"Configured user provider: {:?}",
|
||||
user_provider.as_ref().map(|u| u.name())
|
||||
);
|
||||
self.user_provider = user_provider;
|
||||
}
|
||||
|
||||
pub async fn start(&mut self) -> Result<()> {
|
||||
let mut instance = self
|
||||
.instance
|
||||
@@ -100,6 +89,9 @@ where
|
||||
instance.start().await?;
|
||||
|
||||
let instance = Arc::new(instance);
|
||||
Services::start(&self.opts, instance, self.user_provider.clone()).await
|
||||
|
||||
let provider = self.plugins.get::<UserProviderRef>().cloned();
|
||||
|
||||
Services::start(&self.opts, instance, provider).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ use servers::tls::TlsOption;
|
||||
pub struct PostgresOptions {
|
||||
pub addr: String,
|
||||
pub runtime_size: usize,
|
||||
pub check_pwd: bool,
|
||||
#[serde(default = "Default::default")]
|
||||
pub tls: Arc<TlsOption>,
|
||||
}
|
||||
@@ -31,7 +30,6 @@ impl Default for PostgresOptions {
|
||||
Self {
|
||||
addr: "127.0.0.1:4003".to_string(),
|
||||
runtime_size: 2,
|
||||
check_pwd: false,
|
||||
tls: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,10 +99,9 @@ impl Services {
|
||||
|
||||
let pg_server = Box::new(PostgresServer::new(
|
||||
instance.clone(),
|
||||
opts.check_pwd,
|
||||
opts.tls.clone(),
|
||||
pg_io_runtime,
|
||||
user_provider,
|
||||
user_provider.clone(),
|
||||
)) as Box<dyn Server>;
|
||||
|
||||
Some((pg_server, pg_addr))
|
||||
@@ -132,6 +131,10 @@ impl Services {
|
||||
let http_addr = parse_addr(&http_options.addr)?;
|
||||
|
||||
let mut http_server = HttpServer::new(instance.clone(), http_options.clone());
|
||||
if let Some(user_provider) = user_provider {
|
||||
http_server.set_user_provider(user_provider);
|
||||
}
|
||||
|
||||
if opentsdb_server_and_addr.is_some() {
|
||||
http_server.set_opentsdb_handler(instance.clone());
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ common-runtime = { path = "../common/runtime" }
|
||||
common-telemetry = { path = "../common/telemetry" }
|
||||
common-time = { path = "../common/time" }
|
||||
datatypes = { path = "../datatypes" }
|
||||
digest = "0.10"
|
||||
futures = "0.3"
|
||||
hex = { version = "0.4" }
|
||||
http-body = "0.4"
|
||||
@@ -43,6 +44,7 @@ schemars = "0.8"
|
||||
serde = "1.0"
|
||||
serde_json = "1.0"
|
||||
session = { path = "../session" }
|
||||
sha1 = "0.10"
|
||||
snafu = { version = "0.7", features = ["backtraces"] }
|
||||
snap = "1"
|
||||
strum = { version = "0.24", features = ["derive"] }
|
||||
@@ -67,6 +69,7 @@ rand = "0.8"
|
||||
script = { path = "../script", features = ["python"] }
|
||||
serde_json = "1.0"
|
||||
table = { path = "../table" }
|
||||
tempdir = "0.3"
|
||||
tokio-postgres = "0.7"
|
||||
tokio-postgres-rustls = "0.9"
|
||||
tokio-test = "0.4"
|
||||
|
||||
@@ -12,13 +12,17 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
pub mod user_provider;
|
||||
|
||||
pub const DEFAULT_USERNAME: &str = "greptime";
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_error::prelude::ErrorExt;
|
||||
use common_error::status_code::StatusCode;
|
||||
use snafu::{Backtrace, ErrorCompat, Snafu};
|
||||
use snafu::{Backtrace, ErrorCompat, OptionExt, Snafu};
|
||||
|
||||
use crate::auth::user_provider::StaticUserProvider;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait UserProvider: Send + Sync {
|
||||
@@ -73,11 +77,40 @@ impl UserInfo {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn user_provider_from_option(opt: &String) -> Result<UserProviderRef, Error> {
|
||||
let (name, content) = opt.split_once(':').context(InvalidConfigSnafu {
|
||||
value: opt.to_string(),
|
||||
msg: "UserProviderOption must be in format `<option>:<value>`",
|
||||
})?;
|
||||
match name {
|
||||
user_provider::STATIC_USER_PROVIDER => {
|
||||
let provider =
|
||||
StaticUserProvider::try_from(content).map(|p| Arc::new(p) as UserProviderRef)?;
|
||||
Ok(provider)
|
||||
}
|
||||
_ => InvalidConfigSnafu {
|
||||
value: name.to_string(),
|
||||
msg: "Invalid UserProviderOption",
|
||||
}
|
||||
.fail(),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(visibility(pub))]
|
||||
pub enum Error {
|
||||
#[snafu(display("User not found"))]
|
||||
UserNotFound { backtrace: Backtrace },
|
||||
#[snafu(display("Invalid config value: {}, {}", value, msg))]
|
||||
InvalidConfig {
|
||||
value: String,
|
||||
msg: String,
|
||||
backtrace: Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Encounter IO error, source: {}", source))]
|
||||
IOErr { source: std::io::Error },
|
||||
|
||||
#[snafu(display("User not found, username: {}", username))]
|
||||
UserNotFound { username: String },
|
||||
|
||||
#[snafu(display("Unsupported password type: {}", password_type))]
|
||||
UnsupportedPasswordType {
|
||||
@@ -85,20 +118,23 @@ pub enum Error {
|
||||
backtrace: Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Username and password does not match"))]
|
||||
UserPasswordMismatch { backtrace: Backtrace },
|
||||
#[snafu(display("Username and password does not match, username: {}", username))]
|
||||
UserPasswordMismatch { username: String },
|
||||
}
|
||||
|
||||
impl ErrorExt for Error {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
Error::InvalidConfig { .. } => StatusCode::InvalidArguments,
|
||||
Error::IOErr { .. } => StatusCode::Internal,
|
||||
|
||||
Error::UserNotFound { .. } => StatusCode::UserNotFound,
|
||||
Error::UnsupportedPasswordType { .. } => StatusCode::UnsupportedPasswordType,
|
||||
Error::UserPasswordMismatch { .. } => StatusCode::UserPasswordMismatch,
|
||||
}
|
||||
}
|
||||
|
||||
fn backtrace_opt(&self) -> Option<&common_error::snafu::Backtrace> {
|
||||
fn backtrace_opt(&self) -> Option<&Backtrace> {
|
||||
ErrorCompat::backtrace(self)
|
||||
}
|
||||
|
||||
@@ -133,10 +169,16 @@ pub mod test {
|
||||
username: "greptime".to_string(),
|
||||
});
|
||||
} else {
|
||||
return super::UserPasswordMismatchSnafu {}.fail();
|
||||
return super::UserPasswordMismatchSnafu {
|
||||
username: username.to_string(),
|
||||
}
|
||||
.fail();
|
||||
}
|
||||
} else {
|
||||
return super::UserNotFoundSnafu {}.fail();
|
||||
return super::UserNotFoundSnafu {
|
||||
username: username.to_string(),
|
||||
}
|
||||
.fail();
|
||||
}
|
||||
}
|
||||
_ => super::UnsupportedPasswordTypeSnafu {
|
||||
|
||||
253
src/servers/src/auth/user_provider.rs
Normal file
253
src/servers/src/auth/user_provider.rs
Normal file
@@ -0,0 +1,253 @@
|
||||
// Copyright 2022 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::fs::File;
|
||||
use std::io;
|
||||
use std::io::BufRead;
|
||||
use std::path::Path;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use digest;
|
||||
use digest::Digest;
|
||||
use sha1::Sha1;
|
||||
use snafu::{ensure, OptionExt, ResultExt};
|
||||
|
||||
use crate::auth::{
|
||||
Error, HashedPassword, IOErrSnafu, Identity, InvalidConfigSnafu, Password, Salt,
|
||||
UnsupportedPasswordTypeSnafu, UserInfo, UserNotFoundSnafu, UserPasswordMismatchSnafu,
|
||||
UserProvider,
|
||||
};
|
||||
|
||||
pub const STATIC_USER_PROVIDER: &str = "static_user_provider";
|
||||
|
||||
impl TryFrom<&str> for StaticUserProvider {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
let (mode, content) = value.split_once(':').context(InvalidConfigSnafu {
|
||||
value: value.to_string(),
|
||||
msg: "StaticUserProviderOption must be in format `<option>:<value>`",
|
||||
})?;
|
||||
return match mode {
|
||||
"file" => {
|
||||
// check valid path
|
||||
let path = Path::new(content);
|
||||
ensure!(path.exists() && path.is_file(), InvalidConfigSnafu {
|
||||
value: content.to_string(),
|
||||
msg: "StaticUserProviderOption file must be a valid file path",
|
||||
});
|
||||
|
||||
let file = File::open(path).context(IOErrSnafu)?;
|
||||
let credential = io::BufReader::new(file)
|
||||
.lines()
|
||||
.filter_map(|line| line.ok())
|
||||
.filter_map(|line| {
|
||||
if let Some((k, v)) = line.split_once('=') {
|
||||
Some((k.to_string(), v.as_bytes().to_vec()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<HashMap<String, Vec<u8>>>();
|
||||
|
||||
ensure!(!credential.is_empty(), InvalidConfigSnafu {
|
||||
value: content.to_string(),
|
||||
msg: "StaticUserProviderOption file must contains at least one valid credential",
|
||||
});
|
||||
|
||||
Ok(StaticUserProvider { users: credential, })
|
||||
}
|
||||
"cmd" => content
|
||||
.split(',')
|
||||
.map(|kv| {
|
||||
let (k, v) = kv.split_once('=').context(InvalidConfigSnafu {
|
||||
value: kv.to_string(),
|
||||
msg: "StaticUserProviderOption cmd values must be in format `user=pwd[,user=pwd]`",
|
||||
})?;
|
||||
Ok((k.to_string(), v.as_bytes().to_vec()))
|
||||
})
|
||||
.collect::<Result<HashMap<String, Vec<u8>>, Error>>()
|
||||
.map(|users| StaticUserProvider { users }),
|
||||
_ => InvalidConfigSnafu {
|
||||
value: mode.to_string(),
|
||||
msg: "StaticUserProviderOption must be in format `file:<path>` or `cmd:<values>`",
|
||||
}
|
||||
.fail(),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
pub struct StaticUserProvider {
|
||||
users: HashMap<String, Vec<u8>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl UserProvider for StaticUserProvider {
|
||||
fn name(&self) -> &str {
|
||||
STATIC_USER_PROVIDER
|
||||
}
|
||||
|
||||
async fn auth(
|
||||
&self,
|
||||
input_id: Identity<'_>,
|
||||
input_pwd: Password<'_>,
|
||||
) -> Result<UserInfo, Error> {
|
||||
match input_id {
|
||||
Identity::UserId(username, _) => {
|
||||
let save_pwd = self.users.get(username).context(UserNotFoundSnafu {
|
||||
username: username.to_string(),
|
||||
})?;
|
||||
|
||||
match input_pwd {
|
||||
Password::PlainText(pwd) => {
|
||||
return if save_pwd == pwd.as_bytes() {
|
||||
Ok(UserInfo {
|
||||
username: username.to_string(),
|
||||
})
|
||||
} else {
|
||||
UserPasswordMismatchSnafu {
|
||||
username: username.to_string(),
|
||||
}
|
||||
.fail()
|
||||
}
|
||||
}
|
||||
Password::MysqlNativePassword(auth_data, salt) => {
|
||||
auth_mysql(auth_data, salt, username.to_string(), save_pwd)
|
||||
}
|
||||
Password::PgMD5(_, _) => UnsupportedPasswordTypeSnafu {
|
||||
password_type: "pg_md5",
|
||||
}
|
||||
.fail(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn auth_mysql(
|
||||
auth_data: HashedPassword,
|
||||
salt: Salt,
|
||||
username: String,
|
||||
save_pwd: &[u8],
|
||||
) -> Result<UserInfo, Error> {
|
||||
// ref: https://github.com/mysql/mysql-server/blob/a246bad76b9271cb4333634e954040a970222e0a/sql/auth/password.cc#L62
|
||||
let hash_stage_2 = double_sha1(save_pwd);
|
||||
let tmp = sha1_two(salt, &hash_stage_2);
|
||||
// xor auth_data and tmp
|
||||
let mut xor_result = [0u8; 20];
|
||||
for i in 0..20 {
|
||||
xor_result[i] = auth_data[i] ^ tmp[i];
|
||||
}
|
||||
let candidate_stage_2 = sha1_one(&xor_result);
|
||||
if candidate_stage_2 == hash_stage_2 {
|
||||
Ok(UserInfo { username })
|
||||
} else {
|
||||
UserPasswordMismatchSnafu { username }.fail()
|
||||
}
|
||||
}
|
||||
|
||||
fn sha1_two(input_1: &[u8], input_2: &[u8]) -> Vec<u8> {
|
||||
let mut hasher = Sha1::new();
|
||||
hasher.update(input_1);
|
||||
hasher.update(input_2);
|
||||
hasher.finalize().to_vec()
|
||||
}
|
||||
|
||||
fn sha1_one(data: &[u8]) -> Vec<u8> {
|
||||
let mut hasher = Sha1::new();
|
||||
hasher.update(data);
|
||||
hasher.finalize().to_vec()
|
||||
}
|
||||
|
||||
fn double_sha1(data: &[u8]) -> Vec<u8> {
|
||||
sha1_one(&sha1_one(data))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test {
|
||||
use std::fs::File;
|
||||
use std::io::{LineWriter, Write};
|
||||
|
||||
use tempdir::TempDir;
|
||||
|
||||
use crate::auth::user_provider::{double_sha1, sha1_one, sha1_two, StaticUserProvider};
|
||||
use crate::auth::{Identity, Password, UserProvider};
|
||||
|
||||
#[test]
|
||||
fn test_sha() {
|
||||
let sha_1_answer: Vec<u8> = vec![
|
||||
124, 74, 141, 9, 202, 55, 98, 175, 97, 229, 149, 32, 148, 61, 194, 100, 148, 248, 148,
|
||||
27,
|
||||
];
|
||||
let sha_1 = sha1_one("123456".as_bytes());
|
||||
assert_eq!(sha_1, sha_1_answer);
|
||||
|
||||
let double_sha1_answer: Vec<u8> = vec![
|
||||
107, 180, 131, 126, 183, 67, 41, 16, 94, 228, 86, 141, 218, 125, 198, 126, 210, 202,
|
||||
42, 217,
|
||||
];
|
||||
let double_sha1 = double_sha1("123456".as_bytes());
|
||||
assert_eq!(double_sha1, double_sha1_answer);
|
||||
|
||||
let sha1_2_answer: Vec<u8> = vec![
|
||||
132, 115, 215, 211, 99, 186, 164, 206, 168, 152, 217, 192, 117, 47, 240, 252, 142, 244,
|
||||
37, 204,
|
||||
];
|
||||
let sha1_2 = sha1_two("123456".as_bytes(), "654321".as_bytes());
|
||||
assert_eq!(sha1_2, sha1_2_answer);
|
||||
}
|
||||
|
||||
async fn test_auth(provider: &dyn UserProvider, username: &str, password: &str) {
|
||||
let re = provider
|
||||
.auth(
|
||||
Identity::UserId(username, None),
|
||||
Password::PlainText(password),
|
||||
)
|
||||
.await;
|
||||
assert!(re.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_inline_provider() {
|
||||
let provider = StaticUserProvider::try_from("cmd:root=123456,admin=654321").unwrap();
|
||||
test_auth(&provider, "root", "123456").await;
|
||||
test_auth(&provider, "admin", "654321").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_file_provider() {
|
||||
let dir = TempDir::new("test_file_provider").unwrap();
|
||||
let file_path = format!("{}/test_file_provider", dir.path().to_str().unwrap());
|
||||
{
|
||||
// write a tmp file
|
||||
let file = File::create(&file_path);
|
||||
assert!(file.is_ok());
|
||||
let file = file.unwrap();
|
||||
let mut lw = LineWriter::new(file);
|
||||
assert!(lw
|
||||
.write_all(
|
||||
b"root=123456
|
||||
admin=654321",
|
||||
)
|
||||
.is_ok());
|
||||
assert!(lw.flush().is_ok());
|
||||
}
|
||||
|
||||
let param = format!("file:{}", file_path);
|
||||
let provider = StaticUserProvider::try_from(param.as_str()).unwrap();
|
||||
test_auth(&provider, "root", "123456").await;
|
||||
test_auth(&provider, "admin", "654321").await;
|
||||
}
|
||||
}
|
||||
@@ -107,16 +107,14 @@ impl ServerParameterProvider for GreptimeDBStartupParameters {
|
||||
pub struct PgAuthStartupHandler {
|
||||
verifier: PgPwdVerifier,
|
||||
param_provider: GreptimeDBStartupParameters,
|
||||
with_pwd: bool,
|
||||
force_tls: bool,
|
||||
}
|
||||
|
||||
impl PgAuthStartupHandler {
|
||||
pub fn new(with_pwd: bool, user_provider: Option<UserProviderRef>, force_tls: bool) -> Self {
|
||||
pub fn new(user_provider: Option<UserProviderRef>, force_tls: bool) -> Self {
|
||||
PgAuthStartupHandler {
|
||||
verifier: PgPwdVerifier { user_provider },
|
||||
param_provider: GreptimeDBStartupParameters::new(),
|
||||
with_pwd,
|
||||
force_tls,
|
||||
}
|
||||
}
|
||||
@@ -151,7 +149,7 @@ impl StartupHandler for PgAuthStartupHandler {
|
||||
return Ok(());
|
||||
}
|
||||
auth::save_startup_parameters_to_metadata(client, startup);
|
||||
if self.with_pwd {
|
||||
if self.verifier.user_provider.is_some() {
|
||||
client.set_state(PgWireConnectionState::AuthenticationInProgress);
|
||||
client
|
||||
.send(PgWireBackendMessage::Authentication(
|
||||
|
||||
@@ -43,14 +43,12 @@ impl PostgresServer {
|
||||
/// Creates a new Postgres server with provided query_handler and async runtime
|
||||
pub fn new(
|
||||
query_handler: SqlQueryHandlerRef,
|
||||
check_pwd: bool,
|
||||
tls: Arc<TlsOption>,
|
||||
io_runtime: Arc<Runtime>,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
) -> PostgresServer {
|
||||
let postgres_handler = Arc::new(PostgresServerHandler::new(query_handler));
|
||||
let startup_handler = Arc::new(PgAuthStartupHandler::new(
|
||||
check_pwd,
|
||||
user_provider,
|
||||
tls.should_force_tls(),
|
||||
));
|
||||
|
||||
@@ -16,9 +16,10 @@ use std::sync::Arc;
|
||||
|
||||
use api::v1::InsertExpr;
|
||||
use async_trait::async_trait;
|
||||
use axum::Router;
|
||||
use axum::{http, Router};
|
||||
use axum_test_helper::TestClient;
|
||||
use common_query::Output;
|
||||
use servers::auth::user_provider::StaticUserProvider;
|
||||
use servers::error::Result;
|
||||
use servers::http::{HttpOptions, HttpServer};
|
||||
use servers::influxdb::InfluxdbRequest;
|
||||
@@ -53,6 +54,9 @@ impl SqlQueryHandler for DummyInstance {
|
||||
fn make_test_app(tx: mpsc::Sender<(String, String)>) -> Router {
|
||||
let instance = Arc::new(DummyInstance { tx });
|
||||
let mut server = HttpServer::new(instance.clone(), HttpOptions::default());
|
||||
let up = StaticUserProvider::try_from("cmd:greptime=greptime").unwrap();
|
||||
server.set_user_provider(Arc::new(up));
|
||||
|
||||
server.set_influxdb_handler(instance);
|
||||
server.make_app()
|
||||
}
|
||||
@@ -68,6 +72,10 @@ async fn test_influxdb_write() {
|
||||
let result = client
|
||||
.post("/v1/influxdb/write")
|
||||
.body("monitor,host=host1 cpu=1.2 1664370459457010101")
|
||||
.header(
|
||||
http::header::AUTHORIZATION,
|
||||
"basic Z3JlcHRpbWU6Z3JlcHRpbWU=",
|
||||
)
|
||||
.send()
|
||||
.await;
|
||||
assert_eq!(result.status(), 204);
|
||||
@@ -76,6 +84,10 @@ async fn test_influxdb_write() {
|
||||
let result = client
|
||||
.post("/v1/influxdb/write?db=influxdb")
|
||||
.body("monitor,host=host1 cpu=1.2 1664370459457010101")
|
||||
.header(
|
||||
http::header::AUTHORIZATION,
|
||||
"basic Z3JlcHRpbWU6Z3JlcHRpbWU=",
|
||||
)
|
||||
.send()
|
||||
.await;
|
||||
assert_eq!(result.status(), 204);
|
||||
@@ -85,6 +97,10 @@ async fn test_influxdb_write() {
|
||||
let result = client
|
||||
.post("/v1/influxdb/write")
|
||||
.body("monitor, host=host1 cpu=1.2 1664370459457010101")
|
||||
.header(
|
||||
http::header::AUTHORIZATION,
|
||||
"basic Z3JlcHRpbWU6Z3JlcHRpbWU=",
|
||||
)
|
||||
.send()
|
||||
.await;
|
||||
assert_eq!(result.status(), 400);
|
||||
|
||||
@@ -23,6 +23,7 @@ use mysql_async::prelude::*;
|
||||
use mysql_async::SslOpts;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::Rng;
|
||||
use servers::auth::user_provider::StaticUserProvider;
|
||||
use servers::error::Result;
|
||||
use servers::mysql::server::MysqlServer;
|
||||
use servers::server::Server;
|
||||
@@ -42,11 +43,13 @@ fn create_mysql_server(table: MemTable, tls: Arc<TlsOption>) -> Result<Box<dyn S
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let provider = StaticUserProvider::try_from("cmd:greptime=greptime").unwrap();
|
||||
|
||||
Ok(MysqlServer::create_server(
|
||||
query_handler,
|
||||
io_runtime,
|
||||
tls,
|
||||
None,
|
||||
Some(Arc::new(provider)),
|
||||
))
|
||||
}
|
||||
|
||||
@@ -85,10 +88,10 @@ async fn test_shutdown_mysql_server() -> Result<()> {
|
||||
let server_port = server_addr.port();
|
||||
|
||||
let mut join_handles = vec![];
|
||||
for index in 0..2 {
|
||||
for _ in 0..2 {
|
||||
join_handles.push(tokio::spawn(async move {
|
||||
for _ in 0..1000 {
|
||||
match create_connection(server_port, index == 1, false).await {
|
||||
match create_connection(server_port, false).await {
|
||||
Ok(mut connection) => {
|
||||
let result: u32 = connection
|
||||
.query_first("SELECT uint32s FROM numbers LIMIT 1")
|
||||
@@ -125,7 +128,7 @@ async fn test_query_all_datatypes() -> Result<()> {
|
||||
let server_tls = Arc::new(TlsOption::default());
|
||||
let client_tls = false;
|
||||
|
||||
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
|
||||
do_test_query_all_datatypes(server_tls, client_tls).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -138,7 +141,7 @@ async fn test_server_prefer_secure_client_plain() -> Result<()> {
|
||||
});
|
||||
|
||||
let client_tls = false;
|
||||
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
|
||||
do_test_query_all_datatypes(server_tls, client_tls).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -151,7 +154,7 @@ async fn test_server_prefer_secure_client_secure() -> Result<()> {
|
||||
});
|
||||
|
||||
let client_tls = true;
|
||||
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
|
||||
do_test_query_all_datatypes(server_tls, client_tls).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -164,7 +167,7 @@ async fn test_server_require_secure_client_secure() -> Result<()> {
|
||||
});
|
||||
|
||||
let client_tls = true;
|
||||
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
|
||||
do_test_query_all_datatypes(server_tls, client_tls).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -194,16 +197,12 @@ async fn test_server_required_secure_client_plain() -> Result<()> {
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = mysql_server.start(listening).await.unwrap();
|
||||
|
||||
let r = create_connection(server_addr.port(), client_tls, false).await;
|
||||
let r = create_connection(server_addr.port(), client_tls).await;
|
||||
assert!(r.is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn do_test_query_all_datatypes(
|
||||
server_tls: Arc<TlsOption>,
|
||||
with_pwd: bool,
|
||||
client_tls: bool,
|
||||
) -> Result<()> {
|
||||
async fn do_test_query_all_datatypes(server_tls: Arc<TlsOption>, client_tls: bool) -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let TestingData {
|
||||
column_schemas,
|
||||
@@ -220,7 +219,7 @@ async fn do_test_query_all_datatypes(
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = mysql_server.start(listening).await.unwrap();
|
||||
|
||||
let mut connection = create_connection(server_addr.port(), client_tls, with_pwd)
|
||||
let mut connection = create_connection(server_addr.port(), client_tls)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -258,13 +257,11 @@ async fn test_query_concurrently() -> Result<()> {
|
||||
let threads = 4;
|
||||
let expect_executed_queries_per_worker = 1000;
|
||||
let mut join_handles = vec![];
|
||||
for index in 0..threads {
|
||||
for _ in 0..threads {
|
||||
join_handles.push(tokio::spawn(async move {
|
||||
let mut rand: StdRng = rand::SeedableRng::from_entropy();
|
||||
|
||||
let mut connection = create_connection(server_port, index % 2 == 0, false)
|
||||
.await
|
||||
.unwrap();
|
||||
let mut connection = create_connection(server_port, false).await.unwrap();
|
||||
for _ in 0..expect_executed_queries_per_worker {
|
||||
let expected: u32 = rand.gen_range(0..100);
|
||||
let result: u32 = connection
|
||||
@@ -279,9 +276,7 @@ async fn test_query_concurrently() -> Result<()> {
|
||||
|
||||
let should_recreate_conn = expected == 1;
|
||||
if should_recreate_conn {
|
||||
connection = create_connection(server_port, index % 2 == 0, false)
|
||||
.await
|
||||
.unwrap();
|
||||
connection = create_connection(server_port, false).await.unwrap();
|
||||
}
|
||||
}
|
||||
expect_executed_queries_per_worker
|
||||
@@ -295,16 +290,14 @@ async fn test_query_concurrently() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_connection(
|
||||
port: u16,
|
||||
with_pwd: bool,
|
||||
ssl: bool,
|
||||
) -> mysql_async::Result<mysql_async::Conn> {
|
||||
async fn create_connection(port: u16, ssl: bool) -> mysql_async::Result<mysql_async::Conn> {
|
||||
let mut opts = mysql_async::OptsBuilder::default()
|
||||
.ip_or_hostname("127.0.0.1")
|
||||
.tcp_port(port)
|
||||
.prefer_socket(false)
|
||||
.wait_timeout(Some(1000));
|
||||
.wait_timeout(Some(1000))
|
||||
.user(Some("greptime".to_string()))
|
||||
.pass(Some("greptime".to_string()));
|
||||
|
||||
if ssl {
|
||||
let ssl_opts = SslOpts::default()
|
||||
@@ -313,9 +306,5 @@ async fn create_connection(
|
||||
opts = opts.ssl_opts(ssl_opts)
|
||||
}
|
||||
|
||||
if with_pwd {
|
||||
opts = opts.pass(Some("default_pwd".to_string()));
|
||||
}
|
||||
|
||||
mysql_async::Conn::new(opts).await
|
||||
}
|
||||
|
||||
@@ -22,6 +22,8 @@ use rand::rngs::StdRng;
|
||||
use rand::Rng;
|
||||
use rustls::client::{ServerCertVerified, ServerCertVerifier};
|
||||
use rustls::{Certificate, Error, ServerName};
|
||||
use servers::auth::user_provider::StaticUserProvider;
|
||||
use servers::auth::UserProviderRef;
|
||||
use servers::error::Result;
|
||||
use servers::postgres::PostgresServer;
|
||||
use servers::server::Server;
|
||||
@@ -44,12 +46,19 @@ fn create_postgres_server(
|
||||
.build()
|
||||
.unwrap(),
|
||||
);
|
||||
let user_provider: Option<UserProviderRef> = if check_pwd {
|
||||
Some(Arc::new(
|
||||
StaticUserProvider::try_from("cmd:test_user=test_pwd").unwrap(),
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Box::new(PostgresServer::new(
|
||||
query_handler,
|
||||
check_pwd,
|
||||
tls,
|
||||
io_runtime,
|
||||
None,
|
||||
user_provider,
|
||||
)))
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user