feat: impl static_user_provider (#739)

* feat: add MemUserProvider and impl auth

* feat: impl user_provider option in fe and standalone mode

* chore: add file impl for mem provider

* chore: remove mem opts

* chore: minor change

* chore: refac pg server to use user_provider as indicator for using pwd auth

* chore: fix test

* chore: extract common code

* chore: add unit test

* chore: rebase develop

* chore: add user provider to http server

* chore: minor rename

* chore: change to ref when convert to anymap

* chore: fix according to clippy

* chore: remove clone on startcommand

* chore: fix cr issue

* chore: update tempdir use

* chore: change TryFrom to normal func while parsing anymap

* chore: minor change

* chore: remove to_lowercase
This commit is contained in:
shuiyisong
2022-12-14 16:38:29 +08:00
committed by GitHub
parent 756c068166
commit fda9e80cbf
17 changed files with 482 additions and 73 deletions

View File

@@ -22,6 +22,7 @@ common-runtime = { path = "../common/runtime" }
common-telemetry = { path = "../common/telemetry" }
common-time = { path = "../common/time" }
datatypes = { path = "../datatypes" }
digest = "0.10"
futures = "0.3"
hex = { version = "0.4" }
http-body = "0.4"
@@ -43,6 +44,7 @@ schemars = "0.8"
serde = "1.0"
serde_json = "1.0"
session = { path = "../session" }
sha1 = "0.10"
snafu = { version = "0.7", features = ["backtraces"] }
snap = "1"
strum = { version = "0.24", features = ["derive"] }
@@ -67,6 +69,7 @@ rand = "0.8"
script = { path = "../script", features = ["python"] }
serde_json = "1.0"
table = { path = "../table" }
tempdir = "0.3"
tokio-postgres = "0.7"
tokio-postgres-rustls = "0.9"
tokio-test = "0.4"

View File

@@ -12,13 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod user_provider;
pub const DEFAULT_USERNAME: &str = "greptime";
use std::sync::Arc;
use common_error::prelude::ErrorExt;
use common_error::status_code::StatusCode;
use snafu::{Backtrace, ErrorCompat, Snafu};
use snafu::{Backtrace, ErrorCompat, OptionExt, Snafu};
use crate::auth::user_provider::StaticUserProvider;
#[async_trait::async_trait]
pub trait UserProvider: Send + Sync {
@@ -73,11 +77,40 @@ impl UserInfo {
}
}
pub fn user_provider_from_option(opt: &String) -> Result<UserProviderRef, Error> {
let (name, content) = opt.split_once(':').context(InvalidConfigSnafu {
value: opt.to_string(),
msg: "UserProviderOption must be in format `<option>:<value>`",
})?;
match name {
user_provider::STATIC_USER_PROVIDER => {
let provider =
StaticUserProvider::try_from(content).map(|p| Arc::new(p) as UserProviderRef)?;
Ok(provider)
}
_ => InvalidConfigSnafu {
value: name.to_string(),
msg: "Invalid UserProviderOption",
}
.fail(),
}
}
#[derive(Debug, Snafu)]
#[snafu(visibility(pub))]
pub enum Error {
#[snafu(display("User not found"))]
UserNotFound { backtrace: Backtrace },
#[snafu(display("Invalid config value: {}, {}", value, msg))]
InvalidConfig {
value: String,
msg: String,
backtrace: Backtrace,
},
#[snafu(display("Encounter IO error, source: {}", source))]
IOErr { source: std::io::Error },
#[snafu(display("User not found, username: {}", username))]
UserNotFound { username: String },
#[snafu(display("Unsupported password type: {}", password_type))]
UnsupportedPasswordType {
@@ -85,20 +118,23 @@ pub enum Error {
backtrace: Backtrace,
},
#[snafu(display("Username and password does not match"))]
UserPasswordMismatch { backtrace: Backtrace },
#[snafu(display("Username and password does not match, username: {}", username))]
UserPasswordMismatch { username: String },
}
impl ErrorExt for Error {
fn status_code(&self) -> StatusCode {
match self {
Error::InvalidConfig { .. } => StatusCode::InvalidArguments,
Error::IOErr { .. } => StatusCode::Internal,
Error::UserNotFound { .. } => StatusCode::UserNotFound,
Error::UnsupportedPasswordType { .. } => StatusCode::UnsupportedPasswordType,
Error::UserPasswordMismatch { .. } => StatusCode::UserPasswordMismatch,
}
}
fn backtrace_opt(&self) -> Option<&common_error::snafu::Backtrace> {
fn backtrace_opt(&self) -> Option<&Backtrace> {
ErrorCompat::backtrace(self)
}
@@ -133,10 +169,16 @@ pub mod test {
username: "greptime".to_string(),
});
} else {
return super::UserPasswordMismatchSnafu {}.fail();
return super::UserPasswordMismatchSnafu {
username: username.to_string(),
}
.fail();
}
} else {
return super::UserNotFoundSnafu {}.fail();
return super::UserNotFoundSnafu {
username: username.to_string(),
}
.fail();
}
}
_ => super::UnsupportedPasswordTypeSnafu {

View File

@@ -0,0 +1,253 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::fs::File;
use std::io;
use std::io::BufRead;
use std::path::Path;
use async_trait::async_trait;
use digest;
use digest::Digest;
use sha1::Sha1;
use snafu::{ensure, OptionExt, ResultExt};
use crate::auth::{
Error, HashedPassword, IOErrSnafu, Identity, InvalidConfigSnafu, Password, Salt,
UnsupportedPasswordTypeSnafu, UserInfo, UserNotFoundSnafu, UserPasswordMismatchSnafu,
UserProvider,
};
pub const STATIC_USER_PROVIDER: &str = "static_user_provider";
impl TryFrom<&str> for StaticUserProvider {
type Error = Error;
fn try_from(value: &str) -> Result<Self, Self::Error> {
let (mode, content) = value.split_once(':').context(InvalidConfigSnafu {
value: value.to_string(),
msg: "StaticUserProviderOption must be in format `<option>:<value>`",
})?;
return match mode {
"file" => {
// check valid path
let path = Path::new(content);
ensure!(path.exists() && path.is_file(), InvalidConfigSnafu {
value: content.to_string(),
msg: "StaticUserProviderOption file must be a valid file path",
});
let file = File::open(path).context(IOErrSnafu)?;
let credential = io::BufReader::new(file)
.lines()
.filter_map(|line| line.ok())
.filter_map(|line| {
if let Some((k, v)) = line.split_once('=') {
Some((k.to_string(), v.as_bytes().to_vec()))
} else {
None
}
})
.collect::<HashMap<String, Vec<u8>>>();
ensure!(!credential.is_empty(), InvalidConfigSnafu {
value: content.to_string(),
msg: "StaticUserProviderOption file must contains at least one valid credential",
});
Ok(StaticUserProvider { users: credential, })
}
"cmd" => content
.split(',')
.map(|kv| {
let (k, v) = kv.split_once('=').context(InvalidConfigSnafu {
value: kv.to_string(),
msg: "StaticUserProviderOption cmd values must be in format `user=pwd[,user=pwd]`",
})?;
Ok((k.to_string(), v.as_bytes().to_vec()))
})
.collect::<Result<HashMap<String, Vec<u8>>, Error>>()
.map(|users| StaticUserProvider { users }),
_ => InvalidConfigSnafu {
value: mode.to_string(),
msg: "StaticUserProviderOption must be in format `file:<path>` or `cmd:<values>`",
}
.fail(),
};
}
}
pub struct StaticUserProvider {
users: HashMap<String, Vec<u8>>,
}
#[async_trait]
impl UserProvider for StaticUserProvider {
fn name(&self) -> &str {
STATIC_USER_PROVIDER
}
async fn auth(
&self,
input_id: Identity<'_>,
input_pwd: Password<'_>,
) -> Result<UserInfo, Error> {
match input_id {
Identity::UserId(username, _) => {
let save_pwd = self.users.get(username).context(UserNotFoundSnafu {
username: username.to_string(),
})?;
match input_pwd {
Password::PlainText(pwd) => {
return if save_pwd == pwd.as_bytes() {
Ok(UserInfo {
username: username.to_string(),
})
} else {
UserPasswordMismatchSnafu {
username: username.to_string(),
}
.fail()
}
}
Password::MysqlNativePassword(auth_data, salt) => {
auth_mysql(auth_data, salt, username.to_string(), save_pwd)
}
Password::PgMD5(_, _) => UnsupportedPasswordTypeSnafu {
password_type: "pg_md5",
}
.fail(),
}
}
}
}
}
fn auth_mysql(
auth_data: HashedPassword,
salt: Salt,
username: String,
save_pwd: &[u8],
) -> Result<UserInfo, Error> {
// ref: https://github.com/mysql/mysql-server/blob/a246bad76b9271cb4333634e954040a970222e0a/sql/auth/password.cc#L62
let hash_stage_2 = double_sha1(save_pwd);
let tmp = sha1_two(salt, &hash_stage_2);
// xor auth_data and tmp
let mut xor_result = [0u8; 20];
for i in 0..20 {
xor_result[i] = auth_data[i] ^ tmp[i];
}
let candidate_stage_2 = sha1_one(&xor_result);
if candidate_stage_2 == hash_stage_2 {
Ok(UserInfo { username })
} else {
UserPasswordMismatchSnafu { username }.fail()
}
}
fn sha1_two(input_1: &[u8], input_2: &[u8]) -> Vec<u8> {
let mut hasher = Sha1::new();
hasher.update(input_1);
hasher.update(input_2);
hasher.finalize().to_vec()
}
fn sha1_one(data: &[u8]) -> Vec<u8> {
let mut hasher = Sha1::new();
hasher.update(data);
hasher.finalize().to_vec()
}
fn double_sha1(data: &[u8]) -> Vec<u8> {
sha1_one(&sha1_one(data))
}
#[cfg(test)]
pub mod test {
use std::fs::File;
use std::io::{LineWriter, Write};
use tempdir::TempDir;
use crate::auth::user_provider::{double_sha1, sha1_one, sha1_two, StaticUserProvider};
use crate::auth::{Identity, Password, UserProvider};
#[test]
fn test_sha() {
let sha_1_answer: Vec<u8> = vec![
124, 74, 141, 9, 202, 55, 98, 175, 97, 229, 149, 32, 148, 61, 194, 100, 148, 248, 148,
27,
];
let sha_1 = sha1_one("123456".as_bytes());
assert_eq!(sha_1, sha_1_answer);
let double_sha1_answer: Vec<u8> = vec![
107, 180, 131, 126, 183, 67, 41, 16, 94, 228, 86, 141, 218, 125, 198, 126, 210, 202,
42, 217,
];
let double_sha1 = double_sha1("123456".as_bytes());
assert_eq!(double_sha1, double_sha1_answer);
let sha1_2_answer: Vec<u8> = vec![
132, 115, 215, 211, 99, 186, 164, 206, 168, 152, 217, 192, 117, 47, 240, 252, 142, 244,
37, 204,
];
let sha1_2 = sha1_two("123456".as_bytes(), "654321".as_bytes());
assert_eq!(sha1_2, sha1_2_answer);
}
async fn test_auth(provider: &dyn UserProvider, username: &str, password: &str) {
let re = provider
.auth(
Identity::UserId(username, None),
Password::PlainText(password),
)
.await;
assert!(re.is_ok());
}
#[tokio::test]
async fn test_inline_provider() {
let provider = StaticUserProvider::try_from("cmd:root=123456,admin=654321").unwrap();
test_auth(&provider, "root", "123456").await;
test_auth(&provider, "admin", "654321").await;
}
#[tokio::test]
async fn test_file_provider() {
let dir = TempDir::new("test_file_provider").unwrap();
let file_path = format!("{}/test_file_provider", dir.path().to_str().unwrap());
{
// write a tmp file
let file = File::create(&file_path);
assert!(file.is_ok());
let file = file.unwrap();
let mut lw = LineWriter::new(file);
assert!(lw
.write_all(
b"root=123456
admin=654321",
)
.is_ok());
assert!(lw.flush().is_ok());
}
let param = format!("file:{}", file_path);
let provider = StaticUserProvider::try_from(param.as_str()).unwrap();
test_auth(&provider, "root", "123456").await;
test_auth(&provider, "admin", "654321").await;
}
}

View File

@@ -107,16 +107,14 @@ impl ServerParameterProvider for GreptimeDBStartupParameters {
pub struct PgAuthStartupHandler {
verifier: PgPwdVerifier,
param_provider: GreptimeDBStartupParameters,
with_pwd: bool,
force_tls: bool,
}
impl PgAuthStartupHandler {
pub fn new(with_pwd: bool, user_provider: Option<UserProviderRef>, force_tls: bool) -> Self {
pub fn new(user_provider: Option<UserProviderRef>, force_tls: bool) -> Self {
PgAuthStartupHandler {
verifier: PgPwdVerifier { user_provider },
param_provider: GreptimeDBStartupParameters::new(),
with_pwd,
force_tls,
}
}
@@ -151,7 +149,7 @@ impl StartupHandler for PgAuthStartupHandler {
return Ok(());
}
auth::save_startup_parameters_to_metadata(client, startup);
if self.with_pwd {
if self.verifier.user_provider.is_some() {
client.set_state(PgWireConnectionState::AuthenticationInProgress);
client
.send(PgWireBackendMessage::Authentication(

View File

@@ -43,14 +43,12 @@ impl PostgresServer {
/// Creates a new Postgres server with provided query_handler and async runtime
pub fn new(
query_handler: SqlQueryHandlerRef,
check_pwd: bool,
tls: Arc<TlsOption>,
io_runtime: Arc<Runtime>,
user_provider: Option<UserProviderRef>,
) -> PostgresServer {
let postgres_handler = Arc::new(PostgresServerHandler::new(query_handler));
let startup_handler = Arc::new(PgAuthStartupHandler::new(
check_pwd,
user_provider,
tls.should_force_tls(),
));

View File

@@ -16,9 +16,10 @@ use std::sync::Arc;
use api::v1::InsertExpr;
use async_trait::async_trait;
use axum::Router;
use axum::{http, Router};
use axum_test_helper::TestClient;
use common_query::Output;
use servers::auth::user_provider::StaticUserProvider;
use servers::error::Result;
use servers::http::{HttpOptions, HttpServer};
use servers::influxdb::InfluxdbRequest;
@@ -53,6 +54,9 @@ impl SqlQueryHandler for DummyInstance {
fn make_test_app(tx: mpsc::Sender<(String, String)>) -> Router {
let instance = Arc::new(DummyInstance { tx });
let mut server = HttpServer::new(instance.clone(), HttpOptions::default());
let up = StaticUserProvider::try_from("cmd:greptime=greptime").unwrap();
server.set_user_provider(Arc::new(up));
server.set_influxdb_handler(instance);
server.make_app()
}
@@ -68,6 +72,10 @@ async fn test_influxdb_write() {
let result = client
.post("/v1/influxdb/write")
.body("monitor,host=host1 cpu=1.2 1664370459457010101")
.header(
http::header::AUTHORIZATION,
"basic Z3JlcHRpbWU6Z3JlcHRpbWU=",
)
.send()
.await;
assert_eq!(result.status(), 204);
@@ -76,6 +84,10 @@ async fn test_influxdb_write() {
let result = client
.post("/v1/influxdb/write?db=influxdb")
.body("monitor,host=host1 cpu=1.2 1664370459457010101")
.header(
http::header::AUTHORIZATION,
"basic Z3JlcHRpbWU6Z3JlcHRpbWU=",
)
.send()
.await;
assert_eq!(result.status(), 204);
@@ -85,6 +97,10 @@ async fn test_influxdb_write() {
let result = client
.post("/v1/influxdb/write")
.body("monitor, host=host1 cpu=1.2 1664370459457010101")
.header(
http::header::AUTHORIZATION,
"basic Z3JlcHRpbWU6Z3JlcHRpbWU=",
)
.send()
.await;
assert_eq!(result.status(), 400);

View File

@@ -23,6 +23,7 @@ use mysql_async::prelude::*;
use mysql_async::SslOpts;
use rand::rngs::StdRng;
use rand::Rng;
use servers::auth::user_provider::StaticUserProvider;
use servers::error::Result;
use servers::mysql::server::MysqlServer;
use servers::server::Server;
@@ -42,11 +43,13 @@ fn create_mysql_server(table: MemTable, tls: Arc<TlsOption>) -> Result<Box<dyn S
.unwrap(),
);
let provider = StaticUserProvider::try_from("cmd:greptime=greptime").unwrap();
Ok(MysqlServer::create_server(
query_handler,
io_runtime,
tls,
None,
Some(Arc::new(provider)),
))
}
@@ -85,10 +88,10 @@ async fn test_shutdown_mysql_server() -> Result<()> {
let server_port = server_addr.port();
let mut join_handles = vec![];
for index in 0..2 {
for _ in 0..2 {
join_handles.push(tokio::spawn(async move {
for _ in 0..1000 {
match create_connection(server_port, index == 1, false).await {
match create_connection(server_port, false).await {
Ok(mut connection) => {
let result: u32 = connection
.query_first("SELECT uint32s FROM numbers LIMIT 1")
@@ -125,7 +128,7 @@ async fn test_query_all_datatypes() -> Result<()> {
let server_tls = Arc::new(TlsOption::default());
let client_tls = false;
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
do_test_query_all_datatypes(server_tls, client_tls).await?;
Ok(())
}
@@ -138,7 +141,7 @@ async fn test_server_prefer_secure_client_plain() -> Result<()> {
});
let client_tls = false;
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
do_test_query_all_datatypes(server_tls, client_tls).await?;
Ok(())
}
@@ -151,7 +154,7 @@ async fn test_server_prefer_secure_client_secure() -> Result<()> {
});
let client_tls = true;
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
do_test_query_all_datatypes(server_tls, client_tls).await?;
Ok(())
}
@@ -164,7 +167,7 @@ async fn test_server_require_secure_client_secure() -> Result<()> {
});
let client_tls = true;
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
do_test_query_all_datatypes(server_tls, client_tls).await?;
Ok(())
}
@@ -194,16 +197,12 @@ async fn test_server_required_secure_client_plain() -> Result<()> {
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
let r = create_connection(server_addr.port(), client_tls, false).await;
let r = create_connection(server_addr.port(), client_tls).await;
assert!(r.is_err());
Ok(())
}
async fn do_test_query_all_datatypes(
server_tls: Arc<TlsOption>,
with_pwd: bool,
client_tls: bool,
) -> Result<()> {
async fn do_test_query_all_datatypes(server_tls: Arc<TlsOption>, client_tls: bool) -> Result<()> {
common_telemetry::init_default_ut_logging();
let TestingData {
column_schemas,
@@ -220,7 +219,7 @@ async fn do_test_query_all_datatypes(
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
let mut connection = create_connection(server_addr.port(), client_tls, with_pwd)
let mut connection = create_connection(server_addr.port(), client_tls)
.await
.unwrap();
@@ -258,13 +257,11 @@ async fn test_query_concurrently() -> Result<()> {
let threads = 4;
let expect_executed_queries_per_worker = 1000;
let mut join_handles = vec![];
for index in 0..threads {
for _ in 0..threads {
join_handles.push(tokio::spawn(async move {
let mut rand: StdRng = rand::SeedableRng::from_entropy();
let mut connection = create_connection(server_port, index % 2 == 0, false)
.await
.unwrap();
let mut connection = create_connection(server_port, false).await.unwrap();
for _ in 0..expect_executed_queries_per_worker {
let expected: u32 = rand.gen_range(0..100);
let result: u32 = connection
@@ -279,9 +276,7 @@ async fn test_query_concurrently() -> Result<()> {
let should_recreate_conn = expected == 1;
if should_recreate_conn {
connection = create_connection(server_port, index % 2 == 0, false)
.await
.unwrap();
connection = create_connection(server_port, false).await.unwrap();
}
}
expect_executed_queries_per_worker
@@ -295,16 +290,14 @@ async fn test_query_concurrently() -> Result<()> {
Ok(())
}
async fn create_connection(
port: u16,
with_pwd: bool,
ssl: bool,
) -> mysql_async::Result<mysql_async::Conn> {
async fn create_connection(port: u16, ssl: bool) -> mysql_async::Result<mysql_async::Conn> {
let mut opts = mysql_async::OptsBuilder::default()
.ip_or_hostname("127.0.0.1")
.tcp_port(port)
.prefer_socket(false)
.wait_timeout(Some(1000));
.wait_timeout(Some(1000))
.user(Some("greptime".to_string()))
.pass(Some("greptime".to_string()));
if ssl {
let ssl_opts = SslOpts::default()
@@ -313,9 +306,5 @@ async fn create_connection(
opts = opts.ssl_opts(ssl_opts)
}
if with_pwd {
opts = opts.pass(Some("default_pwd".to_string()));
}
mysql_async::Conn::new(opts).await
}

View File

@@ -22,6 +22,8 @@ use rand::rngs::StdRng;
use rand::Rng;
use rustls::client::{ServerCertVerified, ServerCertVerifier};
use rustls::{Certificate, Error, ServerName};
use servers::auth::user_provider::StaticUserProvider;
use servers::auth::UserProviderRef;
use servers::error::Result;
use servers::postgres::PostgresServer;
use servers::server::Server;
@@ -44,12 +46,19 @@ fn create_postgres_server(
.build()
.unwrap(),
);
let user_provider: Option<UserProviderRef> = if check_pwd {
Some(Arc::new(
StaticUserProvider::try_from("cmd:test_user=test_pwd").unwrap(),
))
} else {
None
};
Ok(Box::new(PostgresServer::new(
query_handler,
check_pwd,
tls,
io_runtime,
None,
user_provider,
)))
}