feat: impl static_user_provider (#739)

* feat: add MemUserProvider and impl auth

* feat: impl user_provider option in fe and standalone mode

* chore: add file impl for mem provider

* chore: remove mem opts

* chore: minor change

* chore: refac pg server to use user_provider as indicator for using pwd auth

* chore: fix test

* chore: extract common code

* chore: add unit test

* chore: rebase develop

* chore: add user provider to http server

* chore: minor rename

* chore: change to ref when convert to anymap

* chore: fix according to clippy

* chore: remove clone on startcommand

* chore: fix cr issue

* chore: update tempdir use

* chore: change TryFrom to normal func while parsing anymap

* chore: minor change

* chore: remove to_lowercase
This commit is contained in:
shuiyisong
2022-12-14 16:38:29 +08:00
committed by GitHub
parent 756c068166
commit fda9e80cbf
17 changed files with 482 additions and 73 deletions

11
Cargo.lock generated
View File

@@ -127,6 +127,12 @@ version = "1.0.65"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "98161a4e3e2184da77bb14f02184cdd111e83bbbcc9979dfee3c44b9a85f5602"
[[package]]
name = "anymap"
version = "1.0.0-beta.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f1f8f5a6f3d50d89e3797d7593a50f96bb2aaa20ca0cc7be1fb673232c91d72"
[[package]]
name = "api"
version = "0.1.0"
@@ -1276,6 +1282,7 @@ dependencies = [
name = "cmd"
version = "0.1.0"
dependencies = [
"anymap",
"build-data",
"clap 3.2.22",
"common-error",
@@ -2475,6 +2482,7 @@ dependencies = [
name = "frontend"
version = "0.1.0"
dependencies = [
"anymap",
"api",
"async-stream",
"async-trait",
@@ -6114,6 +6122,7 @@ dependencies = [
"common-telemetry",
"common-time",
"datatypes",
"digest",
"futures",
"hex",
"http-body",
@@ -6138,10 +6147,12 @@ dependencies = [
"serde",
"serde_json",
"session",
"sha1",
"snafu",
"snap",
"strum 0.24.1",
"table",
"tempdir",
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",

View File

@@ -10,6 +10,7 @@ name = "greptime"
path = "src/bin/greptime.rs"
[dependencies]
anymap = "1.0.0-beta.2"
clap = { version = "3.1", features = ["derive"] }
common-error = { path = "../common/error" }
common-telemetry = { path = "../common/telemetry", features = [

View File

@@ -55,6 +55,12 @@ pub enum Error {
#[snafu(display("Illegal config: {}", msg))]
IllegalConfig { msg: String, backtrace: Backtrace },
#[snafu(display("Illegal auth config: {}", source))]
IllegalAuthConfig {
#[snafu(backtrace)]
source: servers::auth::Error,
},
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -69,6 +75,7 @@ impl ErrorExt for Error {
StatusCode::InvalidArguments
}
Error::IllegalConfig { .. } => StatusCode::InvalidArguments,
Error::IllegalAuthConfig { .. } => StatusCode::InvalidArguments,
}
}

View File

@@ -14,6 +14,7 @@
use std::sync::Arc;
use anymap::AnyMap;
use clap::Parser;
use frontend::frontend::{Frontend, FrontendOptions};
use frontend::grpc::GrpcOptions;
@@ -23,12 +24,13 @@ use frontend::mysql::MysqlOptions;
use frontend::opentsdb::OpentsdbOptions;
use frontend::postgres::PostgresOptions;
use meta_client::MetaClientOpts;
use servers::auth::UserProviderRef;
use servers::http::HttpOptions;
use servers::tls::{TlsMode, TlsOption};
use servers::Mode;
use servers::{auth, Mode};
use snafu::ResultExt;
use crate::error::{self, Result};
use crate::error::{self, IllegalAuthConfigSnafu, Result};
use crate::toml_loader;
#[derive(Parser)]
@@ -80,21 +82,35 @@ pub struct StartCommand {
tls_cert_path: Option<String>,
#[clap(long)]
tls_key_path: Option<String>,
#[clap(long)]
user_provider: Option<String>,
}
impl StartCommand {
async fn run(self) -> Result<()> {
let plugins = load_frontend_plugins(&self.user_provider)?;
let opts: FrontendOptions = self.try_into()?;
let mut frontend = Frontend::new(
opts.clone(),
Instance::try_new_distributed(&opts)
.await
.context(error::StartFrontendSnafu)?,
plugins,
);
frontend.start().await.context(error::StartFrontendSnafu)
}
}
pub fn load_frontend_plugins(user_provider: &Option<String>) -> Result<AnyMap> {
let mut plugins = AnyMap::new();
if let Some(provider) = user_provider {
let provider = auth::user_provider_from_option(provider).context(IllegalAuthConfigSnafu)?;
plugins.insert::<UserProviderRef>(provider);
}
Ok(plugins)
}
impl TryFrom<StartCommand> for FrontendOptions {
type Error = error::Error;
@@ -160,6 +176,8 @@ impl TryFrom<StartCommand> for FrontendOptions {
mod tests {
use std::time::Duration;
use servers::auth::{Identity, Password, UserProviderRef};
use super::*;
#[test]
@@ -176,6 +194,7 @@ mod tests {
tls_mode: None,
tls_cert_path: None,
tls_key_path: None,
user_provider: None,
};
let opts: FrontendOptions = command.try_into().unwrap();
@@ -228,6 +247,7 @@ mod tests {
tls_mode: None,
tls_cert_path: None,
tls_key_path: None,
user_provider: None,
};
let fe_opts = FrontendOptions::try_from(command).unwrap();
@@ -241,4 +261,34 @@ mod tests {
fe_opts.http_options.as_ref().unwrap().timeout
);
}
#[tokio::test]
async fn test_try_from_start_command_to_anymap() {
let command = StartCommand {
http_addr: None,
grpc_addr: None,
mysql_addr: None,
postgres_addr: None,
opentsdb_addr: None,
influxdb_enable: None,
config_file: None,
metasrv_addr: None,
tls_mode: None,
tls_cert_path: None,
tls_key_path: None,
user_provider: Some("static_user_provider:cmd:test=test".to_string()),
};
let plugins = load_frontend_plugins(&command.user_provider);
assert!(plugins.is_ok());
let plugins = plugins.unwrap();
let provider = plugins.get::<UserProviderRef>();
assert!(provider.is_some());
let provider = provider.unwrap();
let result = provider
.auth(Identity::UserId("test", None), Password::PlainText("test"))
.await;
assert!(result.is_ok());
}
}

View File

@@ -14,6 +14,7 @@
use std::sync::Arc;
use anymap::AnyMap;
use clap::Parser;
use common_telemetry::info;
use datanode::datanode::{Datanode, DatanodeOptions, ObjectStoreConfig};
@@ -33,6 +34,7 @@ use servers::Mode;
use snafu::ResultExt;
use crate::error::{Error, IllegalConfigSnafu, Result, StartDatanodeSnafu, StartFrontendSnafu};
use crate::frontend::load_frontend_plugins;
use crate::toml_loader;
#[derive(Parser)]
@@ -142,12 +144,15 @@ struct StartCommand {
tls_cert_path: Option<String>,
#[clap(long)]
tls_key_path: Option<String>,
#[clap(long)]
user_provider: Option<String>,
}
impl StartCommand {
async fn run(self) -> Result<()> {
let enable_memory_catalog = self.enable_memory_catalog;
let config_file = self.config_file.clone();
let plugins = load_frontend_plugins(&self.user_provider)?;
let fe_opts = FrontendOptions::try_from(self)?;
let dn_opts: DatanodeOptions = {
let mut opts: StandaloneOptions = if let Some(path) = config_file {
@@ -167,7 +172,7 @@ impl StartCommand {
let mut datanode = Datanode::new(dn_opts.clone())
.await
.context(StartDatanodeSnafu)?;
let mut frontend = build_frontend(fe_opts, datanode.get_instance()).await?;
let mut frontend = build_frontend(fe_opts, plugins, datanode.get_instance()).await?;
// Start datanode instance before starting services, to avoid requests come in before internal components are started.
datanode
@@ -184,12 +189,13 @@ impl StartCommand {
/// Build frontend instance in standalone mode
async fn build_frontend(
fe_opts: FrontendOptions,
plugins: AnyMap,
datanode_instance: InstanceRef,
) -> Result<Frontend<FeInstance>> {
let mut frontend_instance = FeInstance::new_standalone(datanode_instance.clone());
frontend_instance.set_catalog_manager(datanode_instance.catalog_manager().clone());
frontend_instance.set_script_handler(datanode_instance);
Ok(Frontend::new(fe_opts, frontend_instance))
Ok(Frontend::new(fe_opts, frontend_instance, plugins))
}
impl TryFrom<StartCommand> for FrontendOptions {
@@ -274,6 +280,8 @@ impl TryFrom<StartCommand> for FrontendOptions {
mod tests {
use std::time::Duration;
use servers::auth::{Identity, Password, UserProviderRef};
use super::*;
#[test]
@@ -293,6 +301,7 @@ mod tests {
tls_mode: None,
tls_cert_path: None,
tls_key_path: None,
user_provider: None,
};
let fe_opts = FrontendOptions::try_from(cmd).unwrap();
@@ -316,4 +325,33 @@ mod tests {
assert_eq!(2, fe_opts.mysql_options.as_ref().unwrap().runtime_size);
assert!(fe_opts.influxdb_options.as_ref().unwrap().enable);
}
#[tokio::test]
async fn test_try_from_start_command_to_anymap() {
let command = StartCommand {
http_addr: None,
rpc_addr: None,
mysql_addr: None,
postgres_addr: None,
opentsdb_addr: None,
config_file: None,
influxdb_enable: false,
enable_memory_catalog: false,
tls_mode: None,
tls_cert_path: None,
tls_key_path: None,
user_provider: Some("static_user_provider:cmd:test=test".to_string()),
};
let plugins = load_frontend_plugins(&command.user_provider);
assert!(plugins.is_ok());
let plugins = plugins.unwrap();
let provider = plugins.get::<UserProviderRef>();
assert!(provider.is_some());
let provider = provider.unwrap();
let result = provider
.auth(Identity::UserId("test", None), Password::PlainText("test"))
.await;
assert!(result.is_ok());
}
}

View File

@@ -5,6 +5,7 @@ edition = "2021"
license = "Apache-2.0"
[dependencies]
anymap = "1.0.0-beta.2"
api = { path = "../api" }
async-stream = "0.3"
async-trait = "0.1"

View File

@@ -14,7 +14,7 @@
use std::sync::Arc;
use common_telemetry::info;
use anymap::AnyMap;
use meta_client::MetaClientOpts;
use serde::{Deserialize, Serialize};
use servers::auth::UserProviderRef;
@@ -67,29 +67,18 @@ where
{
opts: FrontendOptions,
instance: Option<T>,
user_provider: Option<UserProviderRef>,
plugins: AnyMap,
}
impl<T> Frontend<T>
where
T: FrontendInstance,
{
pub fn new(opts: FrontendOptions, instance: T) -> Self {
impl<T: FrontendInstance> Frontend<T> {
pub fn new(opts: FrontendOptions, instance: T, plugins: AnyMap) -> Self {
Self {
opts,
instance: Some(instance),
user_provider: None,
plugins,
}
}
pub fn set_user_provider(&mut self, user_provider: Option<UserProviderRef>) {
info!(
"Configured user provider: {:?}",
user_provider.as_ref().map(|u| u.name())
);
self.user_provider = user_provider;
}
pub async fn start(&mut self) -> Result<()> {
let mut instance = self
.instance
@@ -100,6 +89,9 @@ where
instance.start().await?;
let instance = Arc::new(instance);
Services::start(&self.opts, instance, self.user_provider.clone()).await
let provider = self.plugins.get::<UserProviderRef>().cloned();
Services::start(&self.opts, instance, provider).await
}
}

View File

@@ -21,7 +21,6 @@ use servers::tls::TlsOption;
pub struct PostgresOptions {
pub addr: String,
pub runtime_size: usize,
pub check_pwd: bool,
#[serde(default = "Default::default")]
pub tls: Arc<TlsOption>,
}
@@ -31,7 +30,6 @@ impl Default for PostgresOptions {
Self {
addr: "127.0.0.1:4003".to_string(),
runtime_size: 2,
check_pwd: false,
tls: Default::default(),
}
}

View File

@@ -99,10 +99,9 @@ impl Services {
let pg_server = Box::new(PostgresServer::new(
instance.clone(),
opts.check_pwd,
opts.tls.clone(),
pg_io_runtime,
user_provider,
user_provider.clone(),
)) as Box<dyn Server>;
Some((pg_server, pg_addr))
@@ -132,6 +131,10 @@ impl Services {
let http_addr = parse_addr(&http_options.addr)?;
let mut http_server = HttpServer::new(instance.clone(), http_options.clone());
if let Some(user_provider) = user_provider {
http_server.set_user_provider(user_provider);
}
if opentsdb_server_and_addr.is_some() {
http_server.set_opentsdb_handler(instance.clone());
}

View File

@@ -22,6 +22,7 @@ common-runtime = { path = "../common/runtime" }
common-telemetry = { path = "../common/telemetry" }
common-time = { path = "../common/time" }
datatypes = { path = "../datatypes" }
digest = "0.10"
futures = "0.3"
hex = { version = "0.4" }
http-body = "0.4"
@@ -43,6 +44,7 @@ schemars = "0.8"
serde = "1.0"
serde_json = "1.0"
session = { path = "../session" }
sha1 = "0.10"
snafu = { version = "0.7", features = ["backtraces"] }
snap = "1"
strum = { version = "0.24", features = ["derive"] }
@@ -67,6 +69,7 @@ rand = "0.8"
script = { path = "../script", features = ["python"] }
serde_json = "1.0"
table = { path = "../table" }
tempdir = "0.3"
tokio-postgres = "0.7"
tokio-postgres-rustls = "0.9"
tokio-test = "0.4"

View File

@@ -12,13 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod user_provider;
pub const DEFAULT_USERNAME: &str = "greptime";
use std::sync::Arc;
use common_error::prelude::ErrorExt;
use common_error::status_code::StatusCode;
use snafu::{Backtrace, ErrorCompat, Snafu};
use snafu::{Backtrace, ErrorCompat, OptionExt, Snafu};
use crate::auth::user_provider::StaticUserProvider;
#[async_trait::async_trait]
pub trait UserProvider: Send + Sync {
@@ -73,11 +77,40 @@ impl UserInfo {
}
}
pub fn user_provider_from_option(opt: &String) -> Result<UserProviderRef, Error> {
let (name, content) = opt.split_once(':').context(InvalidConfigSnafu {
value: opt.to_string(),
msg: "UserProviderOption must be in format `<option>:<value>`",
})?;
match name {
user_provider::STATIC_USER_PROVIDER => {
let provider =
StaticUserProvider::try_from(content).map(|p| Arc::new(p) as UserProviderRef)?;
Ok(provider)
}
_ => InvalidConfigSnafu {
value: name.to_string(),
msg: "Invalid UserProviderOption",
}
.fail(),
}
}
#[derive(Debug, Snafu)]
#[snafu(visibility(pub))]
pub enum Error {
#[snafu(display("User not found"))]
UserNotFound { backtrace: Backtrace },
#[snafu(display("Invalid config value: {}, {}", value, msg))]
InvalidConfig {
value: String,
msg: String,
backtrace: Backtrace,
},
#[snafu(display("Encounter IO error, source: {}", source))]
IOErr { source: std::io::Error },
#[snafu(display("User not found, username: {}", username))]
UserNotFound { username: String },
#[snafu(display("Unsupported password type: {}", password_type))]
UnsupportedPasswordType {
@@ -85,20 +118,23 @@ pub enum Error {
backtrace: Backtrace,
},
#[snafu(display("Username and password does not match"))]
UserPasswordMismatch { backtrace: Backtrace },
#[snafu(display("Username and password does not match, username: {}", username))]
UserPasswordMismatch { username: String },
}
impl ErrorExt for Error {
fn status_code(&self) -> StatusCode {
match self {
Error::InvalidConfig { .. } => StatusCode::InvalidArguments,
Error::IOErr { .. } => StatusCode::Internal,
Error::UserNotFound { .. } => StatusCode::UserNotFound,
Error::UnsupportedPasswordType { .. } => StatusCode::UnsupportedPasswordType,
Error::UserPasswordMismatch { .. } => StatusCode::UserPasswordMismatch,
}
}
fn backtrace_opt(&self) -> Option<&common_error::snafu::Backtrace> {
fn backtrace_opt(&self) -> Option<&Backtrace> {
ErrorCompat::backtrace(self)
}
@@ -133,10 +169,16 @@ pub mod test {
username: "greptime".to_string(),
});
} else {
return super::UserPasswordMismatchSnafu {}.fail();
return super::UserPasswordMismatchSnafu {
username: username.to_string(),
}
.fail();
}
} else {
return super::UserNotFoundSnafu {}.fail();
return super::UserNotFoundSnafu {
username: username.to_string(),
}
.fail();
}
}
_ => super::UnsupportedPasswordTypeSnafu {

View File

@@ -0,0 +1,253 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::fs::File;
use std::io;
use std::io::BufRead;
use std::path::Path;
use async_trait::async_trait;
use digest;
use digest::Digest;
use sha1::Sha1;
use snafu::{ensure, OptionExt, ResultExt};
use crate::auth::{
Error, HashedPassword, IOErrSnafu, Identity, InvalidConfigSnafu, Password, Salt,
UnsupportedPasswordTypeSnafu, UserInfo, UserNotFoundSnafu, UserPasswordMismatchSnafu,
UserProvider,
};
pub const STATIC_USER_PROVIDER: &str = "static_user_provider";
impl TryFrom<&str> for StaticUserProvider {
type Error = Error;
fn try_from(value: &str) -> Result<Self, Self::Error> {
let (mode, content) = value.split_once(':').context(InvalidConfigSnafu {
value: value.to_string(),
msg: "StaticUserProviderOption must be in format `<option>:<value>`",
})?;
return match mode {
"file" => {
// check valid path
let path = Path::new(content);
ensure!(path.exists() && path.is_file(), InvalidConfigSnafu {
value: content.to_string(),
msg: "StaticUserProviderOption file must be a valid file path",
});
let file = File::open(path).context(IOErrSnafu)?;
let credential = io::BufReader::new(file)
.lines()
.filter_map(|line| line.ok())
.filter_map(|line| {
if let Some((k, v)) = line.split_once('=') {
Some((k.to_string(), v.as_bytes().to_vec()))
} else {
None
}
})
.collect::<HashMap<String, Vec<u8>>>();
ensure!(!credential.is_empty(), InvalidConfigSnafu {
value: content.to_string(),
msg: "StaticUserProviderOption file must contains at least one valid credential",
});
Ok(StaticUserProvider { users: credential, })
}
"cmd" => content
.split(',')
.map(|kv| {
let (k, v) = kv.split_once('=').context(InvalidConfigSnafu {
value: kv.to_string(),
msg: "StaticUserProviderOption cmd values must be in format `user=pwd[,user=pwd]`",
})?;
Ok((k.to_string(), v.as_bytes().to_vec()))
})
.collect::<Result<HashMap<String, Vec<u8>>, Error>>()
.map(|users| StaticUserProvider { users }),
_ => InvalidConfigSnafu {
value: mode.to_string(),
msg: "StaticUserProviderOption must be in format `file:<path>` or `cmd:<values>`",
}
.fail(),
};
}
}
pub struct StaticUserProvider {
users: HashMap<String, Vec<u8>>,
}
#[async_trait]
impl UserProvider for StaticUserProvider {
fn name(&self) -> &str {
STATIC_USER_PROVIDER
}
async fn auth(
&self,
input_id: Identity<'_>,
input_pwd: Password<'_>,
) -> Result<UserInfo, Error> {
match input_id {
Identity::UserId(username, _) => {
let save_pwd = self.users.get(username).context(UserNotFoundSnafu {
username: username.to_string(),
})?;
match input_pwd {
Password::PlainText(pwd) => {
return if save_pwd == pwd.as_bytes() {
Ok(UserInfo {
username: username.to_string(),
})
} else {
UserPasswordMismatchSnafu {
username: username.to_string(),
}
.fail()
}
}
Password::MysqlNativePassword(auth_data, salt) => {
auth_mysql(auth_data, salt, username.to_string(), save_pwd)
}
Password::PgMD5(_, _) => UnsupportedPasswordTypeSnafu {
password_type: "pg_md5",
}
.fail(),
}
}
}
}
}
fn auth_mysql(
auth_data: HashedPassword,
salt: Salt,
username: String,
save_pwd: &[u8],
) -> Result<UserInfo, Error> {
// ref: https://github.com/mysql/mysql-server/blob/a246bad76b9271cb4333634e954040a970222e0a/sql/auth/password.cc#L62
let hash_stage_2 = double_sha1(save_pwd);
let tmp = sha1_two(salt, &hash_stage_2);
// xor auth_data and tmp
let mut xor_result = [0u8; 20];
for i in 0..20 {
xor_result[i] = auth_data[i] ^ tmp[i];
}
let candidate_stage_2 = sha1_one(&xor_result);
if candidate_stage_2 == hash_stage_2 {
Ok(UserInfo { username })
} else {
UserPasswordMismatchSnafu { username }.fail()
}
}
fn sha1_two(input_1: &[u8], input_2: &[u8]) -> Vec<u8> {
let mut hasher = Sha1::new();
hasher.update(input_1);
hasher.update(input_2);
hasher.finalize().to_vec()
}
fn sha1_one(data: &[u8]) -> Vec<u8> {
let mut hasher = Sha1::new();
hasher.update(data);
hasher.finalize().to_vec()
}
fn double_sha1(data: &[u8]) -> Vec<u8> {
sha1_one(&sha1_one(data))
}
#[cfg(test)]
pub mod test {
use std::fs::File;
use std::io::{LineWriter, Write};
use tempdir::TempDir;
use crate::auth::user_provider::{double_sha1, sha1_one, sha1_two, StaticUserProvider};
use crate::auth::{Identity, Password, UserProvider};
#[test]
fn test_sha() {
let sha_1_answer: Vec<u8> = vec![
124, 74, 141, 9, 202, 55, 98, 175, 97, 229, 149, 32, 148, 61, 194, 100, 148, 248, 148,
27,
];
let sha_1 = sha1_one("123456".as_bytes());
assert_eq!(sha_1, sha_1_answer);
let double_sha1_answer: Vec<u8> = vec![
107, 180, 131, 126, 183, 67, 41, 16, 94, 228, 86, 141, 218, 125, 198, 126, 210, 202,
42, 217,
];
let double_sha1 = double_sha1("123456".as_bytes());
assert_eq!(double_sha1, double_sha1_answer);
let sha1_2_answer: Vec<u8> = vec![
132, 115, 215, 211, 99, 186, 164, 206, 168, 152, 217, 192, 117, 47, 240, 252, 142, 244,
37, 204,
];
let sha1_2 = sha1_two("123456".as_bytes(), "654321".as_bytes());
assert_eq!(sha1_2, sha1_2_answer);
}
async fn test_auth(provider: &dyn UserProvider, username: &str, password: &str) {
let re = provider
.auth(
Identity::UserId(username, None),
Password::PlainText(password),
)
.await;
assert!(re.is_ok());
}
#[tokio::test]
async fn test_inline_provider() {
let provider = StaticUserProvider::try_from("cmd:root=123456,admin=654321").unwrap();
test_auth(&provider, "root", "123456").await;
test_auth(&provider, "admin", "654321").await;
}
#[tokio::test]
async fn test_file_provider() {
let dir = TempDir::new("test_file_provider").unwrap();
let file_path = format!("{}/test_file_provider", dir.path().to_str().unwrap());
{
// write a tmp file
let file = File::create(&file_path);
assert!(file.is_ok());
let file = file.unwrap();
let mut lw = LineWriter::new(file);
assert!(lw
.write_all(
b"root=123456
admin=654321",
)
.is_ok());
assert!(lw.flush().is_ok());
}
let param = format!("file:{}", file_path);
let provider = StaticUserProvider::try_from(param.as_str()).unwrap();
test_auth(&provider, "root", "123456").await;
test_auth(&provider, "admin", "654321").await;
}
}

View File

@@ -107,16 +107,14 @@ impl ServerParameterProvider for GreptimeDBStartupParameters {
pub struct PgAuthStartupHandler {
verifier: PgPwdVerifier,
param_provider: GreptimeDBStartupParameters,
with_pwd: bool,
force_tls: bool,
}
impl PgAuthStartupHandler {
pub fn new(with_pwd: bool, user_provider: Option<UserProviderRef>, force_tls: bool) -> Self {
pub fn new(user_provider: Option<UserProviderRef>, force_tls: bool) -> Self {
PgAuthStartupHandler {
verifier: PgPwdVerifier { user_provider },
param_provider: GreptimeDBStartupParameters::new(),
with_pwd,
force_tls,
}
}
@@ -151,7 +149,7 @@ impl StartupHandler for PgAuthStartupHandler {
return Ok(());
}
auth::save_startup_parameters_to_metadata(client, startup);
if self.with_pwd {
if self.verifier.user_provider.is_some() {
client.set_state(PgWireConnectionState::AuthenticationInProgress);
client
.send(PgWireBackendMessage::Authentication(

View File

@@ -43,14 +43,12 @@ impl PostgresServer {
/// Creates a new Postgres server with provided query_handler and async runtime
pub fn new(
query_handler: SqlQueryHandlerRef,
check_pwd: bool,
tls: Arc<TlsOption>,
io_runtime: Arc<Runtime>,
user_provider: Option<UserProviderRef>,
) -> PostgresServer {
let postgres_handler = Arc::new(PostgresServerHandler::new(query_handler));
let startup_handler = Arc::new(PgAuthStartupHandler::new(
check_pwd,
user_provider,
tls.should_force_tls(),
));

View File

@@ -16,9 +16,10 @@ use std::sync::Arc;
use api::v1::InsertExpr;
use async_trait::async_trait;
use axum::Router;
use axum::{http, Router};
use axum_test_helper::TestClient;
use common_query::Output;
use servers::auth::user_provider::StaticUserProvider;
use servers::error::Result;
use servers::http::{HttpOptions, HttpServer};
use servers::influxdb::InfluxdbRequest;
@@ -53,6 +54,9 @@ impl SqlQueryHandler for DummyInstance {
fn make_test_app(tx: mpsc::Sender<(String, String)>) -> Router {
let instance = Arc::new(DummyInstance { tx });
let mut server = HttpServer::new(instance.clone(), HttpOptions::default());
let up = StaticUserProvider::try_from("cmd:greptime=greptime").unwrap();
server.set_user_provider(Arc::new(up));
server.set_influxdb_handler(instance);
server.make_app()
}
@@ -68,6 +72,10 @@ async fn test_influxdb_write() {
let result = client
.post("/v1/influxdb/write")
.body("monitor,host=host1 cpu=1.2 1664370459457010101")
.header(
http::header::AUTHORIZATION,
"basic Z3JlcHRpbWU6Z3JlcHRpbWU=",
)
.send()
.await;
assert_eq!(result.status(), 204);
@@ -76,6 +84,10 @@ async fn test_influxdb_write() {
let result = client
.post("/v1/influxdb/write?db=influxdb")
.body("monitor,host=host1 cpu=1.2 1664370459457010101")
.header(
http::header::AUTHORIZATION,
"basic Z3JlcHRpbWU6Z3JlcHRpbWU=",
)
.send()
.await;
assert_eq!(result.status(), 204);
@@ -85,6 +97,10 @@ async fn test_influxdb_write() {
let result = client
.post("/v1/influxdb/write")
.body("monitor, host=host1 cpu=1.2 1664370459457010101")
.header(
http::header::AUTHORIZATION,
"basic Z3JlcHRpbWU6Z3JlcHRpbWU=",
)
.send()
.await;
assert_eq!(result.status(), 400);

View File

@@ -23,6 +23,7 @@ use mysql_async::prelude::*;
use mysql_async::SslOpts;
use rand::rngs::StdRng;
use rand::Rng;
use servers::auth::user_provider::StaticUserProvider;
use servers::error::Result;
use servers::mysql::server::MysqlServer;
use servers::server::Server;
@@ -42,11 +43,13 @@ fn create_mysql_server(table: MemTable, tls: Arc<TlsOption>) -> Result<Box<dyn S
.unwrap(),
);
let provider = StaticUserProvider::try_from("cmd:greptime=greptime").unwrap();
Ok(MysqlServer::create_server(
query_handler,
io_runtime,
tls,
None,
Some(Arc::new(provider)),
))
}
@@ -85,10 +88,10 @@ async fn test_shutdown_mysql_server() -> Result<()> {
let server_port = server_addr.port();
let mut join_handles = vec![];
for index in 0..2 {
for _ in 0..2 {
join_handles.push(tokio::spawn(async move {
for _ in 0..1000 {
match create_connection(server_port, index == 1, false).await {
match create_connection(server_port, false).await {
Ok(mut connection) => {
let result: u32 = connection
.query_first("SELECT uint32s FROM numbers LIMIT 1")
@@ -125,7 +128,7 @@ async fn test_query_all_datatypes() -> Result<()> {
let server_tls = Arc::new(TlsOption::default());
let client_tls = false;
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
do_test_query_all_datatypes(server_tls, client_tls).await?;
Ok(())
}
@@ -138,7 +141,7 @@ async fn test_server_prefer_secure_client_plain() -> Result<()> {
});
let client_tls = false;
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
do_test_query_all_datatypes(server_tls, client_tls).await?;
Ok(())
}
@@ -151,7 +154,7 @@ async fn test_server_prefer_secure_client_secure() -> Result<()> {
});
let client_tls = true;
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
do_test_query_all_datatypes(server_tls, client_tls).await?;
Ok(())
}
@@ -164,7 +167,7 @@ async fn test_server_require_secure_client_secure() -> Result<()> {
});
let client_tls = true;
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
do_test_query_all_datatypes(server_tls, client_tls).await?;
Ok(())
}
@@ -194,16 +197,12 @@ async fn test_server_required_secure_client_plain() -> Result<()> {
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
let r = create_connection(server_addr.port(), client_tls, false).await;
let r = create_connection(server_addr.port(), client_tls).await;
assert!(r.is_err());
Ok(())
}
async fn do_test_query_all_datatypes(
server_tls: Arc<TlsOption>,
with_pwd: bool,
client_tls: bool,
) -> Result<()> {
async fn do_test_query_all_datatypes(server_tls: Arc<TlsOption>, client_tls: bool) -> Result<()> {
common_telemetry::init_default_ut_logging();
let TestingData {
column_schemas,
@@ -220,7 +219,7 @@ async fn do_test_query_all_datatypes(
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
let mut connection = create_connection(server_addr.port(), client_tls, with_pwd)
let mut connection = create_connection(server_addr.port(), client_tls)
.await
.unwrap();
@@ -258,13 +257,11 @@ async fn test_query_concurrently() -> Result<()> {
let threads = 4;
let expect_executed_queries_per_worker = 1000;
let mut join_handles = vec![];
for index in 0..threads {
for _ in 0..threads {
join_handles.push(tokio::spawn(async move {
let mut rand: StdRng = rand::SeedableRng::from_entropy();
let mut connection = create_connection(server_port, index % 2 == 0, false)
.await
.unwrap();
let mut connection = create_connection(server_port, false).await.unwrap();
for _ in 0..expect_executed_queries_per_worker {
let expected: u32 = rand.gen_range(0..100);
let result: u32 = connection
@@ -279,9 +276,7 @@ async fn test_query_concurrently() -> Result<()> {
let should_recreate_conn = expected == 1;
if should_recreate_conn {
connection = create_connection(server_port, index % 2 == 0, false)
.await
.unwrap();
connection = create_connection(server_port, false).await.unwrap();
}
}
expect_executed_queries_per_worker
@@ -295,16 +290,14 @@ async fn test_query_concurrently() -> Result<()> {
Ok(())
}
async fn create_connection(
port: u16,
with_pwd: bool,
ssl: bool,
) -> mysql_async::Result<mysql_async::Conn> {
async fn create_connection(port: u16, ssl: bool) -> mysql_async::Result<mysql_async::Conn> {
let mut opts = mysql_async::OptsBuilder::default()
.ip_or_hostname("127.0.0.1")
.tcp_port(port)
.prefer_socket(false)
.wait_timeout(Some(1000));
.wait_timeout(Some(1000))
.user(Some("greptime".to_string()))
.pass(Some("greptime".to_string()));
if ssl {
let ssl_opts = SslOpts::default()
@@ -313,9 +306,5 @@ async fn create_connection(
opts = opts.ssl_opts(ssl_opts)
}
if with_pwd {
opts = opts.pass(Some("default_pwd".to_string()));
}
mysql_async::Conn::new(opts).await
}

View File

@@ -22,6 +22,8 @@ use rand::rngs::StdRng;
use rand::Rng;
use rustls::client::{ServerCertVerified, ServerCertVerifier};
use rustls::{Certificate, Error, ServerName};
use servers::auth::user_provider::StaticUserProvider;
use servers::auth::UserProviderRef;
use servers::error::Result;
use servers::postgres::PostgresServer;
use servers::server::Server;
@@ -44,12 +46,19 @@ fn create_postgres_server(
.build()
.unwrap(),
);
let user_provider: Option<UserProviderRef> = if check_pwd {
Some(Arc::new(
StaticUserProvider::try_from("cmd:test_user=test_pwd").unwrap(),
))
} else {
None
};
Ok(Box::new(PostgresServer::new(
query_handler,
check_pwd,
tls,
io_runtime,
None,
user_provider,
)))
}