mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-25 09:20:40 +00:00
feat: add authorize to UserProvider trait (#879)
* feat: add SchemaValidator * feat: add schema validator to mysql shim * chore: pass schema validator to http auth layer * feat: add schema validator to http * feat: add schema validator to pg * feat: add schema validator to pg * feat: add schema validator test * chore: remove println in test * chore: use !matches * refactor: refac authenticate and authorize in http auth * refactor: refac authenticate and authorize in http auth * chore: typo * chore: minor change * refactor: merge schema_validator into user_providier * chore: fix license issue * refactor: change http query param from database to db * chore: fix cr issue
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -6633,6 +6633,7 @@ dependencies = [
|
||||
"script",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"session",
|
||||
"sha1",
|
||||
"snafu",
|
||||
|
||||
@@ -287,7 +287,7 @@ mod tests {
|
||||
|
||||
let provider = provider.unwrap();
|
||||
let result = provider
|
||||
.auth(Identity::UserId("test", None), Password::PlainText("test"))
|
||||
.authenticate(Identity::UserId("test", None), Password::PlainText("test"))
|
||||
.await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
@@ -350,7 +350,7 @@ mod tests {
|
||||
assert!(provider.is_some());
|
||||
let provider = provider.unwrap();
|
||||
let result = provider
|
||||
.auth(Identity::UserId("test", None), Password::PlainText("test"))
|
||||
.authenticate(Identity::UserId("test", None), Password::PlainText("test"))
|
||||
.await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
@@ -77,6 +77,8 @@ pub enum StatusCode {
|
||||
AuthHeaderNotFound = 7003,
|
||||
/// Invalid http authorization header
|
||||
InvalidAuthHeader = 7004,
|
||||
/// Illegal request to connect catalog-schema
|
||||
AccessDenied = 7005,
|
||||
// ====== End of auth related status code =====
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ use std::sync::Arc;
|
||||
|
||||
use meta_client::MetaClientOpts;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use servers::auth::UserProviderRef;
|
||||
use servers::http::HttpOptions;
|
||||
use servers::Mode;
|
||||
use snafu::prelude::*;
|
||||
@@ -92,8 +91,6 @@ impl<T: FrontendInstance> Frontend<T> {
|
||||
let instance = Arc::new(instance);
|
||||
|
||||
// TODO(sunng87): merge this into instance
|
||||
let provider = self.plugins.get::<UserProviderRef>().cloned();
|
||||
|
||||
Services::start(&self.opts, instance, provider).await
|
||||
Services::start(&self.opts, instance, self.plugins.clone()).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,6 +34,7 @@ use crate::frontend::FrontendOptions;
|
||||
use crate::influxdb::InfluxdbOptions;
|
||||
use crate::instance::FrontendInstance;
|
||||
use crate::prometheus::PrometheusOptions;
|
||||
use crate::Plugins;
|
||||
|
||||
pub(crate) struct Services;
|
||||
|
||||
@@ -41,12 +42,14 @@ impl Services {
|
||||
pub(crate) async fn start<T>(
|
||||
opts: &FrontendOptions,
|
||||
instance: Arc<T>,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
plugins: Arc<Plugins>,
|
||||
) -> Result<()>
|
||||
where
|
||||
T: FrontendInstance,
|
||||
{
|
||||
info!("Starting frontend servers");
|
||||
let user_provider = plugins.get::<UserProviderRef>().cloned();
|
||||
|
||||
let grpc_server_and_addr = if let Some(opts) = &opts.grpc_options {
|
||||
let grpc_addr = parse_addr(&opts.addr)?;
|
||||
|
||||
|
||||
@@ -48,6 +48,7 @@ rustls-pemfile = "1.0"
|
||||
schemars = "0.8"
|
||||
serde.workspace = true
|
||||
serde_json = "1.0"
|
||||
serde_urlencoded = "0.7"
|
||||
session = { path = "../session" }
|
||||
sha1 = "0.10"
|
||||
snafu = { version = "0.7", features = ["backtraces"] }
|
||||
|
||||
@@ -12,8 +12,6 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
pub mod user_provider;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_error::ext::BoxedError;
|
||||
@@ -24,11 +22,19 @@ use snafu::{Backtrace, ErrorCompat, OptionExt, Snafu};
|
||||
|
||||
use crate::auth::user_provider::StaticUserProvider;
|
||||
|
||||
pub mod user_provider;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait UserProvider: Send + Sync {
|
||||
fn name(&self) -> &str;
|
||||
|
||||
async fn auth(&self, id: Identity<'_>, password: Password<'_>) -> Result<UserInfo>;
|
||||
/// [`authenticate`] checks whether a user is valid and allowed to access the database.
|
||||
async fn authenticate(&self, id: Identity<'_>, password: Password<'_>) -> Result<UserInfo>;
|
||||
|
||||
/// [`authorize`] checks whether a connection request
|
||||
/// from a certain user to a certain catalog/schema is legal.
|
||||
/// This method should be called after [`authenticate`].
|
||||
async fn authorize(&self, catalog: &str, schema: &str, user_info: &UserInfo) -> Result<()>;
|
||||
}
|
||||
|
||||
pub type UserProviderRef = Arc<dyn UserProvider>;
|
||||
@@ -76,6 +82,12 @@ pub enum Error {
|
||||
#[snafu(display("Invalid config value: {}, {}", value, msg))]
|
||||
InvalidConfig { value: String, msg: String },
|
||||
|
||||
#[snafu(display("Illegal runtime param: {}", msg))]
|
||||
IllegalParam { msg: String },
|
||||
|
||||
#[snafu(display("Internal state error: {}", msg))]
|
||||
InternalState { msg: String },
|
||||
|
||||
#[snafu(display("IO error, source: {}", source))]
|
||||
Io {
|
||||
source: std::io::Error,
|
||||
@@ -96,18 +108,33 @@ pub enum Error {
|
||||
|
||||
#[snafu(display("Username and password does not match, username: {}", username))]
|
||||
UserPasswordMismatch { username: String },
|
||||
|
||||
#[snafu(display(
|
||||
"User {} is not allowed to access catalog {} and schema {}",
|
||||
username,
|
||||
catalog,
|
||||
schema
|
||||
))]
|
||||
AccessDenied {
|
||||
catalog: String,
|
||||
schema: String,
|
||||
username: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl ErrorExt for Error {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
Error::InvalidConfig { .. } => StatusCode::InvalidArguments,
|
||||
Error::IllegalParam { .. } => StatusCode::InvalidArguments,
|
||||
Error::InternalState { .. } => StatusCode::Unexpected,
|
||||
Error::Io { .. } => StatusCode::Internal,
|
||||
Error::AuthBackend { .. } => StatusCode::Internal,
|
||||
|
||||
Error::UserNotFound { .. } => StatusCode::UserNotFound,
|
||||
Error::UnsupportedPasswordType { .. } => StatusCode::UnsupportedPasswordType,
|
||||
Error::UserPasswordMismatch { .. } => StatusCode::UserPasswordMismatch,
|
||||
Error::AccessDenied { .. } => StatusCode::AccessDenied,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,108 +148,3 @@ impl ErrorExt for Error {
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test {
|
||||
use super::{Identity, Password, UserInfo, UserProvider};
|
||||
|
||||
pub struct MockUserProvider {}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl UserProvider for MockUserProvider {
|
||||
fn name(&self) -> &str {
|
||||
"mock_user_provider"
|
||||
}
|
||||
|
||||
async fn auth(
|
||||
&self,
|
||||
id: Identity<'_>,
|
||||
password: Password<'_>,
|
||||
) -> Result<UserInfo, super::Error> {
|
||||
match id {
|
||||
Identity::UserId(username, _host) => match password {
|
||||
Password::PlainText(password) => {
|
||||
if username == "greptime" {
|
||||
if password == "greptime" {
|
||||
return Ok(UserInfo::new("greptime"));
|
||||
} else {
|
||||
return super::UserPasswordMismatchSnafu {
|
||||
username: username.to_string(),
|
||||
}
|
||||
.fail();
|
||||
}
|
||||
} else {
|
||||
return super::UserNotFoundSnafu {
|
||||
username: username.to_string(),
|
||||
}
|
||||
.fail();
|
||||
}
|
||||
}
|
||||
_ => super::UnsupportedPasswordTypeSnafu {
|
||||
password_type: "mysql_native_password",
|
||||
}
|
||||
.fail(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::test::MockUserProvider;
|
||||
use super::{Identity, Password, UserProvider};
|
||||
use crate::auth;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_auth_by_plain_text() {
|
||||
let user_provider = MockUserProvider {};
|
||||
assert_eq!("mock_user_provider", user_provider.name());
|
||||
|
||||
// auth success
|
||||
let auth_result = user_provider
|
||||
.auth(
|
||||
Identity::UserId("greptime", None),
|
||||
Password::PlainText("greptime"),
|
||||
)
|
||||
.await;
|
||||
assert!(auth_result.is_ok());
|
||||
assert_eq!("greptime", auth_result.unwrap().username());
|
||||
|
||||
// auth failed, unsupported password type
|
||||
let auth_result = user_provider
|
||||
.auth(
|
||||
Identity::UserId("greptime", None),
|
||||
Password::MysqlNativePassword(b"hashed_value", b"salt"),
|
||||
)
|
||||
.await;
|
||||
assert!(auth_result.is_err());
|
||||
matches!(
|
||||
auth_result.err().unwrap(),
|
||||
auth::Error::UnsupportedPasswordType { .. }
|
||||
);
|
||||
|
||||
// auth failed, err: user not exist.
|
||||
let auth_result = user_provider
|
||||
.auth(
|
||||
Identity::UserId("not_exist_username", None),
|
||||
Password::PlainText("greptime"),
|
||||
)
|
||||
.await;
|
||||
assert!(auth_result.is_err());
|
||||
matches!(auth_result.err().unwrap(), auth::Error::UserNotFound { .. });
|
||||
|
||||
// auth failed, err: wrong password
|
||||
let auth_result = user_provider
|
||||
.auth(
|
||||
Identity::UserId("greptime", None),
|
||||
Password::PlainText("wrong_password"),
|
||||
)
|
||||
.await;
|
||||
assert!(auth_result.is_err());
|
||||
matches!(
|
||||
auth_result.err().unwrap(),
|
||||
auth::Error::UserPasswordMismatch { .. }
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,7 +99,11 @@ impl UserProvider for StaticUserProvider {
|
||||
STATIC_USER_PROVIDER
|
||||
}
|
||||
|
||||
async fn auth(&self, input_id: Identity<'_>, input_pwd: Password<'_>) -> Result<UserInfo> {
|
||||
async fn authenticate(
|
||||
&self,
|
||||
input_id: Identity<'_>,
|
||||
input_pwd: Password<'_>,
|
||||
) -> Result<UserInfo> {
|
||||
match input_id {
|
||||
Identity::UserId(username, _) => {
|
||||
let save_pwd = self.users.get(username).context(UserNotFoundSnafu {
|
||||
@@ -129,6 +133,11 @@ impl UserProvider for StaticUserProvider {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn authorize(&self, _catalog: &str, _schema: &str, _user_info: &UserInfo) -> Result<()> {
|
||||
// default allow all
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn auth_mysql(
|
||||
@@ -209,7 +218,7 @@ pub mod test {
|
||||
|
||||
async fn test_auth(provider: &dyn UserProvider, username: &str, password: &str) {
|
||||
let re = provider
|
||||
.auth(
|
||||
.authenticate(
|
||||
Identity::UserId(username, None),
|
||||
Password::PlainText(password),
|
||||
)
|
||||
|
||||
@@ -353,6 +353,12 @@ impl From<std::io::Error> for Error {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<auth::Error> for Error {
|
||||
fn from(e: auth::Error) -> Self {
|
||||
Error::Auth { source: e }
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for Error {
|
||||
fn into_response(self) -> Response {
|
||||
let (status, error_message) = match self {
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
mod authorize;
|
||||
pub mod authorize;
|
||||
pub mod handler;
|
||||
pub mod influxdb;
|
||||
pub mod opentsdb;
|
||||
@@ -92,8 +92,8 @@ pub(crate) fn query_context_from_db(
|
||||
}
|
||||
}
|
||||
|
||||
const HTTP_API_VERSION: &str = "v1";
|
||||
const HTTP_API_PREFIX: &str = "/v1/";
|
||||
pub const HTTP_API_VERSION: &str = "v1";
|
||||
pub const HTTP_API_PREFIX: &str = "/v1/";
|
||||
|
||||
pub struct HttpServer {
|
||||
sql_handler: ServerSqlQueryHandlerRef,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use std::collections::HashMap;
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -11,7 +12,6 @@
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use axum::http::{self, Request, StatusCode};
|
||||
@@ -23,7 +23,8 @@ use session::context::UserInfo;
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
use tower_http::auth::AsyncAuthorizeRequest;
|
||||
|
||||
use crate::auth::{Identity, UserProviderRef};
|
||||
use crate::auth::Error::IllegalParam;
|
||||
use crate::auth::{Identity, IllegalParamSnafu, InternalStateSnafu, UserProviderRef};
|
||||
use crate::error::{self, Result};
|
||||
use crate::http::HTTP_API_PREFIX;
|
||||
|
||||
@@ -70,45 +71,84 @@ where
|
||||
return Ok(request);
|
||||
};
|
||||
|
||||
let (scheme, credential) = match auth_header(&request) {
|
||||
Ok(auth_header) => auth_header,
|
||||
// do authenticate
|
||||
match authenticate(&user_provider, &request).await {
|
||||
Ok(user_info) => {
|
||||
request.extensions_mut().insert(user_info);
|
||||
}
|
||||
Err(e) => {
|
||||
error!("failed to get http authorize header, err: {:?}", e);
|
||||
error!("authenticate failed: {}", e);
|
||||
return Err(unauthorized_resp());
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
match scheme {
|
||||
AuthScheme::Basic => {
|
||||
let (username, password) = match decode_basic(credential) {
|
||||
Ok(basic_auth) => basic_auth,
|
||||
Err(e) => {
|
||||
error!("failed to decode basic authorize, err: {:?}", e);
|
||||
return Err(unauthorized_resp());
|
||||
}
|
||||
};
|
||||
match user_provider
|
||||
.auth(
|
||||
Identity::UserId(&username, None),
|
||||
crate::auth::Password::PlainText(&password),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(user_info) => {
|
||||
request.extensions_mut().insert(user_info);
|
||||
Ok(request)
|
||||
}
|
||||
Err(e) => {
|
||||
error!("failed to auth, err: {:?}", e);
|
||||
Err(unauthorized_resp())
|
||||
}
|
||||
}
|
||||
match authorize(&user_provider, &request).await {
|
||||
Ok(_) => Ok(request),
|
||||
Err(e) => {
|
||||
error!("authorize failed: {}", e);
|
||||
Err(unauthorized_resp())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn authorize<B: Send + Sync + 'static>(
|
||||
user_provider: &UserProviderRef,
|
||||
request: &Request<B>,
|
||||
) -> crate::auth::Result<()> {
|
||||
// try get database name
|
||||
let query = request.uri().query().unwrap_or_default();
|
||||
let input_database = match serde_urlencoded::from_str::<HashMap<String, String>>(query) {
|
||||
Ok(query_map) => query_map
|
||||
.get("db")
|
||||
.context(IllegalParamSnafu {
|
||||
msg: "fail to get valid database from http query",
|
||||
})?
|
||||
.to_owned(),
|
||||
Err(e) => IllegalParamSnafu {
|
||||
msg: format!("fail to parse http query: {e}"),
|
||||
}
|
||||
.fail()?,
|
||||
};
|
||||
|
||||
let (catalog, database) =
|
||||
crate::parse_catalog_and_schema_from_client_database_name(&input_database);
|
||||
|
||||
let user_info = request
|
||||
.extensions()
|
||||
.get::<UserInfo>()
|
||||
.context(InternalStateSnafu {
|
||||
msg: "no user info provided while authorizing",
|
||||
})?;
|
||||
|
||||
user_provider.authorize(catalog, database, user_info).await
|
||||
}
|
||||
|
||||
async fn authenticate<B: Send + Sync + 'static>(
|
||||
user_provider: &UserProviderRef,
|
||||
request: &Request<B>,
|
||||
) -> crate::auth::Result<UserInfo> {
|
||||
let (scheme, credential) = auth_header(request).map_err(|e| IllegalParam {
|
||||
msg: format!("failed to get http authorize header, err: {e:?}"),
|
||||
})?;
|
||||
|
||||
match scheme {
|
||||
AuthScheme::Basic => {
|
||||
let (username, password) = decode_basic(credential).map_err(|e| IllegalParam {
|
||||
msg: format!("failed to decode basic authorize, err: {e:?}"),
|
||||
})?;
|
||||
|
||||
Ok(user_provider
|
||||
.authenticate(
|
||||
Identity::UserId(&username, None),
|
||||
crate::auth::Password::PlainText(&password),
|
||||
)
|
||||
.await?)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn unauthorized_resp<RespBody>() -> Response<RespBody>
|
||||
where
|
||||
RespBody: Body + Default,
|
||||
@@ -171,79 +211,7 @@ fn decode_basic(credential: Credential) -> Result<(Username, Password)> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::body::BoxBody;
|
||||
use axum::http;
|
||||
use hyper::Request;
|
||||
use session::context::UserInfo;
|
||||
use tower_http::auth::AsyncAuthorizeRequest;
|
||||
|
||||
use super::{auth_header, decode_basic, AuthScheme, HttpAuth};
|
||||
use crate::auth::test::MockUserProvider;
|
||||
use crate::auth::UserProvider;
|
||||
use crate::error;
|
||||
use crate::error::Result;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_http_auth() {
|
||||
let mut http_auth: HttpAuth<BoxBody> = HttpAuth {
|
||||
user_provider: None,
|
||||
_ty: PhantomData,
|
||||
};
|
||||
|
||||
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
|
||||
let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
|
||||
let auth_res = http_auth.authorize(req).await.unwrap();
|
||||
let user_info: &UserInfo = auth_res.extensions().get().unwrap();
|
||||
let default = UserInfo::default();
|
||||
assert_eq!(default.username(), user_info.username());
|
||||
|
||||
// In mock user provider, right username:password == "greptime:greptime"
|
||||
let mock_user_provider = Some(Arc::new(MockUserProvider {}) as Arc<dyn UserProvider>);
|
||||
let mut http_auth: HttpAuth<BoxBody> = HttpAuth {
|
||||
user_provider: mock_user_provider,
|
||||
_ty: PhantomData,
|
||||
};
|
||||
|
||||
// base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU="
|
||||
let req = mock_http_request(Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="), None).unwrap();
|
||||
let req = http_auth.authorize(req).await.unwrap();
|
||||
let user_info: &UserInfo = req.extensions().get().unwrap();
|
||||
let default = UserInfo::default();
|
||||
assert_eq!(default.username(), user_info.username());
|
||||
|
||||
let req = mock_http_request(None, None).unwrap();
|
||||
let auth_res = http_auth.authorize(req).await;
|
||||
assert!(auth_res.is_err());
|
||||
|
||||
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
|
||||
let wrong_req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
|
||||
let auth_res = http_auth.authorize(wrong_req).await;
|
||||
assert!(auth_res.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_whitelist_no_auth() {
|
||||
// In mock user provider, right username:password == "greptime:greptime"
|
||||
let mock_user_provider = Some(Arc::new(MockUserProvider {}) as Arc<dyn UserProvider>);
|
||||
let mut http_auth: HttpAuth<BoxBody> = HttpAuth {
|
||||
user_provider: mock_user_provider,
|
||||
_ty: PhantomData,
|
||||
};
|
||||
|
||||
// base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU="
|
||||
// try auth path first
|
||||
let req = mock_http_request(None, None).unwrap();
|
||||
let req = http_auth.authorize(req).await;
|
||||
assert!(req.is_err());
|
||||
|
||||
// try whitelist path
|
||||
let req = mock_http_request(None, Some("http://localhost/health")).unwrap();
|
||||
let req = http_auth.authorize(req).await;
|
||||
assert!(req.is_ok());
|
||||
}
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_decode_basic() {
|
||||
|
||||
@@ -28,7 +28,7 @@ use crate::http::{ApiState, JsonResponse};
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct SqlQuery {
|
||||
pub database: Option<String>,
|
||||
pub db: Option<String>,
|
||||
pub sql: Option<String>,
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ pub async fn sql(
|
||||
let sql_handler = &state.sql_handler;
|
||||
let start = Instant::now();
|
||||
let resp = if let Some(sql) = ¶ms.sql {
|
||||
match super::query_context_from_db(sql_handler.clone(), params.database) {
|
||||
match super::query_context_from_db(sql_handler.clone(), params.db) {
|
||||
Ok(query_ctx) => {
|
||||
JsonResponse::from_output(sql_handler.do_query(sql, query_ctx).await).await
|
||||
}
|
||||
|
||||
@@ -123,7 +123,7 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
|
||||
return false;
|
||||
}
|
||||
};
|
||||
match user_provider.auth(user_id, password).await {
|
||||
match user_provider.authenticate(user_id, password).await {
|
||||
Ok(userinfo) => {
|
||||
user_info = Some(userinfo);
|
||||
}
|
||||
@@ -190,6 +190,12 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
|
||||
error::DatabaseNotFoundSnafu { catalog, schema }
|
||||
);
|
||||
|
||||
if let Some(schema_validator) = &self.user_provider {
|
||||
schema_validator
|
||||
.authorize(catalog, schema, &self.session.user_info())
|
||||
.await?;
|
||||
}
|
||||
|
||||
let context = self.session.context();
|
||||
context.set_current_catalog(catalog);
|
||||
context.set_current_schema(database);
|
||||
|
||||
@@ -38,6 +38,15 @@ use crate::tls::TlsOption;
|
||||
// Default size of ResultSet write buffer: 100KB
|
||||
const DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE: usize = 100 * 1024;
|
||||
|
||||
struct MysqlRuntimeOption {
|
||||
query_handler: ServerSqlQueryHandlerRef,
|
||||
tls_conf: Option<Arc<ServerConfig>>,
|
||||
force_tls: bool,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
}
|
||||
|
||||
type MysqlRuntimeOptionRef = Arc<MysqlRuntimeOption>;
|
||||
|
||||
pub struct MysqlServer {
|
||||
base_server: BaseTcpServer,
|
||||
query_handler: ServerSqlQueryHandlerRef,
|
||||
@@ -77,19 +86,19 @@ impl MysqlServer {
|
||||
let user_provider = user_provider.clone();
|
||||
let tls_conf = tls_conf.clone();
|
||||
|
||||
let mysql_runtime_option = Arc::new(MysqlRuntimeOption {
|
||||
query_handler,
|
||||
tls_conf,
|
||||
force_tls,
|
||||
user_provider,
|
||||
});
|
||||
|
||||
async move {
|
||||
match tcp_stream {
|
||||
Err(error) => error!("Broken pipe: {}", error), // IoError doesn't impl ErrorExt.
|
||||
Ok(io_stream) => {
|
||||
if let Err(error) = Self::handle(
|
||||
io_stream,
|
||||
io_runtime,
|
||||
query_handler,
|
||||
tls_conf,
|
||||
force_tls,
|
||||
user_provider,
|
||||
)
|
||||
.await
|
||||
if let Err(error) =
|
||||
Self::handle(io_stream, io_runtime, mysql_runtime_option).await
|
||||
{
|
||||
error!(error; "Unexpected error when handling TcpStream");
|
||||
};
|
||||
@@ -102,15 +111,12 @@ impl MysqlServer {
|
||||
async fn handle(
|
||||
stream: TcpStream,
|
||||
io_runtime: Arc<Runtime>,
|
||||
query_handler: ServerSqlQueryHandlerRef,
|
||||
tls_conf: Option<Arc<ServerConfig>>,
|
||||
force_tls: bool,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
runtime_opts: MysqlRuntimeOptionRef,
|
||||
) -> Result<()> {
|
||||
info!("MySQL connection coming from: {}", stream.peer_addr()?);
|
||||
io_runtime .spawn(async move {
|
||||
// TODO(LFC): Use `output_stream` to write large MySQL ResultSet to client.
|
||||
if let Err(e) = Self::do_handle(stream, query_handler, tls_conf, force_tls, user_provider).await {
|
||||
if let Err(e) = Self::do_handle(stream, runtime_opts).await {
|
||||
// TODO(LFC): Write this error to client as well, in MySQL text protocol.
|
||||
// Looks like we have to expose opensrv-mysql's `PacketWriter`?
|
||||
error!(e; "Internal error occurred during query exec, server actively close the channel to let client try next time.")
|
||||
@@ -120,28 +126,31 @@ impl MysqlServer {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn do_handle(
|
||||
stream: TcpStream,
|
||||
query_handler: ServerSqlQueryHandlerRef,
|
||||
tls_conf: Option<Arc<ServerConfig>>,
|
||||
force_tls: bool,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
) -> Result<()> {
|
||||
let mut shim = MysqlInstanceShim::create(query_handler, stream.peer_addr()?, user_provider);
|
||||
async fn do_handle(stream: TcpStream, runtime_opts: MysqlRuntimeOptionRef) -> Result<()> {
|
||||
let mut shim = MysqlInstanceShim::create(
|
||||
runtime_opts.query_handler.clone(),
|
||||
stream.peer_addr()?,
|
||||
runtime_opts.user_provider.clone(),
|
||||
);
|
||||
let (mut r, w) = stream.into_split();
|
||||
let mut w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w);
|
||||
let ops = IntermediaryOptions::default();
|
||||
|
||||
let (client_tls, init_params) =
|
||||
AsyncMysqlIntermediary::init_before_ssl(&mut shim, &mut r, &mut w, &tls_conf).await?;
|
||||
let (client_tls, init_params) = AsyncMysqlIntermediary::init_before_ssl(
|
||||
&mut shim,
|
||||
&mut r,
|
||||
&mut w,
|
||||
&runtime_opts.tls_conf,
|
||||
)
|
||||
.await?;
|
||||
|
||||
if force_tls && !client_tls {
|
||||
if runtime_opts.force_tls && !client_tls {
|
||||
return Err(Error::TlsRequired {
|
||||
server: "mysql".to_owned(),
|
||||
});
|
||||
}
|
||||
|
||||
match tls_conf {
|
||||
match runtime_opts.tls_conf.clone() {
|
||||
Some(tls_conf) if client_tls => {
|
||||
secure_run_with_options(shim, w, ops, tls_conf, init_params).await
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
|
||||
use pgwire::messages::response::ErrorResponse;
|
||||
use pgwire::messages::startup::Authentication;
|
||||
use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
|
||||
use session::context::{UserInfo, DEFAULT_USERNAME};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::auth::{Identity, Password, UserProviderRef};
|
||||
@@ -30,14 +31,15 @@ use crate::error;
|
||||
use crate::error::Result;
|
||||
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
|
||||
|
||||
struct PgPwdVerifier {
|
||||
struct PgLoginVerifier {
|
||||
user_provider: Option<UserProviderRef>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
struct LoginInfo {
|
||||
user: Option<String>,
|
||||
database: Option<String>,
|
||||
catalog: Option<String>,
|
||||
schema: Option<String>,
|
||||
host: String,
|
||||
}
|
||||
|
||||
@@ -48,27 +50,31 @@ impl LoginInfo {
|
||||
{
|
||||
LoginInfo {
|
||||
user: client.metadata().get(super::METADATA_USER).map(Into::into),
|
||||
database: client
|
||||
catalog: client
|
||||
.metadata()
|
||||
.get(super::METADATA_DATABASE)
|
||||
.get(super::METADATA_CATALOG)
|
||||
.map(Into::into),
|
||||
schema: client
|
||||
.metadata()
|
||||
.get(super::METADATA_SCHEMA)
|
||||
.map(Into::into),
|
||||
host: client.socket_addr().ip().to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PgPwdVerifier {
|
||||
async fn verify_pwd(&self, password: &str, login: LoginInfo) -> Result<bool> {
|
||||
impl PgLoginVerifier {
|
||||
async fn verify_pwd(&self, password: &str, login: &LoginInfo) -> Result<bool> {
|
||||
if let Some(user_provider) = &self.user_provider {
|
||||
let user_name = match login.user {
|
||||
let user_name = match &login.user {
|
||||
Some(name) => name,
|
||||
None => return Ok(false),
|
||||
};
|
||||
|
||||
// TODO(fys): pass user_info to context
|
||||
let _user_info = user_provider
|
||||
.auth(
|
||||
Identity::UserId(&user_name, None),
|
||||
.authenticate(
|
||||
Identity::UserId(user_name, None),
|
||||
Password::PlainText(password),
|
||||
)
|
||||
.await
|
||||
@@ -76,6 +82,29 @@ impl PgPwdVerifier {
|
||||
}
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
async fn authorize(&self, login: &LoginInfo) -> Result<bool> {
|
||||
// at this time, username in login info should be valid
|
||||
// TODO(shuiyisong): change to use actually user_info from session
|
||||
if let Some(user_provider) = &self.user_provider {
|
||||
let user_name = match &login.user {
|
||||
Some(name) => name,
|
||||
None => return Ok(false),
|
||||
};
|
||||
let catalog = match &login.catalog {
|
||||
Some(name) => name,
|
||||
None => return Ok(false),
|
||||
};
|
||||
let schema = match &login.schema {
|
||||
Some(name) => name,
|
||||
None => return Ok(false),
|
||||
};
|
||||
user_provider
|
||||
.authorize(catalog, schema, &UserInfo::new(user_name))
|
||||
.await?;
|
||||
}
|
||||
Ok(true)
|
||||
}
|
||||
}
|
||||
|
||||
struct GreptimeDBStartupParameters {
|
||||
@@ -106,7 +135,7 @@ impl ServerParameterProvider for GreptimeDBStartupParameters {
|
||||
}
|
||||
|
||||
pub struct PgAuthStartupHandler {
|
||||
verifier: PgPwdVerifier,
|
||||
verifier: PgLoginVerifier,
|
||||
param_provider: GreptimeDBStartupParameters,
|
||||
force_tls: bool,
|
||||
query_handler: ServerSqlQueryHandlerRef,
|
||||
@@ -119,7 +148,7 @@ impl PgAuthStartupHandler {
|
||||
query_handler: ServerSqlQueryHandlerRef,
|
||||
) -> Self {
|
||||
PgAuthStartupHandler {
|
||||
verifier: PgPwdVerifier { user_provider },
|
||||
verifier: PgLoginVerifier { user_provider },
|
||||
param_provider: GreptimeDBStartupParameters::new(),
|
||||
force_tls,
|
||||
query_handler,
|
||||
@@ -173,22 +202,50 @@ impl StartupHandler for PgAuthStartupHandler {
|
||||
))
|
||||
.await?;
|
||||
} else {
|
||||
// no user is provided, use default user
|
||||
// and still do authorization
|
||||
let mut login_info = LoginInfo::from_client_info(client);
|
||||
login_info.user = Some(DEFAULT_USERNAME.to_string());
|
||||
|
||||
let authorize_result = self.verifier.authorize(&login_info).await;
|
||||
if !matches!(authorize_result, Ok(true)) {
|
||||
return send_error(
|
||||
client,
|
||||
"FATAL",
|
||||
"28P01",
|
||||
"password authorization failed".to_owned(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
auth::finish_authentication(client, &self.param_provider).await;
|
||||
}
|
||||
}
|
||||
PgWireFrontendMessage::Password(ref pwd) => {
|
||||
let login_info = LoginInfo::from_client_info(client);
|
||||
if let Ok(true) = self.verifier.verify_pwd(pwd.password(), login_info).await {
|
||||
auth::finish_authentication(client, &self.param_provider).await
|
||||
} else {
|
||||
send_error(
|
||||
// do authenticate
|
||||
let authenticate_result =
|
||||
self.verifier.verify_pwd(pwd.password(), &login_info).await;
|
||||
if !matches!(authenticate_result, Ok(true)) {
|
||||
return send_error(
|
||||
client,
|
||||
"FATAL",
|
||||
"28P01",
|
||||
"Password authentication failed".to_owned(),
|
||||
"password authentication failed".to_owned(),
|
||||
)
|
||||
.await?;
|
||||
.await;
|
||||
}
|
||||
// do authorize
|
||||
let authorize_result = self.verifier.authorize(&login_info).await;
|
||||
if !matches!(authorize_result, Ok(true)) {
|
||||
return send_error(
|
||||
client,
|
||||
"FATAL",
|
||||
"28P01",
|
||||
"password authorization failed".to_owned(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
auth::finish_authentication(client, &self.param_provider).await;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
197
src/servers/tests/auth.rs
Normal file
197
src/servers/tests/auth.rs
Normal file
@@ -0,0 +1,197 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use servers::auth::user_provider::auth_mysql;
|
||||
use servers::auth::{
|
||||
AccessDeniedSnafu, Identity, Password, UnsupportedPasswordTypeSnafu, UserNotFoundSnafu,
|
||||
UserPasswordMismatchSnafu, UserProvider,
|
||||
};
|
||||
use session::context::UserInfo;
|
||||
|
||||
pub struct DatabaseAuthInfo<'a> {
|
||||
pub catalog: &'a str,
|
||||
pub schema: &'a str,
|
||||
pub username: &'a str,
|
||||
}
|
||||
|
||||
pub struct MockUserProvider {
|
||||
pub catalog: String,
|
||||
pub schema: String,
|
||||
pub username: String,
|
||||
}
|
||||
|
||||
impl Default for MockUserProvider {
|
||||
fn default() -> Self {
|
||||
MockUserProvider {
|
||||
catalog: "greptime".to_owned(),
|
||||
schema: "public".to_owned(),
|
||||
username: "greptime".to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MockUserProvider {
|
||||
pub fn set_authorization_info(&mut self, info: DatabaseAuthInfo) {
|
||||
self.catalog = info.catalog.to_owned();
|
||||
self.schema = info.schema.to_owned();
|
||||
self.username = info.username.to_owned();
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl UserProvider for MockUserProvider {
|
||||
fn name(&self) -> &str {
|
||||
"mock_user_provider"
|
||||
}
|
||||
|
||||
async fn authenticate(
|
||||
&self,
|
||||
id: Identity<'_>,
|
||||
password: Password<'_>,
|
||||
) -> servers::auth::Result<UserInfo> {
|
||||
match id {
|
||||
Identity::UserId(username, _host) => match password {
|
||||
Password::PlainText(password) => {
|
||||
if username == "greptime" {
|
||||
if password == "greptime" {
|
||||
Ok(UserInfo::new("greptime"))
|
||||
} else {
|
||||
UserPasswordMismatchSnafu {
|
||||
username: username.to_string(),
|
||||
}
|
||||
.fail()
|
||||
}
|
||||
} else {
|
||||
UserNotFoundSnafu {
|
||||
username: username.to_string(),
|
||||
}
|
||||
.fail()
|
||||
}
|
||||
}
|
||||
Password::MysqlNativePassword(auth_data, salt) => {
|
||||
auth_mysql(auth_data, salt, username, "greptime".as_bytes())
|
||||
.map(|_| UserInfo::new(username))
|
||||
}
|
||||
_ => UnsupportedPasswordTypeSnafu {
|
||||
password_type: "mysql_native_password",
|
||||
}
|
||||
.fail(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
async fn authorize(
|
||||
&self,
|
||||
catalog: &str,
|
||||
schema: &str,
|
||||
user_info: &UserInfo,
|
||||
) -> servers::auth::Result<()> {
|
||||
if catalog == self.catalog && schema == self.schema && user_info.username() == self.username
|
||||
{
|
||||
Ok(())
|
||||
} else {
|
||||
AccessDeniedSnafu {
|
||||
catalog: catalog.to_string(),
|
||||
schema: schema.to_string(),
|
||||
username: user_info.username().to_string(),
|
||||
}
|
||||
.fail()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_auth_by_plain_text() {
|
||||
let user_provider = MockUserProvider::default();
|
||||
assert_eq!("mock_user_provider", user_provider.name());
|
||||
|
||||
// auth success
|
||||
let auth_result = user_provider
|
||||
.authenticate(
|
||||
Identity::UserId("greptime", None),
|
||||
Password::PlainText("greptime"),
|
||||
)
|
||||
.await;
|
||||
assert!(auth_result.is_ok());
|
||||
assert_eq!("greptime", auth_result.unwrap().username());
|
||||
|
||||
// auth failed, unsupported password type
|
||||
let auth_result = user_provider
|
||||
.authenticate(
|
||||
Identity::UserId("greptime", None),
|
||||
Password::PgMD5(b"hashed_value", b"salt"),
|
||||
)
|
||||
.await;
|
||||
assert!(auth_result.is_err());
|
||||
matches!(
|
||||
auth_result.err().unwrap(),
|
||||
servers::auth::Error::UnsupportedPasswordType { .. }
|
||||
);
|
||||
|
||||
// auth failed, err: user not exist.
|
||||
let auth_result = user_provider
|
||||
.authenticate(
|
||||
Identity::UserId("not_exist_username", None),
|
||||
Password::PlainText("greptime"),
|
||||
)
|
||||
.await;
|
||||
assert!(auth_result.is_err());
|
||||
matches!(
|
||||
auth_result.err().unwrap(),
|
||||
servers::auth::Error::UserNotFound { .. }
|
||||
);
|
||||
|
||||
// auth failed, err: wrong password
|
||||
let auth_result = user_provider
|
||||
.authenticate(
|
||||
Identity::UserId("greptime", None),
|
||||
Password::PlainText("wrong_password"),
|
||||
)
|
||||
.await;
|
||||
assert!(auth_result.is_err());
|
||||
matches!(
|
||||
auth_result.err().unwrap(),
|
||||
servers::auth::Error::UserPasswordMismatch { .. }
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_schema_validate() {
|
||||
let mut validator = MockUserProvider::default();
|
||||
validator.set_authorization_info(DatabaseAuthInfo {
|
||||
catalog: "greptime",
|
||||
schema: "public",
|
||||
username: "test_user",
|
||||
});
|
||||
|
||||
let right_user = UserInfo::new("test_user");
|
||||
let wrong_user = UserInfo::default();
|
||||
|
||||
// check catalog
|
||||
let re = validator
|
||||
.authorize("greptime_wrong", "public", &right_user)
|
||||
.await;
|
||||
assert!(re.is_err());
|
||||
// check schema
|
||||
let re = validator
|
||||
.authorize("greptime", "public_wrong", &right_user)
|
||||
.await;
|
||||
assert!(re.is_err());
|
||||
// check username
|
||||
let re = validator.authorize("greptime", "public", &wrong_user).await;
|
||||
assert!(re.is_err());
|
||||
// check ok
|
||||
let re = validator.authorize("greptime", "public", &right_user).await;
|
||||
assert!(re.is_ok());
|
||||
}
|
||||
120
src/servers/tests/http/authorize.rs
Normal file
120
src/servers/tests/http/authorize.rs
Normal file
@@ -0,0 +1,120 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::body::BoxBody;
|
||||
use axum::http;
|
||||
use hyper::Request;
|
||||
use servers::auth::UserProvider;
|
||||
use servers::http::authorize::HttpAuth;
|
||||
use session::context::UserInfo;
|
||||
use tower_http::auth::AsyncAuthorizeRequest;
|
||||
|
||||
use crate::auth::MockUserProvider;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_http_auth() {
|
||||
let mut http_auth: HttpAuth<BoxBody> = HttpAuth::new(None);
|
||||
|
||||
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
|
||||
let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
|
||||
let auth_res = http_auth.authorize(req).await.unwrap();
|
||||
let user_info: &UserInfo = auth_res.extensions().get().unwrap();
|
||||
let default = UserInfo::default();
|
||||
assert_eq!(default.username(), user_info.username());
|
||||
|
||||
// In mock user provider, right username:password == "greptime:greptime"
|
||||
let mock_user_provider = Some(Arc::new(MockUserProvider::default()) as Arc<dyn UserProvider>);
|
||||
let mut http_auth: HttpAuth<BoxBody> = HttpAuth::new(mock_user_provider);
|
||||
|
||||
// base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU="
|
||||
let req = mock_http_request(Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="), None).unwrap();
|
||||
let req = http_auth.authorize(req).await.unwrap();
|
||||
let user_info: &UserInfo = req.extensions().get().unwrap();
|
||||
let default = UserInfo::default();
|
||||
assert_eq!(default.username(), user_info.username());
|
||||
|
||||
let req = mock_http_request(None, None).unwrap();
|
||||
let auth_res = http_auth.authorize(req).await;
|
||||
assert!(auth_res.is_err());
|
||||
|
||||
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
|
||||
let wrong_req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
|
||||
let auth_res = http_auth.authorize(wrong_req).await;
|
||||
assert!(auth_res.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_schema_validating() {
|
||||
// In mock user provider, right username:password == "greptime:greptime"
|
||||
let provider = MockUserProvider::default();
|
||||
let mock_user_provider = Some(Arc::new(provider) as Arc<dyn UserProvider>);
|
||||
let mut http_auth: HttpAuth<BoxBody> = HttpAuth::new(mock_user_provider);
|
||||
|
||||
// base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU="
|
||||
// http://localhost/{http_api_version}/sql?db=greptime
|
||||
let version = servers::http::HTTP_API_VERSION;
|
||||
let req = mock_http_request(
|
||||
Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="),
|
||||
Some(format!("http://localhost/{version}/sql?db=public").as_str()),
|
||||
)
|
||||
.unwrap();
|
||||
let req = http_auth.authorize(req).await.unwrap();
|
||||
let user_info: &UserInfo = req.extensions().get().unwrap();
|
||||
let default = UserInfo::default();
|
||||
assert_eq!(default.username(), user_info.username());
|
||||
|
||||
// wrong database
|
||||
let req = mock_http_request(
|
||||
Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="),
|
||||
Some(format!("http://localhost/{version}/sql?db=wrong").as_str()),
|
||||
)
|
||||
.unwrap();
|
||||
let result = http_auth.authorize(req).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_whitelist_no_auth() {
|
||||
// In mock user provider, right username:password == "greptime:greptime"
|
||||
let mock_user_provider = Some(Arc::new(MockUserProvider::default()) as Arc<dyn UserProvider>);
|
||||
let mut http_auth: HttpAuth<BoxBody> = HttpAuth::new(mock_user_provider);
|
||||
|
||||
// base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU="
|
||||
// try auth path first
|
||||
let req = mock_http_request(None, None).unwrap();
|
||||
let req = http_auth.authorize(req).await;
|
||||
assert!(req.is_err());
|
||||
|
||||
// try whitelist path
|
||||
let req = mock_http_request(None, Some("http://localhost/health")).unwrap();
|
||||
let req = http_auth.authorize(req).await;
|
||||
assert!(req.is_ok());
|
||||
}
|
||||
|
||||
// copy from http::authorize
|
||||
fn mock_http_request(
|
||||
auth_header: Option<&str>,
|
||||
uri: Option<&str>,
|
||||
) -> servers::error::Result<Request<()>> {
|
||||
let http_api_version = servers::http::HTTP_API_VERSION;
|
||||
let mut req = Request::builder()
|
||||
.uri(uri.unwrap_or(format!("http://localhost/{http_api_version}/sql?db=public").as_str()));
|
||||
if let Some(auth_header) = auth_header {
|
||||
req = req.header(http::header::AUTHORIZATION, auth_header);
|
||||
}
|
||||
|
||||
Ok(req.body(()).unwrap())
|
||||
}
|
||||
@@ -135,7 +135,7 @@ fn create_invalid_script_query() -> Query<script_handler::ScriptQuery> {
|
||||
fn create_query() -> Query<http_handler::SqlQuery> {
|
||||
Query(http_handler::SqlQuery {
|
||||
sql: Some("select sum(uint32s) from numbers limit 20".to_string()),
|
||||
database: None,
|
||||
db: None,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ use async_trait::async_trait;
|
||||
use axum::{http, Router};
|
||||
use axum_test_helper::TestClient;
|
||||
use common_query::Output;
|
||||
use servers::auth::user_provider::StaticUserProvider;
|
||||
use servers::error::{Error, Result};
|
||||
use servers::http::{HttpOptions, HttpServer};
|
||||
use servers::influxdb::InfluxdbRequest;
|
||||
@@ -28,8 +27,10 @@ use servers::query_handler::InfluxdbLineProtocolHandler;
|
||||
use session::context::QueryContextRef;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::auth::{DatabaseAuthInfo, MockUserProvider};
|
||||
|
||||
struct DummyInstance {
|
||||
tx: mpsc::Sender<(String, String)>,
|
||||
tx: Arc<mpsc::Sender<(String, String)>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -66,11 +67,18 @@ impl SqlQueryHandler for DummyInstance {
|
||||
}
|
||||
}
|
||||
|
||||
fn make_test_app(tx: mpsc::Sender<(String, String)>) -> Router {
|
||||
fn make_test_app(tx: Arc<mpsc::Sender<(String, String)>>, db_name: Option<&str>) -> Router {
|
||||
let instance = Arc::new(DummyInstance { tx });
|
||||
let mut server = HttpServer::new(instance.clone(), HttpOptions::default());
|
||||
let up = StaticUserProvider::try_from("cmd:greptime=greptime").unwrap();
|
||||
server.set_user_provider(Arc::new(up));
|
||||
let mut user_provider = MockUserProvider::default();
|
||||
if let Some(name) = db_name {
|
||||
user_provider.set_authorization_info(DatabaseAuthInfo {
|
||||
catalog: "greptime",
|
||||
schema: name,
|
||||
username: "greptime",
|
||||
})
|
||||
}
|
||||
server.set_user_provider(Arc::new(user_provider));
|
||||
|
||||
server.set_influxdb_handler(instance);
|
||||
server.make_app()
|
||||
@@ -79,13 +87,14 @@ fn make_test_app(tx: mpsc::Sender<(String, String)>) -> Router {
|
||||
#[tokio::test]
|
||||
async fn test_influxdb_write() {
|
||||
let (tx, mut rx) = mpsc::channel(100);
|
||||
let tx = Arc::new(tx);
|
||||
|
||||
let app = make_test_app(tx);
|
||||
let app = make_test_app(tx.clone(), None);
|
||||
let client = TestClient::new(app);
|
||||
|
||||
// right request
|
||||
let result = client
|
||||
.post("/v1/influxdb/write")
|
||||
.post("/v1/influxdb/write?db=public")
|
||||
.body("monitor,host=host1 cpu=1.2 1664370459457010101")
|
||||
.header(
|
||||
http::header::AUTHORIZATION,
|
||||
@@ -96,6 +105,10 @@ async fn test_influxdb_write() {
|
||||
assert_eq!(result.status(), 204);
|
||||
assert!(result.text().await.is_empty());
|
||||
|
||||
// make new app for db=influxdb
|
||||
let app = make_test_app(tx, Some("influxdb"));
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let result = client
|
||||
.post("/v1/influxdb/write?db=influxdb")
|
||||
.body("monitor,host=host1 cpu=1.2 1664370459457010101")
|
||||
@@ -110,7 +123,7 @@ async fn test_influxdb_write() {
|
||||
|
||||
// bad request
|
||||
let result = client
|
||||
.post("/v1/influxdb/write")
|
||||
.post("/v1/influxdb/write?db=influxdb")
|
||||
.body("monitor, host=host1 cpu=1.2 1664370459457010101")
|
||||
.header(
|
||||
http::header::AUTHORIZATION,
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
mod authorize;
|
||||
mod http_handler_test;
|
||||
mod influxdb_test;
|
||||
mod opentsdb_test;
|
||||
|
||||
@@ -22,21 +22,20 @@ use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use common_query::Output;
|
||||
use query::parser::QueryLanguageParser;
|
||||
use query::{QueryEngineFactory, QueryEngineRef};
|
||||
use servers::error::{Error, Result};
|
||||
use servers::query_handler::{ScriptHandler, ScriptHandlerRef};
|
||||
use table::test_util::MemTable;
|
||||
|
||||
mod http;
|
||||
mod mysql;
|
||||
use script::engine::{CompileContext, EvalContext, Script, ScriptEngine};
|
||||
use script::python::{PyEngine, PyScript};
|
||||
use servers::error::{Error, Result};
|
||||
use servers::query_handler::sql::{ServerSqlQueryHandlerRef, SqlQueryHandler};
|
||||
use servers::query_handler::{ScriptHandler, ScriptHandlerRef};
|
||||
use session::context::QueryContextRef;
|
||||
use table::test_util::MemTable;
|
||||
|
||||
mod auth;
|
||||
mod http;
|
||||
mod interceptor;
|
||||
mod mysql;
|
||||
mod opentsdb;
|
||||
mod postgres;
|
||||
|
||||
mod py_script;
|
||||
|
||||
struct DummyInstance {
|
||||
|
||||
@@ -24,17 +24,21 @@ use mysql_async::prelude::*;
|
||||
use mysql_async::SslOpts;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::Rng;
|
||||
use servers::auth::user_provider::StaticUserProvider;
|
||||
use servers::error::Result;
|
||||
use servers::mysql::server::MysqlServer;
|
||||
use servers::server::Server;
|
||||
use servers::tls::TlsOption;
|
||||
use table::test_util::MemTable;
|
||||
|
||||
use crate::auth::{DatabaseAuthInfo, MockUserProvider};
|
||||
use crate::create_testing_sql_query_handler;
|
||||
use crate::mysql::{all_datatype_testing_data, MysqlTextRow, TestingData};
|
||||
|
||||
fn create_mysql_server(table: MemTable, tls: TlsOption) -> Result<Box<dyn Server>> {
|
||||
fn create_mysql_server(
|
||||
table: MemTable,
|
||||
tls: TlsOption,
|
||||
auth_info: Option<DatabaseAuthInfo>,
|
||||
) -> Result<Box<dyn Server>> {
|
||||
let query_handler = create_testing_sql_query_handler(table);
|
||||
let io_runtime = Arc::new(
|
||||
RuntimeBuilder::default()
|
||||
@@ -44,7 +48,10 @@ fn create_mysql_server(table: MemTable, tls: TlsOption) -> Result<Box<dyn Server
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let provider = StaticUserProvider::try_from("cmd:greptime=greptime").unwrap();
|
||||
let mut provider = MockUserProvider::default();
|
||||
if let Some(auth_info) = auth_info {
|
||||
provider.set_authorization_info(auth_info);
|
||||
}
|
||||
|
||||
Ok(MysqlServer::create_server(
|
||||
query_handler,
|
||||
@@ -58,7 +65,7 @@ fn create_mysql_server(table: MemTable, tls: TlsOption) -> Result<Box<dyn Server
|
||||
async fn test_start_mysql_server() -> Result<()> {
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let mysql_server = create_mysql_server(table, Default::default())?;
|
||||
let mysql_server = create_mysql_server(table, Default::default(), None)?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let result = mysql_server.start(listening).await;
|
||||
assert!(result.is_ok());
|
||||
@@ -71,13 +78,54 @@ async fn test_start_mysql_server() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_schema_validation() -> Result<()> {
|
||||
async fn generate_server(auth_info: DatabaseAuthInfo<'_>) -> Result<(Box<dyn Server>, u16)> {
|
||||
let table = MemTable::default_numbers_table();
|
||||
let mysql_server = create_mysql_server(table, Default::default(), Some(auth_info))?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = mysql_server.start(listening).await.unwrap();
|
||||
Ok((mysql_server, server_addr.port()))
|
||||
}
|
||||
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let (mysql_server, server_port) = generate_server(DatabaseAuthInfo {
|
||||
catalog: "greptime",
|
||||
schema: "public",
|
||||
username: "greptime",
|
||||
})
|
||||
.await?;
|
||||
|
||||
//TODO(shuiyisong): mysql conn without dbname rejection is not implemented yet, add test later.
|
||||
|
||||
let pass = create_connection(server_port, Some("public"), false).await;
|
||||
assert!(pass.is_ok());
|
||||
let result = mysql_server.shutdown().await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// change to another username
|
||||
let (mysql_server, server_port) = generate_server(DatabaseAuthInfo {
|
||||
catalog: "greptime",
|
||||
schema: "public",
|
||||
username: "no_access_user",
|
||||
})
|
||||
.await?;
|
||||
|
||||
let fail = create_connection(server_port, Some("public"), false).await;
|
||||
assert!(fail.is_err());
|
||||
let result = mysql_server.shutdown().await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_shutdown_mysql_server() -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let mysql_server = create_mysql_server(table, Default::default())?;
|
||||
let mysql_server = create_mysql_server(table, Default::default(), None)?;
|
||||
let result = mysql_server.shutdown().await;
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
@@ -193,7 +241,7 @@ async fn test_server_required_secure_client_plain() -> Result<()> {
|
||||
let recordbatch = RecordBatch::new(schema, columns).unwrap();
|
||||
let table = MemTable::new("all_datatypes", recordbatch);
|
||||
|
||||
let mysql_server = create_mysql_server(table, server_tls)?;
|
||||
let mysql_server = create_mysql_server(table, server_tls, None)?;
|
||||
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = mysql_server.start(listening).await.unwrap();
|
||||
@@ -219,7 +267,7 @@ async fn test_db_name() -> Result<()> {
|
||||
let recordbatch = RecordBatch::new(schema, columns).unwrap();
|
||||
let table = MemTable::new("all_datatypes", recordbatch);
|
||||
|
||||
let mysql_server = create_mysql_server(table, server_tls)?;
|
||||
let mysql_server = create_mysql_server(table, server_tls, None)?;
|
||||
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = mysql_server.start(listening).await.unwrap();
|
||||
@@ -247,7 +295,7 @@ async fn do_test_query_all_datatypes(server_tls: TlsOption, client_tls: bool) ->
|
||||
let recordbatch = RecordBatch::new(schema, columns).unwrap();
|
||||
let table = MemTable::new("all_datatypes", recordbatch);
|
||||
|
||||
let mysql_server = create_mysql_server(table, server_tls)?;
|
||||
let mysql_server = create_mysql_server(table, server_tls, None)?;
|
||||
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = mysql_server.start(listening).await.unwrap();
|
||||
@@ -282,7 +330,7 @@ async fn test_query_concurrently() -> Result<()> {
|
||||
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let mysql_server = create_mysql_server(table, Default::default())?;
|
||||
let mysql_server = create_mysql_server(table, Default::default(), None)?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = mysql_server.start(listening).await.unwrap();
|
||||
let server_port = server_addr.port();
|
||||
@@ -332,7 +380,7 @@ async fn create_connection(
|
||||
.tcp_port(port)
|
||||
.prefer_socket(false)
|
||||
.wait_timeout(Some(1000))
|
||||
.db_name(db_name)
|
||||
.db_name(db_name.or(Some(DEFAULT_SCHEMA_NAME)))
|
||||
.user(Some("greptime".to_string()))
|
||||
.pass(Some("greptime".to_string()));
|
||||
|
||||
|
||||
@@ -22,7 +22,6 @@ use rand::rngs::StdRng;
|
||||
use rand::Rng;
|
||||
use rustls::client::{ServerCertVerified, ServerCertVerifier};
|
||||
use rustls::{Certificate, Error, ServerName};
|
||||
use servers::auth::user_provider::StaticUserProvider;
|
||||
use servers::auth::UserProviderRef;
|
||||
use servers::error::Result;
|
||||
use servers::postgres::PostgresServer;
|
||||
@@ -31,12 +30,14 @@ use servers::tls::TlsOption;
|
||||
use table::test_util::MemTable;
|
||||
use tokio_postgres::{Client, Error as PgError, NoTls, SimpleQueryMessage};
|
||||
|
||||
use crate::auth::{DatabaseAuthInfo, MockUserProvider};
|
||||
use crate::create_testing_instance;
|
||||
|
||||
fn create_postgres_server(
|
||||
table: MemTable,
|
||||
check_pwd: bool,
|
||||
tls: TlsOption,
|
||||
auth_info: Option<DatabaseAuthInfo>,
|
||||
) -> Result<Box<dyn Server>> {
|
||||
let instance = Arc::new(create_testing_instance(table));
|
||||
let io_runtime = Arc::new(
|
||||
@@ -47,9 +48,11 @@ fn create_postgres_server(
|
||||
.unwrap(),
|
||||
);
|
||||
let user_provider: Option<UserProviderRef> = if check_pwd {
|
||||
Some(Arc::new(
|
||||
StaticUserProvider::try_from("cmd:test_user=test_pwd").unwrap(),
|
||||
))
|
||||
let mut provider = MockUserProvider::default();
|
||||
if let Some(info) = auth_info {
|
||||
provider.set_authorization_info(info);
|
||||
}
|
||||
Some(Arc::new(provider))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -66,7 +69,7 @@ fn create_postgres_server(
|
||||
pub async fn test_start_postgres_server() -> Result<()> {
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let pg_server = create_postgres_server(table, false, Default::default())?;
|
||||
let pg_server = create_postgres_server(table, false, Default::default(), None)?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let result = pg_server.start(listening).await;
|
||||
assert!(result.is_ok());
|
||||
@@ -86,12 +89,52 @@ async fn test_shutdown_pg_server_range() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_schema_validating() -> Result<()> {
|
||||
async fn generate_server(auth_info: DatabaseAuthInfo<'_>) -> Result<(Box<dyn Server>, u16)> {
|
||||
let table = MemTable::default_numbers_table();
|
||||
let postgres_server =
|
||||
create_postgres_server(table, true, Default::default(), Some(auth_info))?;
|
||||
let listening = "127.0.0.1:5432".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = postgres_server.start(listening).await.unwrap();
|
||||
let server_port = server_addr.port();
|
||||
Ok((postgres_server, server_port))
|
||||
}
|
||||
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let (pg_server, server_port) = generate_server(DatabaseAuthInfo {
|
||||
catalog: "greptime",
|
||||
schema: "public",
|
||||
username: "greptime",
|
||||
})
|
||||
.await?;
|
||||
|
||||
let pass = create_plain_connection(server_port, true).await;
|
||||
assert!(pass.is_ok());
|
||||
let result = pg_server.shutdown().await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let (pg_server, server_port) = generate_server(DatabaseAuthInfo {
|
||||
catalog: "greptime",
|
||||
schema: "public",
|
||||
username: "no_right_user",
|
||||
})
|
||||
.await?;
|
||||
|
||||
let fail = create_plain_connection(server_port, true).await;
|
||||
assert!(fail.is_err());
|
||||
let result = pg_server.shutdown().await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let table = MemTable::default_numbers_table();
|
||||
let postgres_server = create_postgres_server(table, with_pwd, Default::default())?;
|
||||
let postgres_server = create_postgres_server(table, with_pwd, Default::default(), None)?;
|
||||
let result = postgres_server.shutdown().await;
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
@@ -273,7 +316,7 @@ async fn test_using_db() -> Result<()> {
|
||||
async fn start_test_server(server_tls: TlsOption) -> Result<u16> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let table = MemTable::default_numbers_table();
|
||||
let pg_server = create_postgres_server(table, false, server_tls)?;
|
||||
let pg_server = create_postgres_server(table, false, server_tls, None)?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = pg_server.start(listening).await.unwrap();
|
||||
Ok(server_addr.port())
|
||||
@@ -301,7 +344,7 @@ async fn create_secure_connection(
|
||||
) -> std::result::Result<Client, PgError> {
|
||||
let url = if with_pwd {
|
||||
format!(
|
||||
"sslmode=require host=127.0.0.1 port={port} user=test_user password=test_pwd connect_timeout=2, dbname={DEFAULT_SCHEMA_NAME}",
|
||||
"sslmode=require host=127.0.0.1 port={port} user=greptime password=greptime connect_timeout=2, dbname={DEFAULT_SCHEMA_NAME}",
|
||||
)
|
||||
} else {
|
||||
format!("host=127.0.0.1 port={port} connect_timeout=2 dbname={DEFAULT_SCHEMA_NAME}")
|
||||
@@ -328,7 +371,7 @@ async fn create_plain_connection(
|
||||
) -> std::result::Result<Client, PgError> {
|
||||
let url = if with_pwd {
|
||||
format!(
|
||||
"host=127.0.0.1 port={port} user=test_user password=test_pwd connect_timeout=2 dbname={DEFAULT_SCHEMA_NAME}",
|
||||
"host=127.0.0.1 port={port} user=greptime password=greptime connect_timeout=2 dbname={DEFAULT_SCHEMA_NAME}",
|
||||
)
|
||||
} else {
|
||||
format!("host=127.0.0.1 port={port} connect_timeout=2 dbname={DEFAULT_SCHEMA_NAME}")
|
||||
|
||||
@@ -195,7 +195,7 @@ pub async fn test_sql_api(store_type: StorageType) {
|
||||
|
||||
// test database given
|
||||
let res = client
|
||||
.get("/v1/sql?database=public&sql=select cpu, ts from demo limit 1")
|
||||
.get("/v1/sql?db=public&sql=select cpu, ts from demo limit 1")
|
||||
.send()
|
||||
.await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
@@ -214,7 +214,7 @@ pub async fn test_sql_api(store_type: StorageType) {
|
||||
|
||||
// test database not found
|
||||
let res = client
|
||||
.get("/v1/sql?database=notfound&sql=select cpu, ts from demo limit 1")
|
||||
.get("/v1/sql?db=notfound&sql=select cpu, ts from demo limit 1")
|
||||
.send()
|
||||
.await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
@@ -223,7 +223,7 @@ pub async fn test_sql_api(store_type: StorageType) {
|
||||
|
||||
// test catalog-schema given
|
||||
let res = client
|
||||
.get("/v1/sql?database=greptime-public&sql=select cpu, ts from demo limit 1")
|
||||
.get("/v1/sql?db=greptime-public&sql=select cpu, ts from demo limit 1")
|
||||
.send()
|
||||
.await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
@@ -242,7 +242,7 @@ pub async fn test_sql_api(store_type: StorageType) {
|
||||
|
||||
// test invalid catalog
|
||||
let res = client
|
||||
.get("/v1/sql?database=notfound2-schema&sql=select cpu, ts from demo limit 1")
|
||||
.get("/v1/sql?db=notfound2-schema&sql=select cpu, ts from demo limit 1")
|
||||
.send()
|
||||
.await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
@@ -251,7 +251,7 @@ pub async fn test_sql_api(store_type: StorageType) {
|
||||
|
||||
// test invalid schema
|
||||
let res = client
|
||||
.get("/v1/sql?database=greptime-schema&sql=select cpu, ts from demo limit 1")
|
||||
.get("/v1/sql?db=greptime-schema&sql=select cpu, ts from demo limit 1")
|
||||
.send()
|
||||
.await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
Reference in New Issue
Block a user