diff --git a/Cargo.lock b/Cargo.lock index 96b40dbb4e..df25fd5a87 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6633,6 +6633,7 @@ dependencies = [ "script", "serde", "serde_json", + "serde_urlencoded", "session", "sha1", "snafu", diff --git a/src/cmd/src/frontend.rs b/src/cmd/src/frontend.rs index a8da249cca..24e115988c 100644 --- a/src/cmd/src/frontend.rs +++ b/src/cmd/src/frontend.rs @@ -287,7 +287,7 @@ mod tests { let provider = provider.unwrap(); let result = provider - .auth(Identity::UserId("test", None), Password::PlainText("test")) + .authenticate(Identity::UserId("test", None), Password::PlainText("test")) .await; assert!(result.is_ok()); } diff --git a/src/cmd/src/standalone.rs b/src/cmd/src/standalone.rs index 4773bfa974..6dfd64f5d0 100644 --- a/src/cmd/src/standalone.rs +++ b/src/cmd/src/standalone.rs @@ -350,7 +350,7 @@ mod tests { assert!(provider.is_some()); let provider = provider.unwrap(); let result = provider - .auth(Identity::UserId("test", None), Password::PlainText("test")) + .authenticate(Identity::UserId("test", None), Password::PlainText("test")) .await; assert!(result.is_ok()); } diff --git a/src/common/error/src/status_code.rs b/src/common/error/src/status_code.rs index 2dfc2ad4a7..d58c1696c2 100644 --- a/src/common/error/src/status_code.rs +++ b/src/common/error/src/status_code.rs @@ -77,6 +77,8 @@ pub enum StatusCode { AuthHeaderNotFound = 7003, /// Invalid http authorization header InvalidAuthHeader = 7004, + /// Illegal request to connect catalog-schema + AccessDenied = 7005, // ====== End of auth related status code ===== } diff --git a/src/frontend/src/frontend.rs b/src/frontend/src/frontend.rs index 00a3b4fbe9..ad92b65dff 100644 --- a/src/frontend/src/frontend.rs +++ b/src/frontend/src/frontend.rs @@ -16,7 +16,6 @@ use std::sync::Arc; use meta_client::MetaClientOpts; use serde::{Deserialize, Serialize}; -use servers::auth::UserProviderRef; use servers::http::HttpOptions; use servers::Mode; use snafu::prelude::*; @@ -92,8 +91,6 @@ impl Frontend { let instance = Arc::new(instance); // TODO(sunng87): merge this into instance - let provider = self.plugins.get::().cloned(); - - Services::start(&self.opts, instance, provider).await + Services::start(&self.opts, instance, self.plugins.clone()).await } } diff --git a/src/frontend/src/server.rs b/src/frontend/src/server.rs index d71c5adb44..d69b6c6fde 100644 --- a/src/frontend/src/server.rs +++ b/src/frontend/src/server.rs @@ -34,6 +34,7 @@ use crate::frontend::FrontendOptions; use crate::influxdb::InfluxdbOptions; use crate::instance::FrontendInstance; use crate::prometheus::PrometheusOptions; +use crate::Plugins; pub(crate) struct Services; @@ -41,12 +42,14 @@ impl Services { pub(crate) async fn start( opts: &FrontendOptions, instance: Arc, - user_provider: Option, + plugins: Arc, ) -> Result<()> where T: FrontendInstance, { info!("Starting frontend servers"); + let user_provider = plugins.get::().cloned(); + let grpc_server_and_addr = if let Some(opts) = &opts.grpc_options { let grpc_addr = parse_addr(&opts.addr)?; diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index e87b9d2501..52f914ccc0 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -48,6 +48,7 @@ rustls-pemfile = "1.0" schemars = "0.8" serde.workspace = true serde_json = "1.0" +serde_urlencoded = "0.7" session = { path = "../session" } sha1 = "0.10" snafu = { version = "0.7", features = ["backtraces"] } diff --git a/src/servers/src/auth.rs b/src/servers/src/auth.rs index a029a83f5a..d7a4ed990c 100644 --- a/src/servers/src/auth.rs +++ b/src/servers/src/auth.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -pub mod user_provider; - use std::sync::Arc; use common_error::ext::BoxedError; @@ -24,11 +22,19 @@ use snafu::{Backtrace, ErrorCompat, OptionExt, Snafu}; use crate::auth::user_provider::StaticUserProvider; +pub mod user_provider; + #[async_trait::async_trait] pub trait UserProvider: Send + Sync { fn name(&self) -> &str; - async fn auth(&self, id: Identity<'_>, password: Password<'_>) -> Result; + /// [`authenticate`] checks whether a user is valid and allowed to access the database. + async fn authenticate(&self, id: Identity<'_>, password: Password<'_>) -> Result; + + /// [`authorize`] checks whether a connection request + /// from a certain user to a certain catalog/schema is legal. + /// This method should be called after [`authenticate`]. + async fn authorize(&self, catalog: &str, schema: &str, user_info: &UserInfo) -> Result<()>; } pub type UserProviderRef = Arc; @@ -76,6 +82,12 @@ pub enum Error { #[snafu(display("Invalid config value: {}, {}", value, msg))] InvalidConfig { value: String, msg: String }, + #[snafu(display("Illegal runtime param: {}", msg))] + IllegalParam { msg: String }, + + #[snafu(display("Internal state error: {}", msg))] + InternalState { msg: String }, + #[snafu(display("IO error, source: {}", source))] Io { source: std::io::Error, @@ -96,18 +108,33 @@ pub enum Error { #[snafu(display("Username and password does not match, username: {}", username))] UserPasswordMismatch { username: String }, + + #[snafu(display( + "User {} is not allowed to access catalog {} and schema {}", + username, + catalog, + schema + ))] + AccessDenied { + catalog: String, + schema: String, + username: String, + }, } impl ErrorExt for Error { fn status_code(&self) -> StatusCode { match self { Error::InvalidConfig { .. } => StatusCode::InvalidArguments, + Error::IllegalParam { .. } => StatusCode::InvalidArguments, + Error::InternalState { .. } => StatusCode::Unexpected, Error::Io { .. } => StatusCode::Internal, Error::AuthBackend { .. } => StatusCode::Internal, Error::UserNotFound { .. } => StatusCode::UserNotFound, Error::UnsupportedPasswordType { .. } => StatusCode::UnsupportedPasswordType, Error::UserPasswordMismatch { .. } => StatusCode::UserPasswordMismatch, + Error::AccessDenied { .. } => StatusCode::AccessDenied, } } @@ -121,108 +148,3 @@ impl ErrorExt for Error { } pub type Result = std::result::Result; - -#[cfg(test)] -pub mod test { - use super::{Identity, Password, UserInfo, UserProvider}; - - pub struct MockUserProvider {} - - #[async_trait::async_trait] - impl UserProvider for MockUserProvider { - fn name(&self) -> &str { - "mock_user_provider" - } - - async fn auth( - &self, - id: Identity<'_>, - password: Password<'_>, - ) -> Result { - match id { - Identity::UserId(username, _host) => match password { - Password::PlainText(password) => { - if username == "greptime" { - if password == "greptime" { - return Ok(UserInfo::new("greptime")); - } else { - return super::UserPasswordMismatchSnafu { - username: username.to_string(), - } - .fail(); - } - } else { - return super::UserNotFoundSnafu { - username: username.to_string(), - } - .fail(); - } - } - _ => super::UnsupportedPasswordTypeSnafu { - password_type: "mysql_native_password", - } - .fail(), - }, - } - } - } -} - -#[cfg(test)] -mod tests { - use super::test::MockUserProvider; - use super::{Identity, Password, UserProvider}; - use crate::auth; - - #[tokio::test] - async fn test_auth_by_plain_text() { - let user_provider = MockUserProvider {}; - assert_eq!("mock_user_provider", user_provider.name()); - - // auth success - let auth_result = user_provider - .auth( - Identity::UserId("greptime", None), - Password::PlainText("greptime"), - ) - .await; - assert!(auth_result.is_ok()); - assert_eq!("greptime", auth_result.unwrap().username()); - - // auth failed, unsupported password type - let auth_result = user_provider - .auth( - Identity::UserId("greptime", None), - Password::MysqlNativePassword(b"hashed_value", b"salt"), - ) - .await; - assert!(auth_result.is_err()); - matches!( - auth_result.err().unwrap(), - auth::Error::UnsupportedPasswordType { .. } - ); - - // auth failed, err: user not exist. - let auth_result = user_provider - .auth( - Identity::UserId("not_exist_username", None), - Password::PlainText("greptime"), - ) - .await; - assert!(auth_result.is_err()); - matches!(auth_result.err().unwrap(), auth::Error::UserNotFound { .. }); - - // auth failed, err: wrong password - let auth_result = user_provider - .auth( - Identity::UserId("greptime", None), - Password::PlainText("wrong_password"), - ) - .await; - assert!(auth_result.is_err()); - matches!( - auth_result.err().unwrap(), - auth::Error::UserPasswordMismatch { .. } - ); - } -} diff --git a/src/servers/src/auth/user_provider.rs b/src/servers/src/auth/user_provider.rs index c25855e25d..8e9d570aa8 100644 --- a/src/servers/src/auth/user_provider.rs +++ b/src/servers/src/auth/user_provider.rs @@ -99,7 +99,11 @@ impl UserProvider for StaticUserProvider { STATIC_USER_PROVIDER } - async fn auth(&self, input_id: Identity<'_>, input_pwd: Password<'_>) -> Result { + async fn authenticate( + &self, + input_id: Identity<'_>, + input_pwd: Password<'_>, + ) -> Result { match input_id { Identity::UserId(username, _) => { let save_pwd = self.users.get(username).context(UserNotFoundSnafu { @@ -129,6 +133,11 @@ impl UserProvider for StaticUserProvider { } } } + + async fn authorize(&self, _catalog: &str, _schema: &str, _user_info: &UserInfo) -> Result<()> { + // default allow all + Ok(()) + } } pub fn auth_mysql( @@ -209,7 +218,7 @@ pub mod test { async fn test_auth(provider: &dyn UserProvider, username: &str, password: &str) { let re = provider - .auth( + .authenticate( Identity::UserId(username, None), Password::PlainText(password), ) diff --git a/src/servers/src/error.rs b/src/servers/src/error.rs index e2a79b9763..caab78c6c1 100644 --- a/src/servers/src/error.rs +++ b/src/servers/src/error.rs @@ -353,6 +353,12 @@ impl From for Error { } } +impl From for Error { + fn from(e: auth::Error) -> Self { + Error::Auth { source: e } + } +} + impl IntoResponse for Error { fn into_response(self) -> Response { let (status, error_message) = match self { diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index 0ac7fb9c62..6ac52c55e9 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -mod authorize; +pub mod authorize; pub mod handler; pub mod influxdb; pub mod opentsdb; @@ -92,8 +92,8 @@ pub(crate) fn query_context_from_db( } } -const HTTP_API_VERSION: &str = "v1"; -const HTTP_API_PREFIX: &str = "/v1/"; +pub const HTTP_API_VERSION: &str = "v1"; +pub const HTTP_API_PREFIX: &str = "/v1/"; pub struct HttpServer { sql_handler: ServerSqlQueryHandlerRef, diff --git a/src/servers/src/http/authorize.rs b/src/servers/src/http/authorize.rs index 05acc538f3..a14fbdff45 100644 --- a/src/servers/src/http/authorize.rs +++ b/src/servers/src/http/authorize.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; // Copyright 2023 Greptime Team // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -11,7 +12,6 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - use std::marker::PhantomData; use axum::http::{self, Request, StatusCode}; @@ -23,7 +23,8 @@ use session::context::UserInfo; use snafu::{OptionExt, ResultExt}; use tower_http::auth::AsyncAuthorizeRequest; -use crate::auth::{Identity, UserProviderRef}; +use crate::auth::Error::IllegalParam; +use crate::auth::{Identity, IllegalParamSnafu, InternalStateSnafu, UserProviderRef}; use crate::error::{self, Result}; use crate::http::HTTP_API_PREFIX; @@ -70,45 +71,84 @@ where return Ok(request); }; - let (scheme, credential) = match auth_header(&request) { - Ok(auth_header) => auth_header, + // do authenticate + match authenticate(&user_provider, &request).await { + Ok(user_info) => { + request.extensions_mut().insert(user_info); + } Err(e) => { - error!("failed to get http authorize header, err: {:?}", e); + error!("authenticate failed: {}", e); return Err(unauthorized_resp()); } - }; + } - match scheme { - AuthScheme::Basic => { - let (username, password) = match decode_basic(credential) { - Ok(basic_auth) => basic_auth, - Err(e) => { - error!("failed to decode basic authorize, err: {:?}", e); - return Err(unauthorized_resp()); - } - }; - match user_provider - .auth( - Identity::UserId(&username, None), - crate::auth::Password::PlainText(&password), - ) - .await - { - Ok(user_info) => { - request.extensions_mut().insert(user_info); - Ok(request) - } - Err(e) => { - error!("failed to auth, err: {:?}", e); - Err(unauthorized_resp()) - } - } + match authorize(&user_provider, &request).await { + Ok(_) => Ok(request), + Err(e) => { + error!("authorize failed: {}", e); + Err(unauthorized_resp()) } } }) } } +async fn authorize( + user_provider: &UserProviderRef, + request: &Request, +) -> crate::auth::Result<()> { + // try get database name + let query = request.uri().query().unwrap_or_default(); + let input_database = match serde_urlencoded::from_str::>(query) { + Ok(query_map) => query_map + .get("db") + .context(IllegalParamSnafu { + msg: "fail to get valid database from http query", + })? + .to_owned(), + Err(e) => IllegalParamSnafu { + msg: format!("fail to parse http query: {e}"), + } + .fail()?, + }; + + let (catalog, database) = + crate::parse_catalog_and_schema_from_client_database_name(&input_database); + + let user_info = request + .extensions() + .get::() + .context(InternalStateSnafu { + msg: "no user info provided while authorizing", + })?; + + user_provider.authorize(catalog, database, user_info).await +} + +async fn authenticate( + user_provider: &UserProviderRef, + request: &Request, +) -> crate::auth::Result { + let (scheme, credential) = auth_header(request).map_err(|e| IllegalParam { + msg: format!("failed to get http authorize header, err: {e:?}"), + })?; + + match scheme { + AuthScheme::Basic => { + let (username, password) = decode_basic(credential).map_err(|e| IllegalParam { + msg: format!("failed to decode basic authorize, err: {e:?}"), + })?; + + Ok(user_provider + .authenticate( + Identity::UserId(&username, None), + crate::auth::Password::PlainText(&password), + ) + .await?) + } + } +} + fn unauthorized_resp() -> Response where RespBody: Body + Default, @@ -171,79 +211,7 @@ fn decode_basic(credential: Credential) -> Result<(Username, Password)> { #[cfg(test)] mod tests { - use std::marker::PhantomData; - use std::sync::Arc; - - use axum::body::BoxBody; - use axum::http; - use hyper::Request; - use session::context::UserInfo; - use tower_http::auth::AsyncAuthorizeRequest; - - use super::{auth_header, decode_basic, AuthScheme, HttpAuth}; - use crate::auth::test::MockUserProvider; - use crate::auth::UserProvider; - use crate::error; - use crate::error::Result; - - #[tokio::test] - async fn test_http_auth() { - let mut http_auth: HttpAuth = HttpAuth { - user_provider: None, - _ty: PhantomData, - }; - - // base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ=" - let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap(); - let auth_res = http_auth.authorize(req).await.unwrap(); - let user_info: &UserInfo = auth_res.extensions().get().unwrap(); - let default = UserInfo::default(); - assert_eq!(default.username(), user_info.username()); - - // In mock user provider, right username:password == "greptime:greptime" - let mock_user_provider = Some(Arc::new(MockUserProvider {}) as Arc); - let mut http_auth: HttpAuth = HttpAuth { - user_provider: mock_user_provider, - _ty: PhantomData, - }; - - // base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU=" - let req = mock_http_request(Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="), None).unwrap(); - let req = http_auth.authorize(req).await.unwrap(); - let user_info: &UserInfo = req.extensions().get().unwrap(); - let default = UserInfo::default(); - assert_eq!(default.username(), user_info.username()); - - let req = mock_http_request(None, None).unwrap(); - let auth_res = http_auth.authorize(req).await; - assert!(auth_res.is_err()); - - // base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ=" - let wrong_req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap(); - let auth_res = http_auth.authorize(wrong_req).await; - assert!(auth_res.is_err()); - } - - #[tokio::test] - async fn test_whitelist_no_auth() { - // In mock user provider, right username:password == "greptime:greptime" - let mock_user_provider = Some(Arc::new(MockUserProvider {}) as Arc); - let mut http_auth: HttpAuth = HttpAuth { - user_provider: mock_user_provider, - _ty: PhantomData, - }; - - // base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU=" - // try auth path first - let req = mock_http_request(None, None).unwrap(); - let req = http_auth.authorize(req).await; - assert!(req.is_err()); - - // try whitelist path - let req = mock_http_request(None, Some("http://localhost/health")).unwrap(); - let req = http_auth.authorize(req).await; - assert!(req.is_ok()); - } + use super::*; #[test] fn test_decode_basic() { diff --git a/src/servers/src/http/handler.rs b/src/servers/src/http/handler.rs index 0bd481e1a2..0c7daac6b8 100644 --- a/src/servers/src/http/handler.rs +++ b/src/servers/src/http/handler.rs @@ -28,7 +28,7 @@ use crate::http::{ApiState, JsonResponse}; #[derive(Debug, Default, Serialize, Deserialize, JsonSchema)] pub struct SqlQuery { - pub database: Option, + pub db: Option, pub sql: Option, } @@ -43,7 +43,7 @@ pub async fn sql( let sql_handler = &state.sql_handler; let start = Instant::now(); let resp = if let Some(sql) = ¶ms.sql { - match super::query_context_from_db(sql_handler.clone(), params.database) { + match super::query_context_from_db(sql_handler.clone(), params.db) { Ok(query_ctx) => { JsonResponse::from_output(sql_handler.do_query(sql, query_ctx).await).await } diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index eb367d7c1e..4fb8e9d566 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -123,7 +123,7 @@ impl AsyncMysqlShim for MysqlInstanceShi return false; } }; - match user_provider.auth(user_id, password).await { + match user_provider.authenticate(user_id, password).await { Ok(userinfo) => { user_info = Some(userinfo); } @@ -190,6 +190,12 @@ impl AsyncMysqlShim for MysqlInstanceShi error::DatabaseNotFoundSnafu { catalog, schema } ); + if let Some(schema_validator) = &self.user_provider { + schema_validator + .authorize(catalog, schema, &self.session.user_info()) + .await?; + } + let context = self.session.context(); context.set_current_catalog(catalog); context.set_current_schema(database); diff --git a/src/servers/src/mysql/server.rs b/src/servers/src/mysql/server.rs index 49c4d67c8b..830d3f858e 100644 --- a/src/servers/src/mysql/server.rs +++ b/src/servers/src/mysql/server.rs @@ -38,6 +38,15 @@ use crate::tls::TlsOption; // Default size of ResultSet write buffer: 100KB const DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE: usize = 100 * 1024; +struct MysqlRuntimeOption { + query_handler: ServerSqlQueryHandlerRef, + tls_conf: Option>, + force_tls: bool, + user_provider: Option, +} + +type MysqlRuntimeOptionRef = Arc; + pub struct MysqlServer { base_server: BaseTcpServer, query_handler: ServerSqlQueryHandlerRef, @@ -77,19 +86,19 @@ impl MysqlServer { let user_provider = user_provider.clone(); let tls_conf = tls_conf.clone(); + let mysql_runtime_option = Arc::new(MysqlRuntimeOption { + query_handler, + tls_conf, + force_tls, + user_provider, + }); + async move { match tcp_stream { Err(error) => error!("Broken pipe: {}", error), // IoError doesn't impl ErrorExt. Ok(io_stream) => { - if let Err(error) = Self::handle( - io_stream, - io_runtime, - query_handler, - tls_conf, - force_tls, - user_provider, - ) - .await + if let Err(error) = + Self::handle(io_stream, io_runtime, mysql_runtime_option).await { error!(error; "Unexpected error when handling TcpStream"); }; @@ -102,15 +111,12 @@ impl MysqlServer { async fn handle( stream: TcpStream, io_runtime: Arc, - query_handler: ServerSqlQueryHandlerRef, - tls_conf: Option>, - force_tls: bool, - user_provider: Option, + runtime_opts: MysqlRuntimeOptionRef, ) -> Result<()> { info!("MySQL connection coming from: {}", stream.peer_addr()?); io_runtime .spawn(async move { // TODO(LFC): Use `output_stream` to write large MySQL ResultSet to client. - if let Err(e) = Self::do_handle(stream, query_handler, tls_conf, force_tls, user_provider).await { + if let Err(e) = Self::do_handle(stream, runtime_opts).await { // TODO(LFC): Write this error to client as well, in MySQL text protocol. // Looks like we have to expose opensrv-mysql's `PacketWriter`? error!(e; "Internal error occurred during query exec, server actively close the channel to let client try next time.") @@ -120,28 +126,31 @@ impl MysqlServer { Ok(()) } - async fn do_handle( - stream: TcpStream, - query_handler: ServerSqlQueryHandlerRef, - tls_conf: Option>, - force_tls: bool, - user_provider: Option, - ) -> Result<()> { - let mut shim = MysqlInstanceShim::create(query_handler, stream.peer_addr()?, user_provider); + async fn do_handle(stream: TcpStream, runtime_opts: MysqlRuntimeOptionRef) -> Result<()> { + let mut shim = MysqlInstanceShim::create( + runtime_opts.query_handler.clone(), + stream.peer_addr()?, + runtime_opts.user_provider.clone(), + ); let (mut r, w) = stream.into_split(); let mut w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w); let ops = IntermediaryOptions::default(); - let (client_tls, init_params) = - AsyncMysqlIntermediary::init_before_ssl(&mut shim, &mut r, &mut w, &tls_conf).await?; + let (client_tls, init_params) = AsyncMysqlIntermediary::init_before_ssl( + &mut shim, + &mut r, + &mut w, + &runtime_opts.tls_conf, + ) + .await?; - if force_tls && !client_tls { + if runtime_opts.force_tls && !client_tls { return Err(Error::TlsRequired { server: "mysql".to_owned(), }); } - match tls_conf { + match runtime_opts.tls_conf.clone() { Some(tls_conf) if client_tls => { secure_run_with_options(shim, w, ops, tls_conf, init_params).await } diff --git a/src/servers/src/postgres/auth_handler.rs b/src/servers/src/postgres/auth_handler.rs index dd7315f9d2..15988589be 100644 --- a/src/servers/src/postgres/auth_handler.rs +++ b/src/servers/src/postgres/auth_handler.rs @@ -23,6 +23,7 @@ use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use pgwire::messages::response::ErrorResponse; use pgwire::messages::startup::Authentication; use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage}; +use session::context::{UserInfo, DEFAULT_USERNAME}; use snafu::ResultExt; use crate::auth::{Identity, Password, UserProviderRef}; @@ -30,14 +31,15 @@ use crate::error; use crate::error::Result; use crate::query_handler::sql::ServerSqlQueryHandlerRef; -struct PgPwdVerifier { +struct PgLoginVerifier { user_provider: Option, } #[allow(dead_code)] struct LoginInfo { user: Option, - database: Option, + catalog: Option, + schema: Option, host: String, } @@ -48,27 +50,31 @@ impl LoginInfo { { LoginInfo { user: client.metadata().get(super::METADATA_USER).map(Into::into), - database: client + catalog: client .metadata() - .get(super::METADATA_DATABASE) + .get(super::METADATA_CATALOG) + .map(Into::into), + schema: client + .metadata() + .get(super::METADATA_SCHEMA) .map(Into::into), host: client.socket_addr().ip().to_string(), } } } -impl PgPwdVerifier { - async fn verify_pwd(&self, password: &str, login: LoginInfo) -> Result { +impl PgLoginVerifier { + async fn verify_pwd(&self, password: &str, login: &LoginInfo) -> Result { if let Some(user_provider) = &self.user_provider { - let user_name = match login.user { + let user_name = match &login.user { Some(name) => name, None => return Ok(false), }; // TODO(fys): pass user_info to context let _user_info = user_provider - .auth( - Identity::UserId(&user_name, None), + .authenticate( + Identity::UserId(user_name, None), Password::PlainText(password), ) .await @@ -76,6 +82,29 @@ impl PgPwdVerifier { } Ok(true) } + + async fn authorize(&self, login: &LoginInfo) -> Result { + // at this time, username in login info should be valid + // TODO(shuiyisong): change to use actually user_info from session + if let Some(user_provider) = &self.user_provider { + let user_name = match &login.user { + Some(name) => name, + None => return Ok(false), + }; + let catalog = match &login.catalog { + Some(name) => name, + None => return Ok(false), + }; + let schema = match &login.schema { + Some(name) => name, + None => return Ok(false), + }; + user_provider + .authorize(catalog, schema, &UserInfo::new(user_name)) + .await?; + } + Ok(true) + } } struct GreptimeDBStartupParameters { @@ -106,7 +135,7 @@ impl ServerParameterProvider for GreptimeDBStartupParameters { } pub struct PgAuthStartupHandler { - verifier: PgPwdVerifier, + verifier: PgLoginVerifier, param_provider: GreptimeDBStartupParameters, force_tls: bool, query_handler: ServerSqlQueryHandlerRef, @@ -119,7 +148,7 @@ impl PgAuthStartupHandler { query_handler: ServerSqlQueryHandlerRef, ) -> Self { PgAuthStartupHandler { - verifier: PgPwdVerifier { user_provider }, + verifier: PgLoginVerifier { user_provider }, param_provider: GreptimeDBStartupParameters::new(), force_tls, query_handler, @@ -173,22 +202,50 @@ impl StartupHandler for PgAuthStartupHandler { )) .await?; } else { + // no user is provided, use default user + // and still do authorization + let mut login_info = LoginInfo::from_client_info(client); + login_info.user = Some(DEFAULT_USERNAME.to_string()); + + let authorize_result = self.verifier.authorize(&login_info).await; + if !matches!(authorize_result, Ok(true)) { + return send_error( + client, + "FATAL", + "28P01", + "password authorization failed".to_owned(), + ) + .await; + } auth::finish_authentication(client, &self.param_provider).await; } } PgWireFrontendMessage::Password(ref pwd) => { let login_info = LoginInfo::from_client_info(client); - if let Ok(true) = self.verifier.verify_pwd(pwd.password(), login_info).await { - auth::finish_authentication(client, &self.param_provider).await - } else { - send_error( + // do authenticate + let authenticate_result = + self.verifier.verify_pwd(pwd.password(), &login_info).await; + if !matches!(authenticate_result, Ok(true)) { + return send_error( client, "FATAL", "28P01", - "Password authentication failed".to_owned(), + "password authentication failed".to_owned(), ) - .await?; + .await; } + // do authorize + let authorize_result = self.verifier.authorize(&login_info).await; + if !matches!(authorize_result, Ok(true)) { + return send_error( + client, + "FATAL", + "28P01", + "password authorization failed".to_owned(), + ) + .await; + } + auth::finish_authentication(client, &self.param_provider).await; } _ => {} } diff --git a/src/servers/tests/auth.rs b/src/servers/tests/auth.rs new file mode 100644 index 0000000000..8e9c9b8546 --- /dev/null +++ b/src/servers/tests/auth.rs @@ -0,0 +1,197 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use servers::auth::user_provider::auth_mysql; +use servers::auth::{ + AccessDeniedSnafu, Identity, Password, UnsupportedPasswordTypeSnafu, UserNotFoundSnafu, + UserPasswordMismatchSnafu, UserProvider, +}; +use session::context::UserInfo; + +pub struct DatabaseAuthInfo<'a> { + pub catalog: &'a str, + pub schema: &'a str, + pub username: &'a str, +} + +pub struct MockUserProvider { + pub catalog: String, + pub schema: String, + pub username: String, +} + +impl Default for MockUserProvider { + fn default() -> Self { + MockUserProvider { + catalog: "greptime".to_owned(), + schema: "public".to_owned(), + username: "greptime".to_owned(), + } + } +} + +impl MockUserProvider { + pub fn set_authorization_info(&mut self, info: DatabaseAuthInfo) { + self.catalog = info.catalog.to_owned(); + self.schema = info.schema.to_owned(); + self.username = info.username.to_owned(); + } +} + +#[async_trait::async_trait] +impl UserProvider for MockUserProvider { + fn name(&self) -> &str { + "mock_user_provider" + } + + async fn authenticate( + &self, + id: Identity<'_>, + password: Password<'_>, + ) -> servers::auth::Result { + match id { + Identity::UserId(username, _host) => match password { + Password::PlainText(password) => { + if username == "greptime" { + if password == "greptime" { + Ok(UserInfo::new("greptime")) + } else { + UserPasswordMismatchSnafu { + username: username.to_string(), + } + .fail() + } + } else { + UserNotFoundSnafu { + username: username.to_string(), + } + .fail() + } + } + Password::MysqlNativePassword(auth_data, salt) => { + auth_mysql(auth_data, salt, username, "greptime".as_bytes()) + .map(|_| UserInfo::new(username)) + } + _ => UnsupportedPasswordTypeSnafu { + password_type: "mysql_native_password", + } + .fail(), + }, + } + } + + async fn authorize( + &self, + catalog: &str, + schema: &str, + user_info: &UserInfo, + ) -> servers::auth::Result<()> { + if catalog == self.catalog && schema == self.schema && user_info.username() == self.username + { + Ok(()) + } else { + AccessDeniedSnafu { + catalog: catalog.to_string(), + schema: schema.to_string(), + username: user_info.username().to_string(), + } + .fail() + } + } +} + +#[tokio::test] +async fn test_auth_by_plain_text() { + let user_provider = MockUserProvider::default(); + assert_eq!("mock_user_provider", user_provider.name()); + + // auth success + let auth_result = user_provider + .authenticate( + Identity::UserId("greptime", None), + Password::PlainText("greptime"), + ) + .await; + assert!(auth_result.is_ok()); + assert_eq!("greptime", auth_result.unwrap().username()); + + // auth failed, unsupported password type + let auth_result = user_provider + .authenticate( + Identity::UserId("greptime", None), + Password::PgMD5(b"hashed_value", b"salt"), + ) + .await; + assert!(auth_result.is_err()); + matches!( + auth_result.err().unwrap(), + servers::auth::Error::UnsupportedPasswordType { .. } + ); + + // auth failed, err: user not exist. + let auth_result = user_provider + .authenticate( + Identity::UserId("not_exist_username", None), + Password::PlainText("greptime"), + ) + .await; + assert!(auth_result.is_err()); + matches!( + auth_result.err().unwrap(), + servers::auth::Error::UserNotFound { .. } + ); + + // auth failed, err: wrong password + let auth_result = user_provider + .authenticate( + Identity::UserId("greptime", None), + Password::PlainText("wrong_password"), + ) + .await; + assert!(auth_result.is_err()); + matches!( + auth_result.err().unwrap(), + servers::auth::Error::UserPasswordMismatch { .. } + ); +} + +#[tokio::test] +async fn test_schema_validate() { + let mut validator = MockUserProvider::default(); + validator.set_authorization_info(DatabaseAuthInfo { + catalog: "greptime", + schema: "public", + username: "test_user", + }); + + let right_user = UserInfo::new("test_user"); + let wrong_user = UserInfo::default(); + + // check catalog + let re = validator + .authorize("greptime_wrong", "public", &right_user) + .await; + assert!(re.is_err()); + // check schema + let re = validator + .authorize("greptime", "public_wrong", &right_user) + .await; + assert!(re.is_err()); + // check username + let re = validator.authorize("greptime", "public", &wrong_user).await; + assert!(re.is_err()); + // check ok + let re = validator.authorize("greptime", "public", &right_user).await; + assert!(re.is_ok()); +} diff --git a/src/servers/tests/http/authorize.rs b/src/servers/tests/http/authorize.rs new file mode 100644 index 0000000000..7f99d84109 --- /dev/null +++ b/src/servers/tests/http/authorize.rs @@ -0,0 +1,120 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use axum::body::BoxBody; +use axum::http; +use hyper::Request; +use servers::auth::UserProvider; +use servers::http::authorize::HttpAuth; +use session::context::UserInfo; +use tower_http::auth::AsyncAuthorizeRequest; + +use crate::auth::MockUserProvider; + +#[tokio::test] +async fn test_http_auth() { + let mut http_auth: HttpAuth = HttpAuth::new(None); + + // base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ=" + let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap(); + let auth_res = http_auth.authorize(req).await.unwrap(); + let user_info: &UserInfo = auth_res.extensions().get().unwrap(); + let default = UserInfo::default(); + assert_eq!(default.username(), user_info.username()); + + // In mock user provider, right username:password == "greptime:greptime" + let mock_user_provider = Some(Arc::new(MockUserProvider::default()) as Arc); + let mut http_auth: HttpAuth = HttpAuth::new(mock_user_provider); + + // base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU=" + let req = mock_http_request(Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="), None).unwrap(); + let req = http_auth.authorize(req).await.unwrap(); + let user_info: &UserInfo = req.extensions().get().unwrap(); + let default = UserInfo::default(); + assert_eq!(default.username(), user_info.username()); + + let req = mock_http_request(None, None).unwrap(); + let auth_res = http_auth.authorize(req).await; + assert!(auth_res.is_err()); + + // base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ=" + let wrong_req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap(); + let auth_res = http_auth.authorize(wrong_req).await; + assert!(auth_res.is_err()); +} + +#[tokio::test] +async fn test_schema_validating() { + // In mock user provider, right username:password == "greptime:greptime" + let provider = MockUserProvider::default(); + let mock_user_provider = Some(Arc::new(provider) as Arc); + let mut http_auth: HttpAuth = HttpAuth::new(mock_user_provider); + + // base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU=" + // http://localhost/{http_api_version}/sql?db=greptime + let version = servers::http::HTTP_API_VERSION; + let req = mock_http_request( + Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="), + Some(format!("http://localhost/{version}/sql?db=public").as_str()), + ) + .unwrap(); + let req = http_auth.authorize(req).await.unwrap(); + let user_info: &UserInfo = req.extensions().get().unwrap(); + let default = UserInfo::default(); + assert_eq!(default.username(), user_info.username()); + + // wrong database + let req = mock_http_request( + Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="), + Some(format!("http://localhost/{version}/sql?db=wrong").as_str()), + ) + .unwrap(); + let result = http_auth.authorize(req).await; + assert!(result.is_err()); +} + +#[tokio::test] +async fn test_whitelist_no_auth() { + // In mock user provider, right username:password == "greptime:greptime" + let mock_user_provider = Some(Arc::new(MockUserProvider::default()) as Arc); + let mut http_auth: HttpAuth = HttpAuth::new(mock_user_provider); + + // base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU=" + // try auth path first + let req = mock_http_request(None, None).unwrap(); + let req = http_auth.authorize(req).await; + assert!(req.is_err()); + + // try whitelist path + let req = mock_http_request(None, Some("http://localhost/health")).unwrap(); + let req = http_auth.authorize(req).await; + assert!(req.is_ok()); +} + +// copy from http::authorize +fn mock_http_request( + auth_header: Option<&str>, + uri: Option<&str>, +) -> servers::error::Result> { + let http_api_version = servers::http::HTTP_API_VERSION; + let mut req = Request::builder() + .uri(uri.unwrap_or(format!("http://localhost/{http_api_version}/sql?db=public").as_str())); + if let Some(auth_header) = auth_header { + req = req.header(http::header::AUTHORIZATION, auth_header); + } + + Ok(req.body(()).unwrap()) +} diff --git a/src/servers/tests/http/http_handler_test.rs b/src/servers/tests/http/http_handler_test.rs index 09b80ef9d6..0ff2701533 100644 --- a/src/servers/tests/http/http_handler_test.rs +++ b/src/servers/tests/http/http_handler_test.rs @@ -135,7 +135,7 @@ fn create_invalid_script_query() -> Query { fn create_query() -> Query { Query(http_handler::SqlQuery { sql: Some("select sum(uint32s) from numbers limit 20".to_string()), - database: None, + db: None, }) } diff --git a/src/servers/tests/http/influxdb_test.rs b/src/servers/tests/http/influxdb_test.rs index 06e73f68bb..979b6b31ee 100644 --- a/src/servers/tests/http/influxdb_test.rs +++ b/src/servers/tests/http/influxdb_test.rs @@ -19,7 +19,6 @@ use async_trait::async_trait; use axum::{http, Router}; use axum_test_helper::TestClient; use common_query::Output; -use servers::auth::user_provider::StaticUserProvider; use servers::error::{Error, Result}; use servers::http::{HttpOptions, HttpServer}; use servers::influxdb::InfluxdbRequest; @@ -28,8 +27,10 @@ use servers::query_handler::InfluxdbLineProtocolHandler; use session::context::QueryContextRef; use tokio::sync::mpsc; +use crate::auth::{DatabaseAuthInfo, MockUserProvider}; + struct DummyInstance { - tx: mpsc::Sender<(String, String)>, + tx: Arc>, } #[async_trait] @@ -66,11 +67,18 @@ impl SqlQueryHandler for DummyInstance { } } -fn make_test_app(tx: mpsc::Sender<(String, String)>) -> Router { +fn make_test_app(tx: Arc>, db_name: Option<&str>) -> Router { let instance = Arc::new(DummyInstance { tx }); let mut server = HttpServer::new(instance.clone(), HttpOptions::default()); - let up = StaticUserProvider::try_from("cmd:greptime=greptime").unwrap(); - server.set_user_provider(Arc::new(up)); + let mut user_provider = MockUserProvider::default(); + if let Some(name) = db_name { + user_provider.set_authorization_info(DatabaseAuthInfo { + catalog: "greptime", + schema: name, + username: "greptime", + }) + } + server.set_user_provider(Arc::new(user_provider)); server.set_influxdb_handler(instance); server.make_app() @@ -79,13 +87,14 @@ fn make_test_app(tx: mpsc::Sender<(String, String)>) -> Router { #[tokio::test] async fn test_influxdb_write() { let (tx, mut rx) = mpsc::channel(100); + let tx = Arc::new(tx); - let app = make_test_app(tx); + let app = make_test_app(tx.clone(), None); let client = TestClient::new(app); // right request let result = client - .post("/v1/influxdb/write") + .post("/v1/influxdb/write?db=public") .body("monitor,host=host1 cpu=1.2 1664370459457010101") .header( http::header::AUTHORIZATION, @@ -96,6 +105,10 @@ async fn test_influxdb_write() { assert_eq!(result.status(), 204); assert!(result.text().await.is_empty()); + // make new app for db=influxdb + let app = make_test_app(tx, Some("influxdb")); + let client = TestClient::new(app); + let result = client .post("/v1/influxdb/write?db=influxdb") .body("monitor,host=host1 cpu=1.2 1664370459457010101") @@ -110,7 +123,7 @@ async fn test_influxdb_write() { // bad request let result = client - .post("/v1/influxdb/write") + .post("/v1/influxdb/write?db=influxdb") .body("monitor, host=host1 cpu=1.2 1664370459457010101") .header( http::header::AUTHORIZATION, diff --git a/src/servers/tests/http/mod.rs b/src/servers/tests/http/mod.rs index 4dea49b9fd..5a39ed348e 100644 --- a/src/servers/tests/http/mod.rs +++ b/src/servers/tests/http/mod.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod authorize; mod http_handler_test; mod influxdb_test; mod opentsdb_test; diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index d4b265b1ea..6e8e833640 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -22,21 +22,20 @@ use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_query::Output; use query::parser::QueryLanguageParser; use query::{QueryEngineFactory, QueryEngineRef}; -use servers::error::{Error, Result}; -use servers::query_handler::{ScriptHandler, ScriptHandlerRef}; -use table::test_util::MemTable; - -mod http; -mod mysql; use script::engine::{CompileContext, EvalContext, Script, ScriptEngine}; use script::python::{PyEngine, PyScript}; +use servers::error::{Error, Result}; use servers::query_handler::sql::{ServerSqlQueryHandlerRef, SqlQueryHandler}; +use servers::query_handler::{ScriptHandler, ScriptHandlerRef}; use session::context::QueryContextRef; +use table::test_util::MemTable; +mod auth; +mod http; mod interceptor; +mod mysql; mod opentsdb; mod postgres; - mod py_script; struct DummyInstance { diff --git a/src/servers/tests/mysql/mysql_server_test.rs b/src/servers/tests/mysql/mysql_server_test.rs index 280d0a1dc2..0ddb426162 100644 --- a/src/servers/tests/mysql/mysql_server_test.rs +++ b/src/servers/tests/mysql/mysql_server_test.rs @@ -24,17 +24,21 @@ use mysql_async::prelude::*; use mysql_async::SslOpts; use rand::rngs::StdRng; use rand::Rng; -use servers::auth::user_provider::StaticUserProvider; use servers::error::Result; use servers::mysql::server::MysqlServer; use servers::server::Server; use servers::tls::TlsOption; use table::test_util::MemTable; +use crate::auth::{DatabaseAuthInfo, MockUserProvider}; use crate::create_testing_sql_query_handler; use crate::mysql::{all_datatype_testing_data, MysqlTextRow, TestingData}; -fn create_mysql_server(table: MemTable, tls: TlsOption) -> Result> { +fn create_mysql_server( + table: MemTable, + tls: TlsOption, + auth_info: Option, +) -> Result> { let query_handler = create_testing_sql_query_handler(table); let io_runtime = Arc::new( RuntimeBuilder::default() @@ -44,7 +48,10 @@ fn create_mysql_server(table: MemTable, tls: TlsOption) -> Result Result Result<()> { let table = MemTable::default_numbers_table(); - let mysql_server = create_mysql_server(table, Default::default())?; + let mysql_server = create_mysql_server(table, Default::default(), None)?; let listening = "127.0.0.1:0".parse::().unwrap(); let result = mysql_server.start(listening).await; assert!(result.is_ok()); @@ -71,13 +78,54 @@ async fn test_start_mysql_server() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_schema_validation() -> Result<()> { + async fn generate_server(auth_info: DatabaseAuthInfo<'_>) -> Result<(Box, u16)> { + let table = MemTable::default_numbers_table(); + let mysql_server = create_mysql_server(table, Default::default(), Some(auth_info))?; + let listening = "127.0.0.1:0".parse::().unwrap(); + let server_addr = mysql_server.start(listening).await.unwrap(); + Ok((mysql_server, server_addr.port())) + } + + common_telemetry::init_default_ut_logging(); + let (mysql_server, server_port) = generate_server(DatabaseAuthInfo { + catalog: "greptime", + schema: "public", + username: "greptime", + }) + .await?; + + //TODO(shuiyisong): mysql conn without dbname rejection is not implemented yet, add test later. + + let pass = create_connection(server_port, Some("public"), false).await; + assert!(pass.is_ok()); + let result = mysql_server.shutdown().await; + assert!(result.is_ok()); + + // change to another username + let (mysql_server, server_port) = generate_server(DatabaseAuthInfo { + catalog: "greptime", + schema: "public", + username: "no_access_user", + }) + .await?; + + let fail = create_connection(server_port, Some("public"), false).await; + assert!(fail.is_err()); + let result = mysql_server.shutdown().await; + assert!(result.is_ok()); + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn test_shutdown_mysql_server() -> Result<()> { common_telemetry::init_default_ut_logging(); let table = MemTable::default_numbers_table(); - let mysql_server = create_mysql_server(table, Default::default())?; + let mysql_server = create_mysql_server(table, Default::default(), None)?; let result = mysql_server.shutdown().await; assert!(result .unwrap_err() @@ -193,7 +241,7 @@ async fn test_server_required_secure_client_plain() -> Result<()> { let recordbatch = RecordBatch::new(schema, columns).unwrap(); let table = MemTable::new("all_datatypes", recordbatch); - let mysql_server = create_mysql_server(table, server_tls)?; + let mysql_server = create_mysql_server(table, server_tls, None)?; let listening = "127.0.0.1:0".parse::().unwrap(); let server_addr = mysql_server.start(listening).await.unwrap(); @@ -219,7 +267,7 @@ async fn test_db_name() -> Result<()> { let recordbatch = RecordBatch::new(schema, columns).unwrap(); let table = MemTable::new("all_datatypes", recordbatch); - let mysql_server = create_mysql_server(table, server_tls)?; + let mysql_server = create_mysql_server(table, server_tls, None)?; let listening = "127.0.0.1:0".parse::().unwrap(); let server_addr = mysql_server.start(listening).await.unwrap(); @@ -247,7 +295,7 @@ async fn do_test_query_all_datatypes(server_tls: TlsOption, client_tls: bool) -> let recordbatch = RecordBatch::new(schema, columns).unwrap(); let table = MemTable::new("all_datatypes", recordbatch); - let mysql_server = create_mysql_server(table, server_tls)?; + let mysql_server = create_mysql_server(table, server_tls, None)?; let listening = "127.0.0.1:0".parse::().unwrap(); let server_addr = mysql_server.start(listening).await.unwrap(); @@ -282,7 +330,7 @@ async fn test_query_concurrently() -> Result<()> { let table = MemTable::default_numbers_table(); - let mysql_server = create_mysql_server(table, Default::default())?; + let mysql_server = create_mysql_server(table, Default::default(), None)?; let listening = "127.0.0.1:0".parse::().unwrap(); let server_addr = mysql_server.start(listening).await.unwrap(); let server_port = server_addr.port(); @@ -332,7 +380,7 @@ async fn create_connection( .tcp_port(port) .prefer_socket(false) .wait_timeout(Some(1000)) - .db_name(db_name) + .db_name(db_name.or(Some(DEFAULT_SCHEMA_NAME))) .user(Some("greptime".to_string())) .pass(Some("greptime".to_string())); diff --git a/src/servers/tests/postgres/mod.rs b/src/servers/tests/postgres/mod.rs index 4dac19c4da..de05fe66bf 100644 --- a/src/servers/tests/postgres/mod.rs +++ b/src/servers/tests/postgres/mod.rs @@ -22,7 +22,6 @@ use rand::rngs::StdRng; use rand::Rng; use rustls::client::{ServerCertVerified, ServerCertVerifier}; use rustls::{Certificate, Error, ServerName}; -use servers::auth::user_provider::StaticUserProvider; use servers::auth::UserProviderRef; use servers::error::Result; use servers::postgres::PostgresServer; @@ -31,12 +30,14 @@ use servers::tls::TlsOption; use table::test_util::MemTable; use tokio_postgres::{Client, Error as PgError, NoTls, SimpleQueryMessage}; +use crate::auth::{DatabaseAuthInfo, MockUserProvider}; use crate::create_testing_instance; fn create_postgres_server( table: MemTable, check_pwd: bool, tls: TlsOption, + auth_info: Option, ) -> Result> { let instance = Arc::new(create_testing_instance(table)); let io_runtime = Arc::new( @@ -47,9 +48,11 @@ fn create_postgres_server( .unwrap(), ); let user_provider: Option = if check_pwd { - Some(Arc::new( - StaticUserProvider::try_from("cmd:test_user=test_pwd").unwrap(), - )) + let mut provider = MockUserProvider::default(); + if let Some(info) = auth_info { + provider.set_authorization_info(info); + } + Some(Arc::new(provider)) } else { None }; @@ -66,7 +69,7 @@ fn create_postgres_server( pub async fn test_start_postgres_server() -> Result<()> { let table = MemTable::default_numbers_table(); - let pg_server = create_postgres_server(table, false, Default::default())?; + let pg_server = create_postgres_server(table, false, Default::default(), None)?; let listening = "127.0.0.1:0".parse::().unwrap(); let result = pg_server.start(listening).await; assert!(result.is_ok()); @@ -86,12 +89,52 @@ async fn test_shutdown_pg_server_range() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_schema_validating() -> Result<()> { + async fn generate_server(auth_info: DatabaseAuthInfo<'_>) -> Result<(Box, u16)> { + let table = MemTable::default_numbers_table(); + let postgres_server = + create_postgres_server(table, true, Default::default(), Some(auth_info))?; + let listening = "127.0.0.1:5432".parse::().unwrap(); + let server_addr = postgres_server.start(listening).await.unwrap(); + let server_port = server_addr.port(); + Ok((postgres_server, server_port)) + } + + common_telemetry::init_default_ut_logging(); + let (pg_server, server_port) = generate_server(DatabaseAuthInfo { + catalog: "greptime", + schema: "public", + username: "greptime", + }) + .await?; + + let pass = create_plain_connection(server_port, true).await; + assert!(pass.is_ok()); + let result = pg_server.shutdown().await; + assert!(result.is_ok()); + + let (pg_server, server_port) = generate_server(DatabaseAuthInfo { + catalog: "greptime", + schema: "public", + username: "no_right_user", + }) + .await?; + + let fail = create_plain_connection(server_port, true).await; + assert!(fail.is_err()); + let result = pg_server.shutdown().await; + assert!(result.is_ok()); + + Ok(()) +} + // #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> { common_telemetry::init_default_ut_logging(); let table = MemTable::default_numbers_table(); - let postgres_server = create_postgres_server(table, with_pwd, Default::default())?; + let postgres_server = create_postgres_server(table, with_pwd, Default::default(), None)?; let result = postgres_server.shutdown().await; assert!(result .unwrap_err() @@ -273,7 +316,7 @@ async fn test_using_db() -> Result<()> { async fn start_test_server(server_tls: TlsOption) -> Result { common_telemetry::init_default_ut_logging(); let table = MemTable::default_numbers_table(); - let pg_server = create_postgres_server(table, false, server_tls)?; + let pg_server = create_postgres_server(table, false, server_tls, None)?; let listening = "127.0.0.1:0".parse::().unwrap(); let server_addr = pg_server.start(listening).await.unwrap(); Ok(server_addr.port()) @@ -301,7 +344,7 @@ async fn create_secure_connection( ) -> std::result::Result { let url = if with_pwd { format!( - "sslmode=require host=127.0.0.1 port={port} user=test_user password=test_pwd connect_timeout=2, dbname={DEFAULT_SCHEMA_NAME}", + "sslmode=require host=127.0.0.1 port={port} user=greptime password=greptime connect_timeout=2, dbname={DEFAULT_SCHEMA_NAME}", ) } else { format!("host=127.0.0.1 port={port} connect_timeout=2 dbname={DEFAULT_SCHEMA_NAME}") @@ -328,7 +371,7 @@ async fn create_plain_connection( ) -> std::result::Result { let url = if with_pwd { format!( - "host=127.0.0.1 port={port} user=test_user password=test_pwd connect_timeout=2 dbname={DEFAULT_SCHEMA_NAME}", + "host=127.0.0.1 port={port} user=greptime password=greptime connect_timeout=2 dbname={DEFAULT_SCHEMA_NAME}", ) } else { format!("host=127.0.0.1 port={port} connect_timeout=2 dbname={DEFAULT_SCHEMA_NAME}") diff --git a/tests-integration/tests/http.rs b/tests-integration/tests/http.rs index 7c34810d89..7634f9b928 100644 --- a/tests-integration/tests/http.rs +++ b/tests-integration/tests/http.rs @@ -195,7 +195,7 @@ pub async fn test_sql_api(store_type: StorageType) { // test database given let res = client - .get("/v1/sql?database=public&sql=select cpu, ts from demo limit 1") + .get("/v1/sql?db=public&sql=select cpu, ts from demo limit 1") .send() .await; assert_eq!(res.status(), StatusCode::OK); @@ -214,7 +214,7 @@ pub async fn test_sql_api(store_type: StorageType) { // test database not found let res = client - .get("/v1/sql?database=notfound&sql=select cpu, ts from demo limit 1") + .get("/v1/sql?db=notfound&sql=select cpu, ts from demo limit 1") .send() .await; assert_eq!(res.status(), StatusCode::OK); @@ -223,7 +223,7 @@ pub async fn test_sql_api(store_type: StorageType) { // test catalog-schema given let res = client - .get("/v1/sql?database=greptime-public&sql=select cpu, ts from demo limit 1") + .get("/v1/sql?db=greptime-public&sql=select cpu, ts from demo limit 1") .send() .await; assert_eq!(res.status(), StatusCode::OK); @@ -242,7 +242,7 @@ pub async fn test_sql_api(store_type: StorageType) { // test invalid catalog let res = client - .get("/v1/sql?database=notfound2-schema&sql=select cpu, ts from demo limit 1") + .get("/v1/sql?db=notfound2-schema&sql=select cpu, ts from demo limit 1") .send() .await; assert_eq!(res.status(), StatusCode::OK); @@ -251,7 +251,7 @@ pub async fn test_sql_api(store_type: StorageType) { // test invalid schema let res = client - .get("/v1/sql?database=greptime-schema&sql=select cpu, ts from demo limit 1") + .get("/v1/sql?db=greptime-schema&sql=select cpu, ts from demo limit 1") .send() .await; assert_eq!(res.status(), StatusCode::OK);