mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-25 09:20:40 +00:00
feat: impl static_user_provider (#739)
* feat: add MemUserProvider and impl auth * feat: impl user_provider option in fe and standalone mode * chore: add file impl for mem provider * chore: remove mem opts * chore: minor change * chore: refac pg server to use user_provider as indicator for using pwd auth * chore: fix test * chore: extract common code * chore: add unit test * chore: rebase develop * chore: add user provider to http server * chore: minor rename * chore: change to ref when convert to anymap * chore: fix according to clippy * chore: remove clone on startcommand * chore: fix cr issue * chore: update tempdir use * chore: change TryFrom to normal func while parsing anymap * chore: minor change * chore: remove to_lowercase
This commit is contained in:
@@ -22,6 +22,7 @@ common-runtime = { path = "../common/runtime" }
|
||||
common-telemetry = { path = "../common/telemetry" }
|
||||
common-time = { path = "../common/time" }
|
||||
datatypes = { path = "../datatypes" }
|
||||
digest = "0.10"
|
||||
futures = "0.3"
|
||||
hex = { version = "0.4" }
|
||||
http-body = "0.4"
|
||||
@@ -43,6 +44,7 @@ schemars = "0.8"
|
||||
serde = "1.0"
|
||||
serde_json = "1.0"
|
||||
session = { path = "../session" }
|
||||
sha1 = "0.10"
|
||||
snafu = { version = "0.7", features = ["backtraces"] }
|
||||
snap = "1"
|
||||
strum = { version = "0.24", features = ["derive"] }
|
||||
@@ -67,6 +69,7 @@ rand = "0.8"
|
||||
script = { path = "../script", features = ["python"] }
|
||||
serde_json = "1.0"
|
||||
table = { path = "../table" }
|
||||
tempdir = "0.3"
|
||||
tokio-postgres = "0.7"
|
||||
tokio-postgres-rustls = "0.9"
|
||||
tokio-test = "0.4"
|
||||
|
||||
@@ -12,13 +12,17 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
pub mod user_provider;
|
||||
|
||||
pub const DEFAULT_USERNAME: &str = "greptime";
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_error::prelude::ErrorExt;
|
||||
use common_error::status_code::StatusCode;
|
||||
use snafu::{Backtrace, ErrorCompat, Snafu};
|
||||
use snafu::{Backtrace, ErrorCompat, OptionExt, Snafu};
|
||||
|
||||
use crate::auth::user_provider::StaticUserProvider;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait UserProvider: Send + Sync {
|
||||
@@ -73,11 +77,40 @@ impl UserInfo {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn user_provider_from_option(opt: &String) -> Result<UserProviderRef, Error> {
|
||||
let (name, content) = opt.split_once(':').context(InvalidConfigSnafu {
|
||||
value: opt.to_string(),
|
||||
msg: "UserProviderOption must be in format `<option>:<value>`",
|
||||
})?;
|
||||
match name {
|
||||
user_provider::STATIC_USER_PROVIDER => {
|
||||
let provider =
|
||||
StaticUserProvider::try_from(content).map(|p| Arc::new(p) as UserProviderRef)?;
|
||||
Ok(provider)
|
||||
}
|
||||
_ => InvalidConfigSnafu {
|
||||
value: name.to_string(),
|
||||
msg: "Invalid UserProviderOption",
|
||||
}
|
||||
.fail(),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(visibility(pub))]
|
||||
pub enum Error {
|
||||
#[snafu(display("User not found"))]
|
||||
UserNotFound { backtrace: Backtrace },
|
||||
#[snafu(display("Invalid config value: {}, {}", value, msg))]
|
||||
InvalidConfig {
|
||||
value: String,
|
||||
msg: String,
|
||||
backtrace: Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Encounter IO error, source: {}", source))]
|
||||
IOErr { source: std::io::Error },
|
||||
|
||||
#[snafu(display("User not found, username: {}", username))]
|
||||
UserNotFound { username: String },
|
||||
|
||||
#[snafu(display("Unsupported password type: {}", password_type))]
|
||||
UnsupportedPasswordType {
|
||||
@@ -85,20 +118,23 @@ pub enum Error {
|
||||
backtrace: Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Username and password does not match"))]
|
||||
UserPasswordMismatch { backtrace: Backtrace },
|
||||
#[snafu(display("Username and password does not match, username: {}", username))]
|
||||
UserPasswordMismatch { username: String },
|
||||
}
|
||||
|
||||
impl ErrorExt for Error {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
Error::InvalidConfig { .. } => StatusCode::InvalidArguments,
|
||||
Error::IOErr { .. } => StatusCode::Internal,
|
||||
|
||||
Error::UserNotFound { .. } => StatusCode::UserNotFound,
|
||||
Error::UnsupportedPasswordType { .. } => StatusCode::UnsupportedPasswordType,
|
||||
Error::UserPasswordMismatch { .. } => StatusCode::UserPasswordMismatch,
|
||||
}
|
||||
}
|
||||
|
||||
fn backtrace_opt(&self) -> Option<&common_error::snafu::Backtrace> {
|
||||
fn backtrace_opt(&self) -> Option<&Backtrace> {
|
||||
ErrorCompat::backtrace(self)
|
||||
}
|
||||
|
||||
@@ -133,10 +169,16 @@ pub mod test {
|
||||
username: "greptime".to_string(),
|
||||
});
|
||||
} else {
|
||||
return super::UserPasswordMismatchSnafu {}.fail();
|
||||
return super::UserPasswordMismatchSnafu {
|
||||
username: username.to_string(),
|
||||
}
|
||||
.fail();
|
||||
}
|
||||
} else {
|
||||
return super::UserNotFoundSnafu {}.fail();
|
||||
return super::UserNotFoundSnafu {
|
||||
username: username.to_string(),
|
||||
}
|
||||
.fail();
|
||||
}
|
||||
}
|
||||
_ => super::UnsupportedPasswordTypeSnafu {
|
||||
|
||||
253
src/servers/src/auth/user_provider.rs
Normal file
253
src/servers/src/auth/user_provider.rs
Normal file
@@ -0,0 +1,253 @@
|
||||
// Copyright 2022 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::fs::File;
|
||||
use std::io;
|
||||
use std::io::BufRead;
|
||||
use std::path::Path;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use digest;
|
||||
use digest::Digest;
|
||||
use sha1::Sha1;
|
||||
use snafu::{ensure, OptionExt, ResultExt};
|
||||
|
||||
use crate::auth::{
|
||||
Error, HashedPassword, IOErrSnafu, Identity, InvalidConfigSnafu, Password, Salt,
|
||||
UnsupportedPasswordTypeSnafu, UserInfo, UserNotFoundSnafu, UserPasswordMismatchSnafu,
|
||||
UserProvider,
|
||||
};
|
||||
|
||||
pub const STATIC_USER_PROVIDER: &str = "static_user_provider";
|
||||
|
||||
impl TryFrom<&str> for StaticUserProvider {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
let (mode, content) = value.split_once(':').context(InvalidConfigSnafu {
|
||||
value: value.to_string(),
|
||||
msg: "StaticUserProviderOption must be in format `<option>:<value>`",
|
||||
})?;
|
||||
return match mode {
|
||||
"file" => {
|
||||
// check valid path
|
||||
let path = Path::new(content);
|
||||
ensure!(path.exists() && path.is_file(), InvalidConfigSnafu {
|
||||
value: content.to_string(),
|
||||
msg: "StaticUserProviderOption file must be a valid file path",
|
||||
});
|
||||
|
||||
let file = File::open(path).context(IOErrSnafu)?;
|
||||
let credential = io::BufReader::new(file)
|
||||
.lines()
|
||||
.filter_map(|line| line.ok())
|
||||
.filter_map(|line| {
|
||||
if let Some((k, v)) = line.split_once('=') {
|
||||
Some((k.to_string(), v.as_bytes().to_vec()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<HashMap<String, Vec<u8>>>();
|
||||
|
||||
ensure!(!credential.is_empty(), InvalidConfigSnafu {
|
||||
value: content.to_string(),
|
||||
msg: "StaticUserProviderOption file must contains at least one valid credential",
|
||||
});
|
||||
|
||||
Ok(StaticUserProvider { users: credential, })
|
||||
}
|
||||
"cmd" => content
|
||||
.split(',')
|
||||
.map(|kv| {
|
||||
let (k, v) = kv.split_once('=').context(InvalidConfigSnafu {
|
||||
value: kv.to_string(),
|
||||
msg: "StaticUserProviderOption cmd values must be in format `user=pwd[,user=pwd]`",
|
||||
})?;
|
||||
Ok((k.to_string(), v.as_bytes().to_vec()))
|
||||
})
|
||||
.collect::<Result<HashMap<String, Vec<u8>>, Error>>()
|
||||
.map(|users| StaticUserProvider { users }),
|
||||
_ => InvalidConfigSnafu {
|
||||
value: mode.to_string(),
|
||||
msg: "StaticUserProviderOption must be in format `file:<path>` or `cmd:<values>`",
|
||||
}
|
||||
.fail(),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
pub struct StaticUserProvider {
|
||||
users: HashMap<String, Vec<u8>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl UserProvider for StaticUserProvider {
|
||||
fn name(&self) -> &str {
|
||||
STATIC_USER_PROVIDER
|
||||
}
|
||||
|
||||
async fn auth(
|
||||
&self,
|
||||
input_id: Identity<'_>,
|
||||
input_pwd: Password<'_>,
|
||||
) -> Result<UserInfo, Error> {
|
||||
match input_id {
|
||||
Identity::UserId(username, _) => {
|
||||
let save_pwd = self.users.get(username).context(UserNotFoundSnafu {
|
||||
username: username.to_string(),
|
||||
})?;
|
||||
|
||||
match input_pwd {
|
||||
Password::PlainText(pwd) => {
|
||||
return if save_pwd == pwd.as_bytes() {
|
||||
Ok(UserInfo {
|
||||
username: username.to_string(),
|
||||
})
|
||||
} else {
|
||||
UserPasswordMismatchSnafu {
|
||||
username: username.to_string(),
|
||||
}
|
||||
.fail()
|
||||
}
|
||||
}
|
||||
Password::MysqlNativePassword(auth_data, salt) => {
|
||||
auth_mysql(auth_data, salt, username.to_string(), save_pwd)
|
||||
}
|
||||
Password::PgMD5(_, _) => UnsupportedPasswordTypeSnafu {
|
||||
password_type: "pg_md5",
|
||||
}
|
||||
.fail(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn auth_mysql(
|
||||
auth_data: HashedPassword,
|
||||
salt: Salt,
|
||||
username: String,
|
||||
save_pwd: &[u8],
|
||||
) -> Result<UserInfo, Error> {
|
||||
// ref: https://github.com/mysql/mysql-server/blob/a246bad76b9271cb4333634e954040a970222e0a/sql/auth/password.cc#L62
|
||||
let hash_stage_2 = double_sha1(save_pwd);
|
||||
let tmp = sha1_two(salt, &hash_stage_2);
|
||||
// xor auth_data and tmp
|
||||
let mut xor_result = [0u8; 20];
|
||||
for i in 0..20 {
|
||||
xor_result[i] = auth_data[i] ^ tmp[i];
|
||||
}
|
||||
let candidate_stage_2 = sha1_one(&xor_result);
|
||||
if candidate_stage_2 == hash_stage_2 {
|
||||
Ok(UserInfo { username })
|
||||
} else {
|
||||
UserPasswordMismatchSnafu { username }.fail()
|
||||
}
|
||||
}
|
||||
|
||||
fn sha1_two(input_1: &[u8], input_2: &[u8]) -> Vec<u8> {
|
||||
let mut hasher = Sha1::new();
|
||||
hasher.update(input_1);
|
||||
hasher.update(input_2);
|
||||
hasher.finalize().to_vec()
|
||||
}
|
||||
|
||||
fn sha1_one(data: &[u8]) -> Vec<u8> {
|
||||
let mut hasher = Sha1::new();
|
||||
hasher.update(data);
|
||||
hasher.finalize().to_vec()
|
||||
}
|
||||
|
||||
fn double_sha1(data: &[u8]) -> Vec<u8> {
|
||||
sha1_one(&sha1_one(data))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test {
|
||||
use std::fs::File;
|
||||
use std::io::{LineWriter, Write};
|
||||
|
||||
use tempdir::TempDir;
|
||||
|
||||
use crate::auth::user_provider::{double_sha1, sha1_one, sha1_two, StaticUserProvider};
|
||||
use crate::auth::{Identity, Password, UserProvider};
|
||||
|
||||
#[test]
|
||||
fn test_sha() {
|
||||
let sha_1_answer: Vec<u8> = vec![
|
||||
124, 74, 141, 9, 202, 55, 98, 175, 97, 229, 149, 32, 148, 61, 194, 100, 148, 248, 148,
|
||||
27,
|
||||
];
|
||||
let sha_1 = sha1_one("123456".as_bytes());
|
||||
assert_eq!(sha_1, sha_1_answer);
|
||||
|
||||
let double_sha1_answer: Vec<u8> = vec![
|
||||
107, 180, 131, 126, 183, 67, 41, 16, 94, 228, 86, 141, 218, 125, 198, 126, 210, 202,
|
||||
42, 217,
|
||||
];
|
||||
let double_sha1 = double_sha1("123456".as_bytes());
|
||||
assert_eq!(double_sha1, double_sha1_answer);
|
||||
|
||||
let sha1_2_answer: Vec<u8> = vec![
|
||||
132, 115, 215, 211, 99, 186, 164, 206, 168, 152, 217, 192, 117, 47, 240, 252, 142, 244,
|
||||
37, 204,
|
||||
];
|
||||
let sha1_2 = sha1_two("123456".as_bytes(), "654321".as_bytes());
|
||||
assert_eq!(sha1_2, sha1_2_answer);
|
||||
}
|
||||
|
||||
async fn test_auth(provider: &dyn UserProvider, username: &str, password: &str) {
|
||||
let re = provider
|
||||
.auth(
|
||||
Identity::UserId(username, None),
|
||||
Password::PlainText(password),
|
||||
)
|
||||
.await;
|
||||
assert!(re.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_inline_provider() {
|
||||
let provider = StaticUserProvider::try_from("cmd:root=123456,admin=654321").unwrap();
|
||||
test_auth(&provider, "root", "123456").await;
|
||||
test_auth(&provider, "admin", "654321").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_file_provider() {
|
||||
let dir = TempDir::new("test_file_provider").unwrap();
|
||||
let file_path = format!("{}/test_file_provider", dir.path().to_str().unwrap());
|
||||
{
|
||||
// write a tmp file
|
||||
let file = File::create(&file_path);
|
||||
assert!(file.is_ok());
|
||||
let file = file.unwrap();
|
||||
let mut lw = LineWriter::new(file);
|
||||
assert!(lw
|
||||
.write_all(
|
||||
b"root=123456
|
||||
admin=654321",
|
||||
)
|
||||
.is_ok());
|
||||
assert!(lw.flush().is_ok());
|
||||
}
|
||||
|
||||
let param = format!("file:{}", file_path);
|
||||
let provider = StaticUserProvider::try_from(param.as_str()).unwrap();
|
||||
test_auth(&provider, "root", "123456").await;
|
||||
test_auth(&provider, "admin", "654321").await;
|
||||
}
|
||||
}
|
||||
@@ -107,16 +107,14 @@ impl ServerParameterProvider for GreptimeDBStartupParameters {
|
||||
pub struct PgAuthStartupHandler {
|
||||
verifier: PgPwdVerifier,
|
||||
param_provider: GreptimeDBStartupParameters,
|
||||
with_pwd: bool,
|
||||
force_tls: bool,
|
||||
}
|
||||
|
||||
impl PgAuthStartupHandler {
|
||||
pub fn new(with_pwd: bool, user_provider: Option<UserProviderRef>, force_tls: bool) -> Self {
|
||||
pub fn new(user_provider: Option<UserProviderRef>, force_tls: bool) -> Self {
|
||||
PgAuthStartupHandler {
|
||||
verifier: PgPwdVerifier { user_provider },
|
||||
param_provider: GreptimeDBStartupParameters::new(),
|
||||
with_pwd,
|
||||
force_tls,
|
||||
}
|
||||
}
|
||||
@@ -151,7 +149,7 @@ impl StartupHandler for PgAuthStartupHandler {
|
||||
return Ok(());
|
||||
}
|
||||
auth::save_startup_parameters_to_metadata(client, startup);
|
||||
if self.with_pwd {
|
||||
if self.verifier.user_provider.is_some() {
|
||||
client.set_state(PgWireConnectionState::AuthenticationInProgress);
|
||||
client
|
||||
.send(PgWireBackendMessage::Authentication(
|
||||
|
||||
@@ -43,14 +43,12 @@ impl PostgresServer {
|
||||
/// Creates a new Postgres server with provided query_handler and async runtime
|
||||
pub fn new(
|
||||
query_handler: SqlQueryHandlerRef,
|
||||
check_pwd: bool,
|
||||
tls: Arc<TlsOption>,
|
||||
io_runtime: Arc<Runtime>,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
) -> PostgresServer {
|
||||
let postgres_handler = Arc::new(PostgresServerHandler::new(query_handler));
|
||||
let startup_handler = Arc::new(PgAuthStartupHandler::new(
|
||||
check_pwd,
|
||||
user_provider,
|
||||
tls.should_force_tls(),
|
||||
));
|
||||
|
||||
@@ -16,9 +16,10 @@ use std::sync::Arc;
|
||||
|
||||
use api::v1::InsertExpr;
|
||||
use async_trait::async_trait;
|
||||
use axum::Router;
|
||||
use axum::{http, Router};
|
||||
use axum_test_helper::TestClient;
|
||||
use common_query::Output;
|
||||
use servers::auth::user_provider::StaticUserProvider;
|
||||
use servers::error::Result;
|
||||
use servers::http::{HttpOptions, HttpServer};
|
||||
use servers::influxdb::InfluxdbRequest;
|
||||
@@ -53,6 +54,9 @@ impl SqlQueryHandler for DummyInstance {
|
||||
fn make_test_app(tx: mpsc::Sender<(String, String)>) -> Router {
|
||||
let instance = Arc::new(DummyInstance { tx });
|
||||
let mut server = HttpServer::new(instance.clone(), HttpOptions::default());
|
||||
let up = StaticUserProvider::try_from("cmd:greptime=greptime").unwrap();
|
||||
server.set_user_provider(Arc::new(up));
|
||||
|
||||
server.set_influxdb_handler(instance);
|
||||
server.make_app()
|
||||
}
|
||||
@@ -68,6 +72,10 @@ async fn test_influxdb_write() {
|
||||
let result = client
|
||||
.post("/v1/influxdb/write")
|
||||
.body("monitor,host=host1 cpu=1.2 1664370459457010101")
|
||||
.header(
|
||||
http::header::AUTHORIZATION,
|
||||
"basic Z3JlcHRpbWU6Z3JlcHRpbWU=",
|
||||
)
|
||||
.send()
|
||||
.await;
|
||||
assert_eq!(result.status(), 204);
|
||||
@@ -76,6 +84,10 @@ async fn test_influxdb_write() {
|
||||
let result = client
|
||||
.post("/v1/influxdb/write?db=influxdb")
|
||||
.body("monitor,host=host1 cpu=1.2 1664370459457010101")
|
||||
.header(
|
||||
http::header::AUTHORIZATION,
|
||||
"basic Z3JlcHRpbWU6Z3JlcHRpbWU=",
|
||||
)
|
||||
.send()
|
||||
.await;
|
||||
assert_eq!(result.status(), 204);
|
||||
@@ -85,6 +97,10 @@ async fn test_influxdb_write() {
|
||||
let result = client
|
||||
.post("/v1/influxdb/write")
|
||||
.body("monitor, host=host1 cpu=1.2 1664370459457010101")
|
||||
.header(
|
||||
http::header::AUTHORIZATION,
|
||||
"basic Z3JlcHRpbWU6Z3JlcHRpbWU=",
|
||||
)
|
||||
.send()
|
||||
.await;
|
||||
assert_eq!(result.status(), 400);
|
||||
|
||||
@@ -23,6 +23,7 @@ use mysql_async::prelude::*;
|
||||
use mysql_async::SslOpts;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::Rng;
|
||||
use servers::auth::user_provider::StaticUserProvider;
|
||||
use servers::error::Result;
|
||||
use servers::mysql::server::MysqlServer;
|
||||
use servers::server::Server;
|
||||
@@ -42,11 +43,13 @@ fn create_mysql_server(table: MemTable, tls: Arc<TlsOption>) -> Result<Box<dyn S
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let provider = StaticUserProvider::try_from("cmd:greptime=greptime").unwrap();
|
||||
|
||||
Ok(MysqlServer::create_server(
|
||||
query_handler,
|
||||
io_runtime,
|
||||
tls,
|
||||
None,
|
||||
Some(Arc::new(provider)),
|
||||
))
|
||||
}
|
||||
|
||||
@@ -85,10 +88,10 @@ async fn test_shutdown_mysql_server() -> Result<()> {
|
||||
let server_port = server_addr.port();
|
||||
|
||||
let mut join_handles = vec![];
|
||||
for index in 0..2 {
|
||||
for _ in 0..2 {
|
||||
join_handles.push(tokio::spawn(async move {
|
||||
for _ in 0..1000 {
|
||||
match create_connection(server_port, index == 1, false).await {
|
||||
match create_connection(server_port, false).await {
|
||||
Ok(mut connection) => {
|
||||
let result: u32 = connection
|
||||
.query_first("SELECT uint32s FROM numbers LIMIT 1")
|
||||
@@ -125,7 +128,7 @@ async fn test_query_all_datatypes() -> Result<()> {
|
||||
let server_tls = Arc::new(TlsOption::default());
|
||||
let client_tls = false;
|
||||
|
||||
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
|
||||
do_test_query_all_datatypes(server_tls, client_tls).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -138,7 +141,7 @@ async fn test_server_prefer_secure_client_plain() -> Result<()> {
|
||||
});
|
||||
|
||||
let client_tls = false;
|
||||
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
|
||||
do_test_query_all_datatypes(server_tls, client_tls).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -151,7 +154,7 @@ async fn test_server_prefer_secure_client_secure() -> Result<()> {
|
||||
});
|
||||
|
||||
let client_tls = true;
|
||||
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
|
||||
do_test_query_all_datatypes(server_tls, client_tls).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -164,7 +167,7 @@ async fn test_server_require_secure_client_secure() -> Result<()> {
|
||||
});
|
||||
|
||||
let client_tls = true;
|
||||
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
|
||||
do_test_query_all_datatypes(server_tls, client_tls).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -194,16 +197,12 @@ async fn test_server_required_secure_client_plain() -> Result<()> {
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = mysql_server.start(listening).await.unwrap();
|
||||
|
||||
let r = create_connection(server_addr.port(), client_tls, false).await;
|
||||
let r = create_connection(server_addr.port(), client_tls).await;
|
||||
assert!(r.is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn do_test_query_all_datatypes(
|
||||
server_tls: Arc<TlsOption>,
|
||||
with_pwd: bool,
|
||||
client_tls: bool,
|
||||
) -> Result<()> {
|
||||
async fn do_test_query_all_datatypes(server_tls: Arc<TlsOption>, client_tls: bool) -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let TestingData {
|
||||
column_schemas,
|
||||
@@ -220,7 +219,7 @@ async fn do_test_query_all_datatypes(
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = mysql_server.start(listening).await.unwrap();
|
||||
|
||||
let mut connection = create_connection(server_addr.port(), client_tls, with_pwd)
|
||||
let mut connection = create_connection(server_addr.port(), client_tls)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -258,13 +257,11 @@ async fn test_query_concurrently() -> Result<()> {
|
||||
let threads = 4;
|
||||
let expect_executed_queries_per_worker = 1000;
|
||||
let mut join_handles = vec![];
|
||||
for index in 0..threads {
|
||||
for _ in 0..threads {
|
||||
join_handles.push(tokio::spawn(async move {
|
||||
let mut rand: StdRng = rand::SeedableRng::from_entropy();
|
||||
|
||||
let mut connection = create_connection(server_port, index % 2 == 0, false)
|
||||
.await
|
||||
.unwrap();
|
||||
let mut connection = create_connection(server_port, false).await.unwrap();
|
||||
for _ in 0..expect_executed_queries_per_worker {
|
||||
let expected: u32 = rand.gen_range(0..100);
|
||||
let result: u32 = connection
|
||||
@@ -279,9 +276,7 @@ async fn test_query_concurrently() -> Result<()> {
|
||||
|
||||
let should_recreate_conn = expected == 1;
|
||||
if should_recreate_conn {
|
||||
connection = create_connection(server_port, index % 2 == 0, false)
|
||||
.await
|
||||
.unwrap();
|
||||
connection = create_connection(server_port, false).await.unwrap();
|
||||
}
|
||||
}
|
||||
expect_executed_queries_per_worker
|
||||
@@ -295,16 +290,14 @@ async fn test_query_concurrently() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_connection(
|
||||
port: u16,
|
||||
with_pwd: bool,
|
||||
ssl: bool,
|
||||
) -> mysql_async::Result<mysql_async::Conn> {
|
||||
async fn create_connection(port: u16, ssl: bool) -> mysql_async::Result<mysql_async::Conn> {
|
||||
let mut opts = mysql_async::OptsBuilder::default()
|
||||
.ip_or_hostname("127.0.0.1")
|
||||
.tcp_port(port)
|
||||
.prefer_socket(false)
|
||||
.wait_timeout(Some(1000));
|
||||
.wait_timeout(Some(1000))
|
||||
.user(Some("greptime".to_string()))
|
||||
.pass(Some("greptime".to_string()));
|
||||
|
||||
if ssl {
|
||||
let ssl_opts = SslOpts::default()
|
||||
@@ -313,9 +306,5 @@ async fn create_connection(
|
||||
opts = opts.ssl_opts(ssl_opts)
|
||||
}
|
||||
|
||||
if with_pwd {
|
||||
opts = opts.pass(Some("default_pwd".to_string()));
|
||||
}
|
||||
|
||||
mysql_async::Conn::new(opts).await
|
||||
}
|
||||
|
||||
@@ -22,6 +22,8 @@ use rand::rngs::StdRng;
|
||||
use rand::Rng;
|
||||
use rustls::client::{ServerCertVerified, ServerCertVerifier};
|
||||
use rustls::{Certificate, Error, ServerName};
|
||||
use servers::auth::user_provider::StaticUserProvider;
|
||||
use servers::auth::UserProviderRef;
|
||||
use servers::error::Result;
|
||||
use servers::postgres::PostgresServer;
|
||||
use servers::server::Server;
|
||||
@@ -44,12 +46,19 @@ fn create_postgres_server(
|
||||
.build()
|
||||
.unwrap(),
|
||||
);
|
||||
let user_provider: Option<UserProviderRef> = if check_pwd {
|
||||
Some(Arc::new(
|
||||
StaticUserProvider::try_from("cmd:test_user=test_pwd").unwrap(),
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Box::new(PostgresServer::new(
|
||||
query_handler,
|
||||
check_pwd,
|
||||
tls,
|
||||
io_runtime,
|
||||
None,
|
||||
user_provider,
|
||||
)))
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user