feat: add authorize to UserProvider trait (#879)

* feat: add SchemaValidator

* feat: add schema validator to mysql shim

* chore: pass schema validator to http auth layer

* feat: add schema validator to http

* feat: add schema validator to pg

* feat: add schema validator to pg

* feat: add schema validator test

* chore: remove println in test

* chore: use !matches

* refactor: refac authenticate and authorize in http auth

* refactor: refac authenticate and authorize in http auth

* chore: typo

* chore: minor change

* refactor: merge schema_validator into user_providier

* chore: fix license issue

* refactor: change http query param from database to db

* chore: fix cr issue
This commit is contained in:
shuiyisong
2023-01-18 12:42:08 +08:00
committed by GitHub
parent 49d83abc0c
commit 6960739b3d
25 changed files with 712 additions and 310 deletions

1
Cargo.lock generated
View File

@@ -6633,6 +6633,7 @@ dependencies = [
"script",
"serde",
"serde_json",
"serde_urlencoded",
"session",
"sha1",
"snafu",

View File

@@ -287,7 +287,7 @@ mod tests {
let provider = provider.unwrap();
let result = provider
.auth(Identity::UserId("test", None), Password::PlainText("test"))
.authenticate(Identity::UserId("test", None), Password::PlainText("test"))
.await;
assert!(result.is_ok());
}

View File

@@ -350,7 +350,7 @@ mod tests {
assert!(provider.is_some());
let provider = provider.unwrap();
let result = provider
.auth(Identity::UserId("test", None), Password::PlainText("test"))
.authenticate(Identity::UserId("test", None), Password::PlainText("test"))
.await;
assert!(result.is_ok());
}

View File

@@ -77,6 +77,8 @@ pub enum StatusCode {
AuthHeaderNotFound = 7003,
/// Invalid http authorization header
InvalidAuthHeader = 7004,
/// Illegal request to connect catalog-schema
AccessDenied = 7005,
// ====== End of auth related status code =====
}

View File

@@ -16,7 +16,6 @@ use std::sync::Arc;
use meta_client::MetaClientOpts;
use serde::{Deserialize, Serialize};
use servers::auth::UserProviderRef;
use servers::http::HttpOptions;
use servers::Mode;
use snafu::prelude::*;
@@ -92,8 +91,6 @@ impl<T: FrontendInstance> Frontend<T> {
let instance = Arc::new(instance);
// TODO(sunng87): merge this into instance
let provider = self.plugins.get::<UserProviderRef>().cloned();
Services::start(&self.opts, instance, provider).await
Services::start(&self.opts, instance, self.plugins.clone()).await
}
}

View File

@@ -34,6 +34,7 @@ use crate::frontend::FrontendOptions;
use crate::influxdb::InfluxdbOptions;
use crate::instance::FrontendInstance;
use crate::prometheus::PrometheusOptions;
use crate::Plugins;
pub(crate) struct Services;
@@ -41,12 +42,14 @@ impl Services {
pub(crate) async fn start<T>(
opts: &FrontendOptions,
instance: Arc<T>,
user_provider: Option<UserProviderRef>,
plugins: Arc<Plugins>,
) -> Result<()>
where
T: FrontendInstance,
{
info!("Starting frontend servers");
let user_provider = plugins.get::<UserProviderRef>().cloned();
let grpc_server_and_addr = if let Some(opts) = &opts.grpc_options {
let grpc_addr = parse_addr(&opts.addr)?;

View File

@@ -48,6 +48,7 @@ rustls-pemfile = "1.0"
schemars = "0.8"
serde.workspace = true
serde_json = "1.0"
serde_urlencoded = "0.7"
session = { path = "../session" }
sha1 = "0.10"
snafu = { version = "0.7", features = ["backtraces"] }

View File

@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod user_provider;
use std::sync::Arc;
use common_error::ext::BoxedError;
@@ -24,11 +22,19 @@ use snafu::{Backtrace, ErrorCompat, OptionExt, Snafu};
use crate::auth::user_provider::StaticUserProvider;
pub mod user_provider;
#[async_trait::async_trait]
pub trait UserProvider: Send + Sync {
fn name(&self) -> &str;
async fn auth(&self, id: Identity<'_>, password: Password<'_>) -> Result<UserInfo>;
/// [`authenticate`] checks whether a user is valid and allowed to access the database.
async fn authenticate(&self, id: Identity<'_>, password: Password<'_>) -> Result<UserInfo>;
/// [`authorize`] checks whether a connection request
/// from a certain user to a certain catalog/schema is legal.
/// This method should be called after [`authenticate`].
async fn authorize(&self, catalog: &str, schema: &str, user_info: &UserInfo) -> Result<()>;
}
pub type UserProviderRef = Arc<dyn UserProvider>;
@@ -76,6 +82,12 @@ pub enum Error {
#[snafu(display("Invalid config value: {}, {}", value, msg))]
InvalidConfig { value: String, msg: String },
#[snafu(display("Illegal runtime param: {}", msg))]
IllegalParam { msg: String },
#[snafu(display("Internal state error: {}", msg))]
InternalState { msg: String },
#[snafu(display("IO error, source: {}", source))]
Io {
source: std::io::Error,
@@ -96,18 +108,33 @@ pub enum Error {
#[snafu(display("Username and password does not match, username: {}", username))]
UserPasswordMismatch { username: String },
#[snafu(display(
"User {} is not allowed to access catalog {} and schema {}",
username,
catalog,
schema
))]
AccessDenied {
catalog: String,
schema: String,
username: String,
},
}
impl ErrorExt for Error {
fn status_code(&self) -> StatusCode {
match self {
Error::InvalidConfig { .. } => StatusCode::InvalidArguments,
Error::IllegalParam { .. } => StatusCode::InvalidArguments,
Error::InternalState { .. } => StatusCode::Unexpected,
Error::Io { .. } => StatusCode::Internal,
Error::AuthBackend { .. } => StatusCode::Internal,
Error::UserNotFound { .. } => StatusCode::UserNotFound,
Error::UnsupportedPasswordType { .. } => StatusCode::UnsupportedPasswordType,
Error::UserPasswordMismatch { .. } => StatusCode::UserPasswordMismatch,
Error::AccessDenied { .. } => StatusCode::AccessDenied,
}
}
@@ -121,108 +148,3 @@ impl ErrorExt for Error {
}
pub type Result<T> = std::result::Result<T, Error>;
#[cfg(test)]
pub mod test {
use super::{Identity, Password, UserInfo, UserProvider};
pub struct MockUserProvider {}
#[async_trait::async_trait]
impl UserProvider for MockUserProvider {
fn name(&self) -> &str {
"mock_user_provider"
}
async fn auth(
&self,
id: Identity<'_>,
password: Password<'_>,
) -> Result<UserInfo, super::Error> {
match id {
Identity::UserId(username, _host) => match password {
Password::PlainText(password) => {
if username == "greptime" {
if password == "greptime" {
return Ok(UserInfo::new("greptime"));
} else {
return super::UserPasswordMismatchSnafu {
username: username.to_string(),
}
.fail();
}
} else {
return super::UserNotFoundSnafu {
username: username.to_string(),
}
.fail();
}
}
_ => super::UnsupportedPasswordTypeSnafu {
password_type: "mysql_native_password",
}
.fail(),
},
}
}
}
}
#[cfg(test)]
mod tests {
use super::test::MockUserProvider;
use super::{Identity, Password, UserProvider};
use crate::auth;
#[tokio::test]
async fn test_auth_by_plain_text() {
let user_provider = MockUserProvider {};
assert_eq!("mock_user_provider", user_provider.name());
// auth success
let auth_result = user_provider
.auth(
Identity::UserId("greptime", None),
Password::PlainText("greptime"),
)
.await;
assert!(auth_result.is_ok());
assert_eq!("greptime", auth_result.unwrap().username());
// auth failed, unsupported password type
let auth_result = user_provider
.auth(
Identity::UserId("greptime", None),
Password::MysqlNativePassword(b"hashed_value", b"salt"),
)
.await;
assert!(auth_result.is_err());
matches!(
auth_result.err().unwrap(),
auth::Error::UnsupportedPasswordType { .. }
);
// auth failed, err: user not exist.
let auth_result = user_provider
.auth(
Identity::UserId("not_exist_username", None),
Password::PlainText("greptime"),
)
.await;
assert!(auth_result.is_err());
matches!(auth_result.err().unwrap(), auth::Error::UserNotFound { .. });
// auth failed, err: wrong password
let auth_result = user_provider
.auth(
Identity::UserId("greptime", None),
Password::PlainText("wrong_password"),
)
.await;
assert!(auth_result.is_err());
matches!(
auth_result.err().unwrap(),
auth::Error::UserPasswordMismatch { .. }
);
}
}

View File

@@ -99,7 +99,11 @@ impl UserProvider for StaticUserProvider {
STATIC_USER_PROVIDER
}
async fn auth(&self, input_id: Identity<'_>, input_pwd: Password<'_>) -> Result<UserInfo> {
async fn authenticate(
&self,
input_id: Identity<'_>,
input_pwd: Password<'_>,
) -> Result<UserInfo> {
match input_id {
Identity::UserId(username, _) => {
let save_pwd = self.users.get(username).context(UserNotFoundSnafu {
@@ -129,6 +133,11 @@ impl UserProvider for StaticUserProvider {
}
}
}
async fn authorize(&self, _catalog: &str, _schema: &str, _user_info: &UserInfo) -> Result<()> {
// default allow all
Ok(())
}
}
pub fn auth_mysql(
@@ -209,7 +218,7 @@ pub mod test {
async fn test_auth(provider: &dyn UserProvider, username: &str, password: &str) {
let re = provider
.auth(
.authenticate(
Identity::UserId(username, None),
Password::PlainText(password),
)

View File

@@ -353,6 +353,12 @@ impl From<std::io::Error> for Error {
}
}
impl From<auth::Error> for Error {
fn from(e: auth::Error) -> Self {
Error::Auth { source: e }
}
}
impl IntoResponse for Error {
fn into_response(self) -> Response {
let (status, error_message) = match self {

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
mod authorize;
pub mod authorize;
pub mod handler;
pub mod influxdb;
pub mod opentsdb;
@@ -92,8 +92,8 @@ pub(crate) fn query_context_from_db(
}
}
const HTTP_API_VERSION: &str = "v1";
const HTTP_API_PREFIX: &str = "/v1/";
pub const HTTP_API_VERSION: &str = "v1";
pub const HTTP_API_PREFIX: &str = "/v1/";
pub struct HttpServer {
sql_handler: ServerSqlQueryHandlerRef,

View File

@@ -1,3 +1,4 @@
use std::collections::HashMap;
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
@@ -11,7 +12,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::marker::PhantomData;
use axum::http::{self, Request, StatusCode};
@@ -23,7 +23,8 @@ use session::context::UserInfo;
use snafu::{OptionExt, ResultExt};
use tower_http::auth::AsyncAuthorizeRequest;
use crate::auth::{Identity, UserProviderRef};
use crate::auth::Error::IllegalParam;
use crate::auth::{Identity, IllegalParamSnafu, InternalStateSnafu, UserProviderRef};
use crate::error::{self, Result};
use crate::http::HTTP_API_PREFIX;
@@ -70,45 +71,84 @@ where
return Ok(request);
};
let (scheme, credential) = match auth_header(&request) {
Ok(auth_header) => auth_header,
// do authenticate
match authenticate(&user_provider, &request).await {
Ok(user_info) => {
request.extensions_mut().insert(user_info);
}
Err(e) => {
error!("failed to get http authorize header, err: {:?}", e);
error!("authenticate failed: {}", e);
return Err(unauthorized_resp());
}
};
}
match scheme {
AuthScheme::Basic => {
let (username, password) = match decode_basic(credential) {
Ok(basic_auth) => basic_auth,
Err(e) => {
error!("failed to decode basic authorize, err: {:?}", e);
return Err(unauthorized_resp());
}
};
match user_provider
.auth(
Identity::UserId(&username, None),
crate::auth::Password::PlainText(&password),
)
.await
{
Ok(user_info) => {
request.extensions_mut().insert(user_info);
Ok(request)
}
Err(e) => {
error!("failed to auth, err: {:?}", e);
Err(unauthorized_resp())
}
}
match authorize(&user_provider, &request).await {
Ok(_) => Ok(request),
Err(e) => {
error!("authorize failed: {}", e);
Err(unauthorized_resp())
}
}
})
}
}
async fn authorize<B: Send + Sync + 'static>(
user_provider: &UserProviderRef,
request: &Request<B>,
) -> crate::auth::Result<()> {
// try get database name
let query = request.uri().query().unwrap_or_default();
let input_database = match serde_urlencoded::from_str::<HashMap<String, String>>(query) {
Ok(query_map) => query_map
.get("db")
.context(IllegalParamSnafu {
msg: "fail to get valid database from http query",
})?
.to_owned(),
Err(e) => IllegalParamSnafu {
msg: format!("fail to parse http query: {e}"),
}
.fail()?,
};
let (catalog, database) =
crate::parse_catalog_and_schema_from_client_database_name(&input_database);
let user_info = request
.extensions()
.get::<UserInfo>()
.context(InternalStateSnafu {
msg: "no user info provided while authorizing",
})?;
user_provider.authorize(catalog, database, user_info).await
}
async fn authenticate<B: Send + Sync + 'static>(
user_provider: &UserProviderRef,
request: &Request<B>,
) -> crate::auth::Result<UserInfo> {
let (scheme, credential) = auth_header(request).map_err(|e| IllegalParam {
msg: format!("failed to get http authorize header, err: {e:?}"),
})?;
match scheme {
AuthScheme::Basic => {
let (username, password) = decode_basic(credential).map_err(|e| IllegalParam {
msg: format!("failed to decode basic authorize, err: {e:?}"),
})?;
Ok(user_provider
.authenticate(
Identity::UserId(&username, None),
crate::auth::Password::PlainText(&password),
)
.await?)
}
}
}
fn unauthorized_resp<RespBody>() -> Response<RespBody>
where
RespBody: Body + Default,
@@ -171,79 +211,7 @@ fn decode_basic(credential: Credential) -> Result<(Username, Password)> {
#[cfg(test)]
mod tests {
use std::marker::PhantomData;
use std::sync::Arc;
use axum::body::BoxBody;
use axum::http;
use hyper::Request;
use session::context::UserInfo;
use tower_http::auth::AsyncAuthorizeRequest;
use super::{auth_header, decode_basic, AuthScheme, HttpAuth};
use crate::auth::test::MockUserProvider;
use crate::auth::UserProvider;
use crate::error;
use crate::error::Result;
#[tokio::test]
async fn test_http_auth() {
let mut http_auth: HttpAuth<BoxBody> = HttpAuth {
user_provider: None,
_ty: PhantomData,
};
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
let auth_res = http_auth.authorize(req).await.unwrap();
let user_info: &UserInfo = auth_res.extensions().get().unwrap();
let default = UserInfo::default();
assert_eq!(default.username(), user_info.username());
// In mock user provider, right username:password == "greptime:greptime"
let mock_user_provider = Some(Arc::new(MockUserProvider {}) as Arc<dyn UserProvider>);
let mut http_auth: HttpAuth<BoxBody> = HttpAuth {
user_provider: mock_user_provider,
_ty: PhantomData,
};
// base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU="
let req = mock_http_request(Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="), None).unwrap();
let req = http_auth.authorize(req).await.unwrap();
let user_info: &UserInfo = req.extensions().get().unwrap();
let default = UserInfo::default();
assert_eq!(default.username(), user_info.username());
let req = mock_http_request(None, None).unwrap();
let auth_res = http_auth.authorize(req).await;
assert!(auth_res.is_err());
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
let wrong_req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
let auth_res = http_auth.authorize(wrong_req).await;
assert!(auth_res.is_err());
}
#[tokio::test]
async fn test_whitelist_no_auth() {
// In mock user provider, right username:password == "greptime:greptime"
let mock_user_provider = Some(Arc::new(MockUserProvider {}) as Arc<dyn UserProvider>);
let mut http_auth: HttpAuth<BoxBody> = HttpAuth {
user_provider: mock_user_provider,
_ty: PhantomData,
};
// base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU="
// try auth path first
let req = mock_http_request(None, None).unwrap();
let req = http_auth.authorize(req).await;
assert!(req.is_err());
// try whitelist path
let req = mock_http_request(None, Some("http://localhost/health")).unwrap();
let req = http_auth.authorize(req).await;
assert!(req.is_ok());
}
use super::*;
#[test]
fn test_decode_basic() {

View File

@@ -28,7 +28,7 @@ use crate::http::{ApiState, JsonResponse};
#[derive(Debug, Default, Serialize, Deserialize, JsonSchema)]
pub struct SqlQuery {
pub database: Option<String>,
pub db: Option<String>,
pub sql: Option<String>,
}
@@ -43,7 +43,7 @@ pub async fn sql(
let sql_handler = &state.sql_handler;
let start = Instant::now();
let resp = if let Some(sql) = &params.sql {
match super::query_context_from_db(sql_handler.clone(), params.database) {
match super::query_context_from_db(sql_handler.clone(), params.db) {
Ok(query_ctx) => {
JsonResponse::from_output(sql_handler.do_query(sql, query_ctx).await).await
}

View File

@@ -123,7 +123,7 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
return false;
}
};
match user_provider.auth(user_id, password).await {
match user_provider.authenticate(user_id, password).await {
Ok(userinfo) => {
user_info = Some(userinfo);
}
@@ -190,6 +190,12 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
error::DatabaseNotFoundSnafu { catalog, schema }
);
if let Some(schema_validator) = &self.user_provider {
schema_validator
.authorize(catalog, schema, &self.session.user_info())
.await?;
}
let context = self.session.context();
context.set_current_catalog(catalog);
context.set_current_schema(database);

View File

@@ -38,6 +38,15 @@ use crate::tls::TlsOption;
// Default size of ResultSet write buffer: 100KB
const DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE: usize = 100 * 1024;
struct MysqlRuntimeOption {
query_handler: ServerSqlQueryHandlerRef,
tls_conf: Option<Arc<ServerConfig>>,
force_tls: bool,
user_provider: Option<UserProviderRef>,
}
type MysqlRuntimeOptionRef = Arc<MysqlRuntimeOption>;
pub struct MysqlServer {
base_server: BaseTcpServer,
query_handler: ServerSqlQueryHandlerRef,
@@ -77,19 +86,19 @@ impl MysqlServer {
let user_provider = user_provider.clone();
let tls_conf = tls_conf.clone();
let mysql_runtime_option = Arc::new(MysqlRuntimeOption {
query_handler,
tls_conf,
force_tls,
user_provider,
});
async move {
match tcp_stream {
Err(error) => error!("Broken pipe: {}", error), // IoError doesn't impl ErrorExt.
Ok(io_stream) => {
if let Err(error) = Self::handle(
io_stream,
io_runtime,
query_handler,
tls_conf,
force_tls,
user_provider,
)
.await
if let Err(error) =
Self::handle(io_stream, io_runtime, mysql_runtime_option).await
{
error!(error; "Unexpected error when handling TcpStream");
};
@@ -102,15 +111,12 @@ impl MysqlServer {
async fn handle(
stream: TcpStream,
io_runtime: Arc<Runtime>,
query_handler: ServerSqlQueryHandlerRef,
tls_conf: Option<Arc<ServerConfig>>,
force_tls: bool,
user_provider: Option<UserProviderRef>,
runtime_opts: MysqlRuntimeOptionRef,
) -> Result<()> {
info!("MySQL connection coming from: {}", stream.peer_addr()?);
io_runtime .spawn(async move {
// TODO(LFC): Use `output_stream` to write large MySQL ResultSet to client.
if let Err(e) = Self::do_handle(stream, query_handler, tls_conf, force_tls, user_provider).await {
if let Err(e) = Self::do_handle(stream, runtime_opts).await {
// TODO(LFC): Write this error to client as well, in MySQL text protocol.
// Looks like we have to expose opensrv-mysql's `PacketWriter`?
error!(e; "Internal error occurred during query exec, server actively close the channel to let client try next time.")
@@ -120,28 +126,31 @@ impl MysqlServer {
Ok(())
}
async fn do_handle(
stream: TcpStream,
query_handler: ServerSqlQueryHandlerRef,
tls_conf: Option<Arc<ServerConfig>>,
force_tls: bool,
user_provider: Option<UserProviderRef>,
) -> Result<()> {
let mut shim = MysqlInstanceShim::create(query_handler, stream.peer_addr()?, user_provider);
async fn do_handle(stream: TcpStream, runtime_opts: MysqlRuntimeOptionRef) -> Result<()> {
let mut shim = MysqlInstanceShim::create(
runtime_opts.query_handler.clone(),
stream.peer_addr()?,
runtime_opts.user_provider.clone(),
);
let (mut r, w) = stream.into_split();
let mut w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w);
let ops = IntermediaryOptions::default();
let (client_tls, init_params) =
AsyncMysqlIntermediary::init_before_ssl(&mut shim, &mut r, &mut w, &tls_conf).await?;
let (client_tls, init_params) = AsyncMysqlIntermediary::init_before_ssl(
&mut shim,
&mut r,
&mut w,
&runtime_opts.tls_conf,
)
.await?;
if force_tls && !client_tls {
if runtime_opts.force_tls && !client_tls {
return Err(Error::TlsRequired {
server: "mysql".to_owned(),
});
}
match tls_conf {
match runtime_opts.tls_conf.clone() {
Some(tls_conf) if client_tls => {
secure_run_with_options(shim, w, ops, tls_conf, init_params).await
}

View File

@@ -23,6 +23,7 @@ use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use pgwire::messages::response::ErrorResponse;
use pgwire::messages::startup::Authentication;
use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
use session::context::{UserInfo, DEFAULT_USERNAME};
use snafu::ResultExt;
use crate::auth::{Identity, Password, UserProviderRef};
@@ -30,14 +31,15 @@ use crate::error;
use crate::error::Result;
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
struct PgPwdVerifier {
struct PgLoginVerifier {
user_provider: Option<UserProviderRef>,
}
#[allow(dead_code)]
struct LoginInfo {
user: Option<String>,
database: Option<String>,
catalog: Option<String>,
schema: Option<String>,
host: String,
}
@@ -48,27 +50,31 @@ impl LoginInfo {
{
LoginInfo {
user: client.metadata().get(super::METADATA_USER).map(Into::into),
database: client
catalog: client
.metadata()
.get(super::METADATA_DATABASE)
.get(super::METADATA_CATALOG)
.map(Into::into),
schema: client
.metadata()
.get(super::METADATA_SCHEMA)
.map(Into::into),
host: client.socket_addr().ip().to_string(),
}
}
}
impl PgPwdVerifier {
async fn verify_pwd(&self, password: &str, login: LoginInfo) -> Result<bool> {
impl PgLoginVerifier {
async fn verify_pwd(&self, password: &str, login: &LoginInfo) -> Result<bool> {
if let Some(user_provider) = &self.user_provider {
let user_name = match login.user {
let user_name = match &login.user {
Some(name) => name,
None => return Ok(false),
};
// TODO(fys): pass user_info to context
let _user_info = user_provider
.auth(
Identity::UserId(&user_name, None),
.authenticate(
Identity::UserId(user_name, None),
Password::PlainText(password),
)
.await
@@ -76,6 +82,29 @@ impl PgPwdVerifier {
}
Ok(true)
}
async fn authorize(&self, login: &LoginInfo) -> Result<bool> {
// at this time, username in login info should be valid
// TODO(shuiyisong): change to use actually user_info from session
if let Some(user_provider) = &self.user_provider {
let user_name = match &login.user {
Some(name) => name,
None => return Ok(false),
};
let catalog = match &login.catalog {
Some(name) => name,
None => return Ok(false),
};
let schema = match &login.schema {
Some(name) => name,
None => return Ok(false),
};
user_provider
.authorize(catalog, schema, &UserInfo::new(user_name))
.await?;
}
Ok(true)
}
}
struct GreptimeDBStartupParameters {
@@ -106,7 +135,7 @@ impl ServerParameterProvider for GreptimeDBStartupParameters {
}
pub struct PgAuthStartupHandler {
verifier: PgPwdVerifier,
verifier: PgLoginVerifier,
param_provider: GreptimeDBStartupParameters,
force_tls: bool,
query_handler: ServerSqlQueryHandlerRef,
@@ -119,7 +148,7 @@ impl PgAuthStartupHandler {
query_handler: ServerSqlQueryHandlerRef,
) -> Self {
PgAuthStartupHandler {
verifier: PgPwdVerifier { user_provider },
verifier: PgLoginVerifier { user_provider },
param_provider: GreptimeDBStartupParameters::new(),
force_tls,
query_handler,
@@ -173,22 +202,50 @@ impl StartupHandler for PgAuthStartupHandler {
))
.await?;
} else {
// no user is provided, use default user
// and still do authorization
let mut login_info = LoginInfo::from_client_info(client);
login_info.user = Some(DEFAULT_USERNAME.to_string());
let authorize_result = self.verifier.authorize(&login_info).await;
if !matches!(authorize_result, Ok(true)) {
return send_error(
client,
"FATAL",
"28P01",
"password authorization failed".to_owned(),
)
.await;
}
auth::finish_authentication(client, &self.param_provider).await;
}
}
PgWireFrontendMessage::Password(ref pwd) => {
let login_info = LoginInfo::from_client_info(client);
if let Ok(true) = self.verifier.verify_pwd(pwd.password(), login_info).await {
auth::finish_authentication(client, &self.param_provider).await
} else {
send_error(
// do authenticate
let authenticate_result =
self.verifier.verify_pwd(pwd.password(), &login_info).await;
if !matches!(authenticate_result, Ok(true)) {
return send_error(
client,
"FATAL",
"28P01",
"Password authentication failed".to_owned(),
"password authentication failed".to_owned(),
)
.await?;
.await;
}
// do authorize
let authorize_result = self.verifier.authorize(&login_info).await;
if !matches!(authorize_result, Ok(true)) {
return send_error(
client,
"FATAL",
"28P01",
"password authorization failed".to_owned(),
)
.await;
}
auth::finish_authentication(client, &self.param_provider).await;
}
_ => {}
}

197
src/servers/tests/auth.rs Normal file
View File

@@ -0,0 +1,197 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use servers::auth::user_provider::auth_mysql;
use servers::auth::{
AccessDeniedSnafu, Identity, Password, UnsupportedPasswordTypeSnafu, UserNotFoundSnafu,
UserPasswordMismatchSnafu, UserProvider,
};
use session::context::UserInfo;
pub struct DatabaseAuthInfo<'a> {
pub catalog: &'a str,
pub schema: &'a str,
pub username: &'a str,
}
pub struct MockUserProvider {
pub catalog: String,
pub schema: String,
pub username: String,
}
impl Default for MockUserProvider {
fn default() -> Self {
MockUserProvider {
catalog: "greptime".to_owned(),
schema: "public".to_owned(),
username: "greptime".to_owned(),
}
}
}
impl MockUserProvider {
pub fn set_authorization_info(&mut self, info: DatabaseAuthInfo) {
self.catalog = info.catalog.to_owned();
self.schema = info.schema.to_owned();
self.username = info.username.to_owned();
}
}
#[async_trait::async_trait]
impl UserProvider for MockUserProvider {
fn name(&self) -> &str {
"mock_user_provider"
}
async fn authenticate(
&self,
id: Identity<'_>,
password: Password<'_>,
) -> servers::auth::Result<UserInfo> {
match id {
Identity::UserId(username, _host) => match password {
Password::PlainText(password) => {
if username == "greptime" {
if password == "greptime" {
Ok(UserInfo::new("greptime"))
} else {
UserPasswordMismatchSnafu {
username: username.to_string(),
}
.fail()
}
} else {
UserNotFoundSnafu {
username: username.to_string(),
}
.fail()
}
}
Password::MysqlNativePassword(auth_data, salt) => {
auth_mysql(auth_data, salt, username, "greptime".as_bytes())
.map(|_| UserInfo::new(username))
}
_ => UnsupportedPasswordTypeSnafu {
password_type: "mysql_native_password",
}
.fail(),
},
}
}
async fn authorize(
&self,
catalog: &str,
schema: &str,
user_info: &UserInfo,
) -> servers::auth::Result<()> {
if catalog == self.catalog && schema == self.schema && user_info.username() == self.username
{
Ok(())
} else {
AccessDeniedSnafu {
catalog: catalog.to_string(),
schema: schema.to_string(),
username: user_info.username().to_string(),
}
.fail()
}
}
}
#[tokio::test]
async fn test_auth_by_plain_text() {
let user_provider = MockUserProvider::default();
assert_eq!("mock_user_provider", user_provider.name());
// auth success
let auth_result = user_provider
.authenticate(
Identity::UserId("greptime", None),
Password::PlainText("greptime"),
)
.await;
assert!(auth_result.is_ok());
assert_eq!("greptime", auth_result.unwrap().username());
// auth failed, unsupported password type
let auth_result = user_provider
.authenticate(
Identity::UserId("greptime", None),
Password::PgMD5(b"hashed_value", b"salt"),
)
.await;
assert!(auth_result.is_err());
matches!(
auth_result.err().unwrap(),
servers::auth::Error::UnsupportedPasswordType { .. }
);
// auth failed, err: user not exist.
let auth_result = user_provider
.authenticate(
Identity::UserId("not_exist_username", None),
Password::PlainText("greptime"),
)
.await;
assert!(auth_result.is_err());
matches!(
auth_result.err().unwrap(),
servers::auth::Error::UserNotFound { .. }
);
// auth failed, err: wrong password
let auth_result = user_provider
.authenticate(
Identity::UserId("greptime", None),
Password::PlainText("wrong_password"),
)
.await;
assert!(auth_result.is_err());
matches!(
auth_result.err().unwrap(),
servers::auth::Error::UserPasswordMismatch { .. }
);
}
#[tokio::test]
async fn test_schema_validate() {
let mut validator = MockUserProvider::default();
validator.set_authorization_info(DatabaseAuthInfo {
catalog: "greptime",
schema: "public",
username: "test_user",
});
let right_user = UserInfo::new("test_user");
let wrong_user = UserInfo::default();
// check catalog
let re = validator
.authorize("greptime_wrong", "public", &right_user)
.await;
assert!(re.is_err());
// check schema
let re = validator
.authorize("greptime", "public_wrong", &right_user)
.await;
assert!(re.is_err());
// check username
let re = validator.authorize("greptime", "public", &wrong_user).await;
assert!(re.is_err());
// check ok
let re = validator.authorize("greptime", "public", &right_user).await;
assert!(re.is_ok());
}

View File

@@ -0,0 +1,120 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use axum::body::BoxBody;
use axum::http;
use hyper::Request;
use servers::auth::UserProvider;
use servers::http::authorize::HttpAuth;
use session::context::UserInfo;
use tower_http::auth::AsyncAuthorizeRequest;
use crate::auth::MockUserProvider;
#[tokio::test]
async fn test_http_auth() {
let mut http_auth: HttpAuth<BoxBody> = HttpAuth::new(None);
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
let auth_res = http_auth.authorize(req).await.unwrap();
let user_info: &UserInfo = auth_res.extensions().get().unwrap();
let default = UserInfo::default();
assert_eq!(default.username(), user_info.username());
// In mock user provider, right username:password == "greptime:greptime"
let mock_user_provider = Some(Arc::new(MockUserProvider::default()) as Arc<dyn UserProvider>);
let mut http_auth: HttpAuth<BoxBody> = HttpAuth::new(mock_user_provider);
// base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU="
let req = mock_http_request(Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="), None).unwrap();
let req = http_auth.authorize(req).await.unwrap();
let user_info: &UserInfo = req.extensions().get().unwrap();
let default = UserInfo::default();
assert_eq!(default.username(), user_info.username());
let req = mock_http_request(None, None).unwrap();
let auth_res = http_auth.authorize(req).await;
assert!(auth_res.is_err());
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
let wrong_req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
let auth_res = http_auth.authorize(wrong_req).await;
assert!(auth_res.is_err());
}
#[tokio::test]
async fn test_schema_validating() {
// In mock user provider, right username:password == "greptime:greptime"
let provider = MockUserProvider::default();
let mock_user_provider = Some(Arc::new(provider) as Arc<dyn UserProvider>);
let mut http_auth: HttpAuth<BoxBody> = HttpAuth::new(mock_user_provider);
// base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU="
// http://localhost/{http_api_version}/sql?db=greptime
let version = servers::http::HTTP_API_VERSION;
let req = mock_http_request(
Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="),
Some(format!("http://localhost/{version}/sql?db=public").as_str()),
)
.unwrap();
let req = http_auth.authorize(req).await.unwrap();
let user_info: &UserInfo = req.extensions().get().unwrap();
let default = UserInfo::default();
assert_eq!(default.username(), user_info.username());
// wrong database
let req = mock_http_request(
Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="),
Some(format!("http://localhost/{version}/sql?db=wrong").as_str()),
)
.unwrap();
let result = http_auth.authorize(req).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_whitelist_no_auth() {
// In mock user provider, right username:password == "greptime:greptime"
let mock_user_provider = Some(Arc::new(MockUserProvider::default()) as Arc<dyn UserProvider>);
let mut http_auth: HttpAuth<BoxBody> = HttpAuth::new(mock_user_provider);
// base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU="
// try auth path first
let req = mock_http_request(None, None).unwrap();
let req = http_auth.authorize(req).await;
assert!(req.is_err());
// try whitelist path
let req = mock_http_request(None, Some("http://localhost/health")).unwrap();
let req = http_auth.authorize(req).await;
assert!(req.is_ok());
}
// copy from http::authorize
fn mock_http_request(
auth_header: Option<&str>,
uri: Option<&str>,
) -> servers::error::Result<Request<()>> {
let http_api_version = servers::http::HTTP_API_VERSION;
let mut req = Request::builder()
.uri(uri.unwrap_or(format!("http://localhost/{http_api_version}/sql?db=public").as_str()));
if let Some(auth_header) = auth_header {
req = req.header(http::header::AUTHORIZATION, auth_header);
}
Ok(req.body(()).unwrap())
}

View File

@@ -135,7 +135,7 @@ fn create_invalid_script_query() -> Query<script_handler::ScriptQuery> {
fn create_query() -> Query<http_handler::SqlQuery> {
Query(http_handler::SqlQuery {
sql: Some("select sum(uint32s) from numbers limit 20".to_string()),
database: None,
db: None,
})
}

View File

@@ -19,7 +19,6 @@ use async_trait::async_trait;
use axum::{http, Router};
use axum_test_helper::TestClient;
use common_query::Output;
use servers::auth::user_provider::StaticUserProvider;
use servers::error::{Error, Result};
use servers::http::{HttpOptions, HttpServer};
use servers::influxdb::InfluxdbRequest;
@@ -28,8 +27,10 @@ use servers::query_handler::InfluxdbLineProtocolHandler;
use session::context::QueryContextRef;
use tokio::sync::mpsc;
use crate::auth::{DatabaseAuthInfo, MockUserProvider};
struct DummyInstance {
tx: mpsc::Sender<(String, String)>,
tx: Arc<mpsc::Sender<(String, String)>>,
}
#[async_trait]
@@ -66,11 +67,18 @@ impl SqlQueryHandler for DummyInstance {
}
}
fn make_test_app(tx: mpsc::Sender<(String, String)>) -> Router {
fn make_test_app(tx: Arc<mpsc::Sender<(String, String)>>, db_name: Option<&str>) -> Router {
let instance = Arc::new(DummyInstance { tx });
let mut server = HttpServer::new(instance.clone(), HttpOptions::default());
let up = StaticUserProvider::try_from("cmd:greptime=greptime").unwrap();
server.set_user_provider(Arc::new(up));
let mut user_provider = MockUserProvider::default();
if let Some(name) = db_name {
user_provider.set_authorization_info(DatabaseAuthInfo {
catalog: "greptime",
schema: name,
username: "greptime",
})
}
server.set_user_provider(Arc::new(user_provider));
server.set_influxdb_handler(instance);
server.make_app()
@@ -79,13 +87,14 @@ fn make_test_app(tx: mpsc::Sender<(String, String)>) -> Router {
#[tokio::test]
async fn test_influxdb_write() {
let (tx, mut rx) = mpsc::channel(100);
let tx = Arc::new(tx);
let app = make_test_app(tx);
let app = make_test_app(tx.clone(), None);
let client = TestClient::new(app);
// right request
let result = client
.post("/v1/influxdb/write")
.post("/v1/influxdb/write?db=public")
.body("monitor,host=host1 cpu=1.2 1664370459457010101")
.header(
http::header::AUTHORIZATION,
@@ -96,6 +105,10 @@ async fn test_influxdb_write() {
assert_eq!(result.status(), 204);
assert!(result.text().await.is_empty());
// make new app for db=influxdb
let app = make_test_app(tx, Some("influxdb"));
let client = TestClient::new(app);
let result = client
.post("/v1/influxdb/write?db=influxdb")
.body("monitor,host=host1 cpu=1.2 1664370459457010101")
@@ -110,7 +123,7 @@ async fn test_influxdb_write() {
// bad request
let result = client
.post("/v1/influxdb/write")
.post("/v1/influxdb/write?db=influxdb")
.body("monitor, host=host1 cpu=1.2 1664370459457010101")
.header(
http::header::AUTHORIZATION,

View File

@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
mod authorize;
mod http_handler_test;
mod influxdb_test;
mod opentsdb_test;

View File

@@ -22,21 +22,20 @@ use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_query::Output;
use query::parser::QueryLanguageParser;
use query::{QueryEngineFactory, QueryEngineRef};
use servers::error::{Error, Result};
use servers::query_handler::{ScriptHandler, ScriptHandlerRef};
use table::test_util::MemTable;
mod http;
mod mysql;
use script::engine::{CompileContext, EvalContext, Script, ScriptEngine};
use script::python::{PyEngine, PyScript};
use servers::error::{Error, Result};
use servers::query_handler::sql::{ServerSqlQueryHandlerRef, SqlQueryHandler};
use servers::query_handler::{ScriptHandler, ScriptHandlerRef};
use session::context::QueryContextRef;
use table::test_util::MemTable;
mod auth;
mod http;
mod interceptor;
mod mysql;
mod opentsdb;
mod postgres;
mod py_script;
struct DummyInstance {

View File

@@ -24,17 +24,21 @@ use mysql_async::prelude::*;
use mysql_async::SslOpts;
use rand::rngs::StdRng;
use rand::Rng;
use servers::auth::user_provider::StaticUserProvider;
use servers::error::Result;
use servers::mysql::server::MysqlServer;
use servers::server::Server;
use servers::tls::TlsOption;
use table::test_util::MemTable;
use crate::auth::{DatabaseAuthInfo, MockUserProvider};
use crate::create_testing_sql_query_handler;
use crate::mysql::{all_datatype_testing_data, MysqlTextRow, TestingData};
fn create_mysql_server(table: MemTable, tls: TlsOption) -> Result<Box<dyn Server>> {
fn create_mysql_server(
table: MemTable,
tls: TlsOption,
auth_info: Option<DatabaseAuthInfo>,
) -> Result<Box<dyn Server>> {
let query_handler = create_testing_sql_query_handler(table);
let io_runtime = Arc::new(
RuntimeBuilder::default()
@@ -44,7 +48,10 @@ fn create_mysql_server(table: MemTable, tls: TlsOption) -> Result<Box<dyn Server
.unwrap(),
);
let provider = StaticUserProvider::try_from("cmd:greptime=greptime").unwrap();
let mut provider = MockUserProvider::default();
if let Some(auth_info) = auth_info {
provider.set_authorization_info(auth_info);
}
Ok(MysqlServer::create_server(
query_handler,
@@ -58,7 +65,7 @@ fn create_mysql_server(table: MemTable, tls: TlsOption) -> Result<Box<dyn Server
async fn test_start_mysql_server() -> Result<()> {
let table = MemTable::default_numbers_table();
let mysql_server = create_mysql_server(table, Default::default())?;
let mysql_server = create_mysql_server(table, Default::default(), None)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let result = mysql_server.start(listening).await;
assert!(result.is_ok());
@@ -71,13 +78,54 @@ async fn test_start_mysql_server() -> Result<()> {
Ok(())
}
#[tokio::test]
async fn test_schema_validation() -> Result<()> {
async fn generate_server(auth_info: DatabaseAuthInfo<'_>) -> Result<(Box<dyn Server>, u16)> {
let table = MemTable::default_numbers_table();
let mysql_server = create_mysql_server(table, Default::default(), Some(auth_info))?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
Ok((mysql_server, server_addr.port()))
}
common_telemetry::init_default_ut_logging();
let (mysql_server, server_port) = generate_server(DatabaseAuthInfo {
catalog: "greptime",
schema: "public",
username: "greptime",
})
.await?;
//TODO(shuiyisong): mysql conn without dbname rejection is not implemented yet, add test later.
let pass = create_connection(server_port, Some("public"), false).await;
assert!(pass.is_ok());
let result = mysql_server.shutdown().await;
assert!(result.is_ok());
// change to another username
let (mysql_server, server_port) = generate_server(DatabaseAuthInfo {
catalog: "greptime",
schema: "public",
username: "no_access_user",
})
.await?;
let fail = create_connection(server_port, Some("public"), false).await;
assert!(fail.is_err());
let result = mysql_server.shutdown().await;
assert!(result.is_ok());
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_shutdown_mysql_server() -> Result<()> {
common_telemetry::init_default_ut_logging();
let table = MemTable::default_numbers_table();
let mysql_server = create_mysql_server(table, Default::default())?;
let mysql_server = create_mysql_server(table, Default::default(), None)?;
let result = mysql_server.shutdown().await;
assert!(result
.unwrap_err()
@@ -193,7 +241,7 @@ async fn test_server_required_secure_client_plain() -> Result<()> {
let recordbatch = RecordBatch::new(schema, columns).unwrap();
let table = MemTable::new("all_datatypes", recordbatch);
let mysql_server = create_mysql_server(table, server_tls)?;
let mysql_server = create_mysql_server(table, server_tls, None)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
@@ -219,7 +267,7 @@ async fn test_db_name() -> Result<()> {
let recordbatch = RecordBatch::new(schema, columns).unwrap();
let table = MemTable::new("all_datatypes", recordbatch);
let mysql_server = create_mysql_server(table, server_tls)?;
let mysql_server = create_mysql_server(table, server_tls, None)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
@@ -247,7 +295,7 @@ async fn do_test_query_all_datatypes(server_tls: TlsOption, client_tls: bool) ->
let recordbatch = RecordBatch::new(schema, columns).unwrap();
let table = MemTable::new("all_datatypes", recordbatch);
let mysql_server = create_mysql_server(table, server_tls)?;
let mysql_server = create_mysql_server(table, server_tls, None)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
@@ -282,7 +330,7 @@ async fn test_query_concurrently() -> Result<()> {
let table = MemTable::default_numbers_table();
let mysql_server = create_mysql_server(table, Default::default())?;
let mysql_server = create_mysql_server(table, Default::default(), None)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
let server_port = server_addr.port();
@@ -332,7 +380,7 @@ async fn create_connection(
.tcp_port(port)
.prefer_socket(false)
.wait_timeout(Some(1000))
.db_name(db_name)
.db_name(db_name.or(Some(DEFAULT_SCHEMA_NAME)))
.user(Some("greptime".to_string()))
.pass(Some("greptime".to_string()));

View File

@@ -22,7 +22,6 @@ use rand::rngs::StdRng;
use rand::Rng;
use rustls::client::{ServerCertVerified, ServerCertVerifier};
use rustls::{Certificate, Error, ServerName};
use servers::auth::user_provider::StaticUserProvider;
use servers::auth::UserProviderRef;
use servers::error::Result;
use servers::postgres::PostgresServer;
@@ -31,12 +30,14 @@ use servers::tls::TlsOption;
use table::test_util::MemTable;
use tokio_postgres::{Client, Error as PgError, NoTls, SimpleQueryMessage};
use crate::auth::{DatabaseAuthInfo, MockUserProvider};
use crate::create_testing_instance;
fn create_postgres_server(
table: MemTable,
check_pwd: bool,
tls: TlsOption,
auth_info: Option<DatabaseAuthInfo>,
) -> Result<Box<dyn Server>> {
let instance = Arc::new(create_testing_instance(table));
let io_runtime = Arc::new(
@@ -47,9 +48,11 @@ fn create_postgres_server(
.unwrap(),
);
let user_provider: Option<UserProviderRef> = if check_pwd {
Some(Arc::new(
StaticUserProvider::try_from("cmd:test_user=test_pwd").unwrap(),
))
let mut provider = MockUserProvider::default();
if let Some(info) = auth_info {
provider.set_authorization_info(info);
}
Some(Arc::new(provider))
} else {
None
};
@@ -66,7 +69,7 @@ fn create_postgres_server(
pub async fn test_start_postgres_server() -> Result<()> {
let table = MemTable::default_numbers_table();
let pg_server = create_postgres_server(table, false, Default::default())?;
let pg_server = create_postgres_server(table, false, Default::default(), None)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let result = pg_server.start(listening).await;
assert!(result.is_ok());
@@ -86,12 +89,52 @@ async fn test_shutdown_pg_server_range() -> Result<()> {
Ok(())
}
#[tokio::test]
async fn test_schema_validating() -> Result<()> {
async fn generate_server(auth_info: DatabaseAuthInfo<'_>) -> Result<(Box<dyn Server>, u16)> {
let table = MemTable::default_numbers_table();
let postgres_server =
create_postgres_server(table, true, Default::default(), Some(auth_info))?;
let listening = "127.0.0.1:5432".parse::<SocketAddr>().unwrap();
let server_addr = postgres_server.start(listening).await.unwrap();
let server_port = server_addr.port();
Ok((postgres_server, server_port))
}
common_telemetry::init_default_ut_logging();
let (pg_server, server_port) = generate_server(DatabaseAuthInfo {
catalog: "greptime",
schema: "public",
username: "greptime",
})
.await?;
let pass = create_plain_connection(server_port, true).await;
assert!(pass.is_ok());
let result = pg_server.shutdown().await;
assert!(result.is_ok());
let (pg_server, server_port) = generate_server(DatabaseAuthInfo {
catalog: "greptime",
schema: "public",
username: "no_right_user",
})
.await?;
let fail = create_plain_connection(server_port, true).await;
assert!(fail.is_err());
let result = pg_server.shutdown().await;
assert!(result.is_ok());
Ok(())
}
// #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> {
common_telemetry::init_default_ut_logging();
let table = MemTable::default_numbers_table();
let postgres_server = create_postgres_server(table, with_pwd, Default::default())?;
let postgres_server = create_postgres_server(table, with_pwd, Default::default(), None)?;
let result = postgres_server.shutdown().await;
assert!(result
.unwrap_err()
@@ -273,7 +316,7 @@ async fn test_using_db() -> Result<()> {
async fn start_test_server(server_tls: TlsOption) -> Result<u16> {
common_telemetry::init_default_ut_logging();
let table = MemTable::default_numbers_table();
let pg_server = create_postgres_server(table, false, server_tls)?;
let pg_server = create_postgres_server(table, false, server_tls, None)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = pg_server.start(listening).await.unwrap();
Ok(server_addr.port())
@@ -301,7 +344,7 @@ async fn create_secure_connection(
) -> std::result::Result<Client, PgError> {
let url = if with_pwd {
format!(
"sslmode=require host=127.0.0.1 port={port} user=test_user password=test_pwd connect_timeout=2, dbname={DEFAULT_SCHEMA_NAME}",
"sslmode=require host=127.0.0.1 port={port} user=greptime password=greptime connect_timeout=2, dbname={DEFAULT_SCHEMA_NAME}",
)
} else {
format!("host=127.0.0.1 port={port} connect_timeout=2 dbname={DEFAULT_SCHEMA_NAME}")
@@ -328,7 +371,7 @@ async fn create_plain_connection(
) -> std::result::Result<Client, PgError> {
let url = if with_pwd {
format!(
"host=127.0.0.1 port={port} user=test_user password=test_pwd connect_timeout=2 dbname={DEFAULT_SCHEMA_NAME}",
"host=127.0.0.1 port={port} user=greptime password=greptime connect_timeout=2 dbname={DEFAULT_SCHEMA_NAME}",
)
} else {
format!("host=127.0.0.1 port={port} connect_timeout=2 dbname={DEFAULT_SCHEMA_NAME}")

View File

@@ -195,7 +195,7 @@ pub async fn test_sql_api(store_type: StorageType) {
// test database given
let res = client
.get("/v1/sql?database=public&sql=select cpu, ts from demo limit 1")
.get("/v1/sql?db=public&sql=select cpu, ts from demo limit 1")
.send()
.await;
assert_eq!(res.status(), StatusCode::OK);
@@ -214,7 +214,7 @@ pub async fn test_sql_api(store_type: StorageType) {
// test database not found
let res = client
.get("/v1/sql?database=notfound&sql=select cpu, ts from demo limit 1")
.get("/v1/sql?db=notfound&sql=select cpu, ts from demo limit 1")
.send()
.await;
assert_eq!(res.status(), StatusCode::OK);
@@ -223,7 +223,7 @@ pub async fn test_sql_api(store_type: StorageType) {
// test catalog-schema given
let res = client
.get("/v1/sql?database=greptime-public&sql=select cpu, ts from demo limit 1")
.get("/v1/sql?db=greptime-public&sql=select cpu, ts from demo limit 1")
.send()
.await;
assert_eq!(res.status(), StatusCode::OK);
@@ -242,7 +242,7 @@ pub async fn test_sql_api(store_type: StorageType) {
// test invalid catalog
let res = client
.get("/v1/sql?database=notfound2-schema&sql=select cpu, ts from demo limit 1")
.get("/v1/sql?db=notfound2-schema&sql=select cpu, ts from demo limit 1")
.send()
.await;
assert_eq!(res.status(), StatusCode::OK);
@@ -251,7 +251,7 @@ pub async fn test_sql_api(store_type: StorageType) {
// test invalid schema
let res = client
.get("/v1/sql?database=greptime-schema&sql=select cpu, ts from demo limit 1")
.get("/v1/sql?db=greptime-schema&sql=select cpu, ts from demo limit 1")
.send()
.await;
assert_eq!(res.status(), StatusCode::OK);