refactor: auth crate (#2148)

* chore: move user_info to auth crate

* chore: temp commit before resolving tests compile error

* chore: fix compile issue

* chore: minor fix

* chore: tmp save

* chore: change user_info to trait

* chore: minor change & use auth result user info in pg session setup

* chore: add as_any to user_info

* chore: rename user_info

* chore: remove ice file

* chore: add permission checker

* chore: add grpc permission check

* chore: add session spawn user_info to query_ctx

* chore: minor update

* chore: add permission checker to sql handler & temp save

* chore: add permission checker to prometheus handler

* chore: add permission checker to opentsdb handler

* chore: add permission checker to other handlers

* chore: add test

* chore: add user_info setting on http entrance

* chore: fix toml

* chore: remove box in permission req

* chore: cr issue

* chore: cr issue
This commit is contained in:
shuiyisong
2023-08-14 10:51:26 +08:00
committed by GitHub
parent 6d64e1c296
commit 7f51141ed0
58 changed files with 690 additions and 301 deletions

22
Cargo.lock generated
View File

@@ -672,6 +672,23 @@ dependencies = [
"winapi",
]
[[package]]
name = "auth"
version = "0.3.2"
dependencies = [
"api",
"async-trait",
"common-error",
"common-test-util",
"digest",
"hex",
"secrecy",
"sha1",
"snafu",
"sql",
"tokio",
]
[[package]]
name = "auto_ops"
version = "0.3.0"
@@ -1559,6 +1576,7 @@ version = "0.3.2"
dependencies = [
"anymap",
"async-trait",
"auth",
"catalog",
"chrono",
"clap 3.2.25",
@@ -3242,6 +3260,7 @@ dependencies = [
"async-compat",
"async-stream",
"async-trait",
"auth",
"catalog",
"chrono",
"client",
@@ -8824,6 +8843,7 @@ dependencies = [
"api",
"arrow-flight",
"async-trait",
"auth",
"axum",
"axum-macros",
"axum-test-helper",
@@ -8912,6 +8932,7 @@ name = "session"
version = "0.3.2"
dependencies = [
"arc-swap",
"auth",
"common-catalog",
"common-telemetry",
"common-time",
@@ -9933,6 +9954,7 @@ version = "0.3.2"
dependencies = [
"api",
"async-trait",
"auth",
"axum",
"axum-test-helper",
"catalog",

View File

@@ -2,6 +2,7 @@
members = [
"benchmarks",
"src/api",
"src/auth",
"src/catalog",
"src/client",
"src/cmd",
@@ -102,6 +103,7 @@ metrics = "0.20"
meter-core = { git = "https://github.com/GreptimeTeam/greptime-meter.git", rev = "abbd357c1e193cd270ea65ee7652334a150b628f" }
## workspaces members
api = { path = "src/api" }
auth = { path = "src/auth" }
catalog = { path = "src/catalog" }
client = { path = "src/client" }
cmd = { path = "src/cmd" }

26
src/auth/Cargo.toml Normal file
View File

@@ -0,0 +1,26 @@
[package]
name = "auth"
version.workspace = true
edition.workspace = true
license.workspace = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
default = []
testing = []
[dependencies]
api.workspace = true
async-trait.workspace = true
common-error.workspace = true
digest = "0.10"
hex = { version = "0.4" }
secrecy = { version = "0.8", features = ["serde", "alloc"] }
sha1 = "0.10"
snafu.workspace = true
sql.workspace = true
tokio.workspace = true
[dev-dependencies]
common-test-util.workspace = true

68
src/auth/src/common.rs Normal file
View File

@@ -0,0 +1,68 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use secrecy::SecretString;
use snafu::OptionExt;
use crate::error::{InvalidConfigSnafu, Result};
use crate::user_info::DefaultUserInfo;
use crate::user_provider::static_user_provider::{StaticUserProvider, STATIC_USER_PROVIDER};
use crate::{UserInfoRef, UserProviderRef};
pub(crate) const DEFAULT_USERNAME: &str = "greptime";
/// construct a [`UserInfo`] impl with name
/// use default username `greptime` if None is provided
pub fn userinfo_by_name(username: Option<String>) -> UserInfoRef {
DefaultUserInfo::with_name(username.unwrap_or_else(|| DEFAULT_USERNAME.to_string()))
}
pub fn user_provider_from_option(opt: &String) -> Result<UserProviderRef> {
let (name, content) = opt.split_once(':').context(InvalidConfigSnafu {
value: opt.to_string(),
msg: "UserProviderOption must be in format `<option>:<value>`",
})?;
match name {
STATIC_USER_PROVIDER => {
let provider =
StaticUserProvider::try_from(content).map(|p| Arc::new(p) as UserProviderRef)?;
Ok(provider)
}
_ => InvalidConfigSnafu {
value: name.to_string(),
msg: "Invalid UserProviderOption",
}
.fail(),
}
}
type Username<'a> = &'a str;
type HostOrIp<'a> = &'a str;
#[derive(Debug, Clone)]
pub enum Identity<'a> {
UserId(Username<'a>, Option<HostOrIp<'a>>),
}
pub type HashedPassword<'a> = &'a [u8];
pub type Salt<'a> = &'a [u8];
/// Authentication information sent by the client.
pub enum Password<'a> {
PlainText(SecretString),
MysqlNativePassword(HashedPassword<'a>, Salt<'a>),
PgMD5(HashedPassword<'a>, Salt<'a>),
}

View File

@@ -4,7 +4,7 @@
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
@@ -12,83 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use common_error::ext::{BoxedError, ErrorExt};
use common_error::status_code::StatusCode;
use secrecy::SecretString;
use session::context::UserInfo;
use snafu::{Location, OptionExt, Snafu};
use crate::auth::user_provider::StaticUserProvider;
pub mod user_provider;
#[async_trait::async_trait]
pub trait UserProvider: Send + Sync {
fn name(&self) -> &str;
/// [`authenticate`] checks whether a user is valid and allowed to access the database.
async fn authenticate(&self, id: Identity<'_>, password: Password<'_>) -> Result<UserInfo>;
/// [`authorize`] checks whether a connection request
/// from a certain user to a certain catalog/schema is legal.
/// This method should be called after [`authenticate`].
async fn authorize(&self, catalog: &str, schema: &str, user_info: &UserInfo) -> Result<()>;
/// [`auth`] is a combination of [`authenticate`] and [`authorize`].
/// In most cases it's preferred for both convenience and performance.
async fn auth(
&self,
id: Identity<'_>,
password: Password<'_>,
catalog: &str,
schema: &str,
) -> Result<UserInfo> {
let user_info = self.authenticate(id, password).await?;
self.authorize(catalog, schema, &user_info).await?;
Ok(user_info)
}
}
pub type UserProviderRef = Arc<dyn UserProvider>;
type Username<'a> = &'a str;
type HostOrIp<'a> = &'a str;
#[derive(Debug, Clone)]
pub enum Identity<'a> {
UserId(Username<'a>, Option<HostOrIp<'a>>),
}
pub type HashedPassword<'a> = &'a [u8];
pub type Salt<'a> = &'a [u8];
/// Authentication information sent by the client.
pub enum Password<'a> {
PlainText(SecretString),
MysqlNativePassword(HashedPassword<'a>, Salt<'a>),
PgMD5(HashedPassword<'a>, Salt<'a>),
}
pub fn user_provider_from_option(opt: &String) -> Result<UserProviderRef> {
let (name, content) = opt.split_once(':').context(InvalidConfigSnafu {
value: opt.to_string(),
msg: "UserProviderOption must be in format `<option>:<value>`",
})?;
match name {
user_provider::STATIC_USER_PROVIDER => {
let provider =
StaticUserProvider::try_from(content).map(|p| Arc::new(p) as UserProviderRef)?;
Ok(provider)
}
_ => InvalidConfigSnafu {
value: name.to_string(),
msg: "Invalid UserProviderOption",
}
.fail(),
}
}
use snafu::{Location, Snafu};
#[derive(Debug, Snafu)]
#[snafu(visibility(pub))]

33
src/auth/src/lib.rs Normal file
View File

@@ -0,0 +1,33 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
mod common;
mod error;
mod permission;
mod user_info;
mod user_provider;
#[cfg(feature = "testing")]
pub mod tests;
pub use common::{user_provider_from_option, userinfo_by_name, HashedPassword, Identity, Password};
pub use error::{Error, Result};
pub use permission::{PermissionChecker, PermissionReq, PermissionResp};
pub use user_info::UserInfo;
pub use user_provider::UserProvider;
/// pub type alias
pub type UserInfoRef = std::sync::Arc<dyn UserInfo>;
pub type UserProviderRef = std::sync::Arc<dyn UserProvider>;
pub type PermissionCheckerRef = std::sync::Arc<dyn PermissionChecker>;

View File

@@ -0,0 +1,60 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::fmt::Debug;
use api::v1::greptime_request::Request;
use sql::statements::statement::Statement;
use crate::error::Result;
use crate::{PermissionCheckerRef, UserInfoRef};
#[derive(Debug, Clone)]
pub enum PermissionReq<'a> {
GrpcRequest(&'a Request),
SqlStatement(&'a Statement),
PromQuery,
Opentsdb,
LineProtocol,
PromStoreWrite,
PromStoreRead,
Otlp,
}
#[derive(Debug)]
pub enum PermissionResp {
Allow,
Reject,
}
pub trait PermissionChecker: Send + Sync {
fn check_permission(
&self,
user_info: Option<UserInfoRef>,
req: PermissionReq,
) -> Result<PermissionResp>;
}
impl PermissionChecker for Option<&PermissionCheckerRef> {
fn check_permission(
&self,
user_info: Option<UserInfoRef>,
req: PermissionReq,
) -> Result<PermissionResp> {
match self {
Some(checker) => checker.check_permission(user_info, req),
None => Ok(PermissionResp::Allow),
}
}
}

View File

@@ -11,14 +11,17 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use secrecy::ExposeSecret;
use servers::auth::user_provider::auth_mysql;
use servers::auth::{
AccessDeniedSnafu, Identity, Password, UnsupportedPasswordTypeSnafu, UserNotFoundSnafu,
UserPasswordMismatchSnafu, UserProvider,
use crate::error::{
AccessDeniedSnafu, Result, UnsupportedPasswordTypeSnafu, UserNotFoundSnafu,
UserPasswordMismatchSnafu,
};
use session::context::UserInfo;
use crate::user_info::DefaultUserInfo;
use crate::user_provider::static_user_provider::auth_mysql;
#[allow(unused_imports)]
use crate::Error;
use crate::{Identity, Password, UserInfoRef, UserProvider};
pub struct DatabaseAuthInfo<'a> {
pub catalog: &'a str,
@@ -56,17 +59,13 @@ impl UserProvider for MockUserProvider {
"mock_user_provider"
}
async fn authenticate(
&self,
id: Identity<'_>,
password: Password<'_>,
) -> servers::auth::Result<UserInfo> {
async fn authenticate(&self, id: Identity<'_>, password: Password<'_>) -> Result<UserInfoRef> {
match id {
Identity::UserId(username, _host) => match password {
Password::PlainText(password) => {
if username == "greptime" {
if password.expose_secret() == "greptime" {
Ok(UserInfo::new("greptime"))
Ok(DefaultUserInfo::with_name("greptime"))
} else {
UserPasswordMismatchSnafu {
username: username.to_string(),
@@ -82,7 +81,7 @@ impl UserProvider for MockUserProvider {
}
Password::MysqlNativePassword(auth_data, salt) => {
auth_mysql(auth_data, salt, username, "greptime".as_bytes())
.map(|_| UserInfo::new(username))
.map(|_| DefaultUserInfo::with_name(username))
}
_ => UnsupportedPasswordTypeSnafu {
password_type: "mysql_native_password",
@@ -92,12 +91,7 @@ impl UserProvider for MockUserProvider {
}
}
async fn authorize(
&self,
catalog: &str,
schema: &str,
user_info: &UserInfo,
) -> servers::auth::Result<()> {
async fn authorize(&self, catalog: &str, schema: &str, user_info: &UserInfoRef) -> Result<()> {
if catalog == self.catalog && schema == self.schema && user_info.username() == self.username
{
Ok(())
@@ -137,7 +131,7 @@ async fn test_auth_by_plain_text() {
assert!(auth_result.is_err());
assert!(matches!(
auth_result.err().unwrap(),
servers::auth::Error::UnsupportedPasswordType { .. }
Error::UnsupportedPasswordType { .. }
));
// auth failed, err: user not exist.
@@ -150,7 +144,7 @@ async fn test_auth_by_plain_text() {
assert!(auth_result.is_err());
assert!(matches!(
auth_result.err().unwrap(),
servers::auth::Error::UserNotFound { .. }
Error::UserNotFound { .. }
));
// auth failed, err: wrong password
@@ -163,7 +157,7 @@ async fn test_auth_by_plain_text() {
assert!(auth_result.is_err());
assert!(matches!(
auth_result.err().unwrap(),
servers::auth::Error::UserPasswordMismatch { .. }
Error::UserPasswordMismatch { .. }
))
}
@@ -176,8 +170,8 @@ async fn test_schema_validate() {
username: "test_user",
});
let right_user = UserInfo::new("test_user");
let wrong_user = UserInfo::default();
let right_user = DefaultUserInfo::with_name("test_user");
let wrong_user = DefaultUserInfo::with_name("greptime");
// check catalog
let re = validator

47
src/auth/src/user_info.rs Normal file
View File

@@ -0,0 +1,47 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::any::Any;
use std::fmt::Debug;
use std::sync::Arc;
use crate::UserInfoRef;
pub trait UserInfo: Debug + Sync + Send {
fn as_any(&self) -> &dyn Any;
fn username(&self) -> &str;
}
#[derive(Debug)]
pub(crate) struct DefaultUserInfo {
username: String,
}
impl DefaultUserInfo {
pub(crate) fn with_name(username: impl Into<String>) -> UserInfoRef {
Arc::new(Self {
username: username.into(),
})
}
}
impl UserInfo for DefaultUserInfo {
fn as_any(&self) -> &dyn Any {
self
}
fn username(&self) -> &str {
self.username.as_str()
}
}

View File

@@ -0,0 +1,46 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub(crate) mod static_user_provider;
use crate::common::{Identity, Password};
use crate::error::Result;
use crate::UserInfoRef;
#[async_trait::async_trait]
pub trait UserProvider: Send + Sync {
fn name(&self) -> &str;
/// [`authenticate`] checks whether a user is valid and allowed to access the database.
async fn authenticate(&self, id: Identity<'_>, password: Password<'_>) -> Result<UserInfoRef>;
/// [`authorize`] checks whether a connection request
/// from a certain user to a certain catalog/schema is legal.
/// This method should be called after [`authenticate`].
async fn authorize(&self, catalog: &str, schema: &str, user_info: &UserInfoRef) -> Result<()>;
/// [`auth`] is a combination of [`authenticate`] and [`authorize`].
/// In most cases it's preferred for both convenience and performance.
async fn auth(
&self,
id: Identity<'_>,
password: Password<'_>,
catalog: &str,
schema: &str,
) -> Result<UserInfoRef> {
let user_info = self.authenticate(id, password).await?;
self.authorize(catalog, schema, &user_info).await?;
Ok(user_info)
}
}

View File

@@ -19,20 +19,20 @@ use std::io::BufRead;
use std::path::Path;
use async_trait::async_trait;
use digest;
use digest::Digest;
use secrecy::ExposeSecret;
use session::context::UserInfo;
use sha1::Sha1;
use snafu::{ensure, OptionExt, ResultExt};
use crate::auth::{
Error, HashedPassword, Identity, IllegalParamSnafu, InvalidConfigSnafu, IoSnafu, Password,
Result, Salt, UnsupportedPasswordTypeSnafu, UserNotFoundSnafu, UserPasswordMismatchSnafu,
UserProvider,
use crate::common::Salt;
use crate::error::{
Error, IllegalParamSnafu, InvalidConfigSnafu, IoSnafu, Result, UnsupportedPasswordTypeSnafu,
UserNotFoundSnafu, UserPasswordMismatchSnafu,
};
use crate::user_info::DefaultUserInfo;
use crate::{HashedPassword, Identity, Password, UserInfoRef, UserProvider};
pub const STATIC_USER_PROVIDER: &str = "static_user_provider";
pub(crate) const STATIC_USER_PROVIDER: &str = "static_user_provider";
impl TryFrom<&str> for StaticUserProvider {
type Error = Error;
@@ -91,7 +91,7 @@ impl TryFrom<&str> for StaticUserProvider {
}
}
pub struct StaticUserProvider {
pub(crate) struct StaticUserProvider {
users: HashMap<String, Vec<u8>>,
}
@@ -105,7 +105,7 @@ impl UserProvider for StaticUserProvider {
&self,
input_id: Identity<'_>,
input_pwd: Password<'_>,
) -> Result<UserInfo> {
) -> Result<UserInfoRef> {
match input_id {
Identity::UserId(username, _) => {
ensure!(
@@ -127,7 +127,7 @@ impl UserProvider for StaticUserProvider {
}
);
return if save_pwd == pwd.expose_secret().as_bytes() {
Ok(UserInfo::new(username))
Ok(DefaultUserInfo::with_name(username))
} else {
UserPasswordMismatchSnafu {
username: username.to_string(),
@@ -143,7 +143,7 @@ impl UserProvider for StaticUserProvider {
}
);
auth_mysql(auth_data, salt, username, save_pwd)
.map(|_| UserInfo::new(username))
.map(|_| DefaultUserInfo::with_name(username))
}
Password::PgMD5(_, _) => UnsupportedPasswordTypeSnafu {
password_type: "pg_md5",
@@ -154,7 +154,12 @@ impl UserProvider for StaticUserProvider {
}
}
async fn authorize(&self, _catalog: &str, _schema: &str, _user_info: &UserInfo) -> Result<()> {
async fn authorize(
&self,
_catalog: &str,
_schema: &str,
_user_info: &UserInfoRef,
) -> Result<()> {
// default allow all
Ok(())
}
@@ -208,10 +213,13 @@ pub mod test {
use std::io::{LineWriter, Write};
use common_test_util::temp_dir::create_temp_dir;
use session::context::UserInfo;
use crate::auth::user_provider::{double_sha1, sha1_one, sha1_two, StaticUserProvider};
use crate::auth::{Identity, Password, UserProvider};
use crate::user_info::DefaultUserInfo;
use crate::user_provider::static_user_provider::{
double_sha1, sha1_one, sha1_two, StaticUserProvider,
};
use crate::user_provider::{Identity, Password};
use crate::UserProvider;
#[test]
fn test_sha() {
@@ -249,9 +257,10 @@ pub mod test {
#[tokio::test]
async fn test_authorize() {
let user_info = DefaultUserInfo::with_name("root");
let provider = StaticUserProvider::try_from("cmd:root=123456,admin=654321").unwrap();
provider
.authorize("catalog", "schema", &UserInfo::new("root"))
.authorize("catalog", "schema", &user_info)
.await
.unwrap();
}

61
src/auth/tests/mod.rs Normal file
View File

@@ -0,0 +1,61 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![feature(assert_matches)]
use std::assert_matches::assert_matches;
use std::sync::Arc;
use api::v1::greptime_request::Request;
use auth::Error::InternalState;
use auth::{PermissionChecker, PermissionCheckerRef, PermissionReq, PermissionResp, UserInfoRef};
use sql::statements::show::{ShowDatabases, ShowKind};
use sql::statements::statement::Statement;
struct DummyPermissionChecker;
impl PermissionChecker for DummyPermissionChecker {
fn check_permission(
&self,
_user_info: Option<UserInfoRef>,
req: PermissionReq,
) -> auth::Result<PermissionResp> {
match req {
PermissionReq::GrpcRequest(_) => Ok(PermissionResp::Allow),
PermissionReq::SqlStatement(_) => Ok(PermissionResp::Reject),
_ => Err(InternalState {
msg: "testing".to_string(),
}),
}
}
}
#[test]
fn test_permission_checker() {
let checker: PermissionCheckerRef = Arc::new(DummyPermissionChecker);
let grpc_result = checker.check_permission(
None,
PermissionReq::GrpcRequest(&Request::Query(Default::default())),
);
assert_matches!(grpc_result, Ok(PermissionResp::Allow));
let sql_result = checker.check_permission(
None,
PermissionReq::SqlStatement(&Statement::ShowDatabases(ShowDatabases::new(ShowKind::All))),
);
assert_matches!(sql_result, Ok(PermissionResp::Reject));
let err_result = checker.check_permission(None, PermissionReq::Opentsdb);
assert_matches!(err_result, Err(InternalState { msg }) if msg == "testing");
}

View File

@@ -17,6 +17,7 @@ metrics-process = ["servers/metrics-process"]
[dependencies]
anymap = "1.0.0-beta.2"
async-trait.workspace = true
auth.workspace = true
catalog = { workspace = true }
chrono.workspace = true
clap = { version = "3.1", features = ["derive"] }

View File

@@ -80,7 +80,7 @@ pub enum Error {
#[snafu(display("Illegal auth config: {}", source))]
IllegalAuthConfig {
location: Location,
source: servers::auth::Error,
source: auth::Error,
},
#[snafu(display("Unsupported selector type, {} source: {}", selector_type, source))]

View File

@@ -14,6 +14,7 @@
use std::sync::Arc;
use auth::UserProviderRef;
use clap::Parser;
use common_base::Plugins;
use common_telemetry::logging;
@@ -21,9 +22,8 @@ use frontend::frontend::FrontendOptions;
use frontend::instance::{FrontendInstance, Instance as FeInstance};
use frontend::service_config::{InfluxdbOptions, PrometheusOptions};
use meta_client::MetaClientOptions;
use servers::auth::UserProviderRef;
use servers::tls::{TlsMode, TlsOption};
use servers::{auth, Mode};
use servers::Mode;
use snafu::ResultExt;
use crate::error::{self, IllegalAuthConfigSnafu, Result, StartCatalogManagerSnafu};
@@ -236,10 +236,10 @@ mod tests {
use std::io::Write;
use std::time::Duration;
use auth::{Identity, Password, UserProviderRef};
use common_base::readable_size::ReadableSize;
use common_test_util::temp_dir::create_named_temp_file;
use frontend::service_config::GrpcOptions;
use servers::auth::{Identity, Password, UserProviderRef};
use super::*;
use crate::options::ENV_VAR_SEP;

View File

@@ -345,9 +345,9 @@ mod tests {
use std::io::Write;
use std::time::Duration;
use auth::{Identity, Password, UserProviderRef};
use common_base::readable_size::ReadableSize;
use common_test_util::temp_dir::create_named_temp_file;
use servers::auth::{Identity, Password, UserProviderRef};
use servers::Mode;
use super::*;

View File

@@ -14,6 +14,7 @@ api = { workspace = true }
async-compat = "0.2"
async-stream.workspace = true
async-trait = "0.1"
auth.workspace = true
catalog = { workspace = true }
chrono.workspace = true
client = { workspace = true }

View File

@@ -578,6 +578,12 @@ pub enum Error {
source: common_meta::error::Error,
location: Location,
},
#[snafu(display("Failed to pass permission check, source: {}", source))]
Permission {
source: auth::Error,
location: Location,
},
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -603,6 +609,8 @@ impl ErrorExt for Error {
Error::NotSupported { .. } => StatusCode::Unsupported,
Error::Permission { source, .. } => source.status_code(),
Error::HandleHeartbeatResponse { source, .. }
| Error::TableMetadataManager { source, .. } => source.status_code(),

View File

@@ -31,6 +31,7 @@ use api::v1::greptime_request::Request;
use api::v1::meta::Role;
use api::v1::{AddColumns, AlterExpr, Column, DdlRequest, InsertRequest, InsertRequests};
use async_trait::async_trait;
use auth::{PermissionChecker, PermissionCheckerRef, PermissionReq};
use catalog::remote::CachedMetaKvBackend;
use catalog::CatalogManagerRef;
use client::client_manager::DatanodeClients;
@@ -57,7 +58,7 @@ use query::query_engine::options::{validate_catalog_and_schema, QueryOptions};
use query::query_engine::DescribeResult;
use query::{QueryEngineFactory, QueryEngineRef};
use servers::error as server_error;
use servers::error::{ExecuteQuerySnafu, ParsePromQLSnafu};
use servers::error::{AuthSnafu, ExecuteQuerySnafu, ParsePromQLSnafu};
use servers::interceptor::{
PromQueryInterceptor, PromQueryInterceptorRef, SqlQueryInterceptor, SqlQueryInterceptorRef,
};
@@ -79,8 +80,8 @@ use sqlparser::ast::ObjectName;
use crate::catalog::FrontendCatalogManager;
use crate::error::{
self, Error, ExecLogicalPlanSnafu, ExecutePromqlSnafu, ExternalSnafu,
InvalidInsertRequestSnafu, MissingMetasrvOptsSnafu, ParseSqlSnafu, PlanStatementSnafu, Result,
SqlExecInterceptedSnafu,
InvalidInsertRequestSnafu, MissingMetasrvOptsSnafu, ParseSqlSnafu, PermissionSnafu,
PlanStatementSnafu, Result, SqlExecInterceptedSnafu,
};
use crate::expr_factory::{CreateExprFactoryRef, DefaultCreateExprFactory};
use crate::frontend::FrontendOptions;
@@ -488,6 +489,9 @@ impl SqlQueryHandler for Instance {
Err(e) => return vec![Err(e)],
};
let checker_ref = self.plugins.get::<PermissionCheckerRef>();
let checker = checker_ref.as_ref();
match parse_stmt(query.as_ref(), query_ctx.sql_dialect())
.and_then(|stmts| query_interceptor.post_parsing(stmts, query_ctx.clone()))
{
@@ -501,6 +505,18 @@ impl SqlQueryHandler for Instance {
results.push(Err(e));
break;
}
if let Err(e) = checker
.check_permission(
query_ctx.current_user(),
PermissionReq::SqlStatement(&stmt),
)
.context(PermissionSnafu)
{
results.push(Err(e));
break;
}
match self.query_statement(stmt, query_ctx.clone()).await {
Ok(output) => {
let output_result =
@@ -523,6 +539,8 @@ impl SqlQueryHandler for Instance {
async fn do_exec_plan(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result<Output> {
let _timer = timer!(metrics::METRIC_EXEC_PLAN_ELAPSED);
// plan should be prepared before exec
// we'll do check there
self.query_engine
.execute(plan, query_ctx)
.await
@@ -534,6 +552,7 @@ impl SqlQueryHandler for Instance {
query: &PromQuery,
query_ctx: QueryContextRef,
) -> Vec<Result<Output>> {
// check will be done in prometheus handler's do_query
let result = PrometheusHandler::do_query(self, query, query_ctx)
.await
.with_context(|_| ExecutePromqlSnafu {
@@ -551,6 +570,12 @@ impl SqlQueryHandler for Instance {
stmt,
Statement::Insert(_) | Statement::Query(_) | Statement::Delete(_)
) {
self.plugins
.get::<PermissionCheckerRef>()
.as_ref()
.check_permission(query_ctx.current_user(), PermissionReq::SqlStatement(&stmt))
.context(PermissionSnafu)?;
let plan = self
.query_engine
.planner()
@@ -588,6 +613,12 @@ impl PrometheusHandler for Instance {
.get::<PromQueryInterceptorRef<server_error::Error>>();
interceptor.pre_execute(query, query_ctx.clone())?;
self.plugins
.get::<PermissionCheckerRef>()
.as_ref()
.check_permission(query_ctx.current_user(), PermissionReq::PromQuery)
.context(AuthSnafu)?;
let stmt = QueryLanguageParser::parse_promql(query).with_context(|_| ParsePromQLSnafu {
query: query.clone(),
})?;

View File

@@ -15,15 +15,16 @@
use api::v1::greptime_request::Request;
use api::v1::query_request::Query;
use async_trait::async_trait;
use auth::{PermissionChecker, PermissionCheckerRef, PermissionReq};
use common_query::Output;
use query::parser::PromQuery;
use servers::interceptor::{GrpcQueryInterceptor, GrpcQueryInterceptorRef};
use servers::query_handler::grpc::GrpcQueryHandler;
use servers::query_handler::sql::SqlQueryHandler;
use session::context::QueryContextRef;
use snafu::{ensure, OptionExt};
use snafu::{ensure, OptionExt, ResultExt};
use crate::error::{Error, IncompleteGrpcResultSnafu, NotSupportedSnafu, Result};
use crate::error::{Error, IncompleteGrpcResultSnafu, NotSupportedSnafu, PermissionSnafu, Result};
use crate::instance::Instance;
#[async_trait]
@@ -35,6 +36,12 @@ impl GrpcQueryHandler for Instance {
let interceptor = interceptor_ref.as_ref();
interceptor.pre_execute(&request, ctx.clone())?;
self.plugins
.get::<PermissionCheckerRef>()
.as_ref()
.check_permission(ctx.current_user(), PermissionReq::GrpcRequest(&request))
.context(PermissionSnafu)?;
let output = match request {
Request::Inserts(requests) => self.handle_inserts(requests, ctx.clone()).await?,
Request::Query(query_request) => {

View File

@@ -13,7 +13,9 @@
// limitations under the License.
use async_trait::async_trait;
use auth::{PermissionChecker, PermissionCheckerRef, PermissionReq};
use common_error::ext::BoxedError;
use servers::error::AuthSnafu;
use servers::influxdb::InfluxdbRequest;
use servers::query_handler::InfluxdbLineProtocolHandler;
use session::context::QueryContextRef;
@@ -28,6 +30,12 @@ impl InfluxdbLineProtocolHandler for Instance {
request: &InfluxdbRequest,
ctx: QueryContextRef,
) -> servers::error::Result<()> {
self.plugins
.get::<PermissionCheckerRef>()
.as_ref()
.check_permission(ctx.current_user(), PermissionReq::LineProtocol)
.context(AuthSnafu)?;
let requests = request.try_into()?;
let _ = self
.handle_inserts(requests, ctx)

View File

@@ -14,8 +14,10 @@
use api::v1::InsertRequests;
use async_trait::async_trait;
use auth::{PermissionChecker, PermissionCheckerRef, PermissionReq};
use common_error::ext::BoxedError;
use servers::error as server_error;
use servers::error::AuthSnafu;
use servers::opentsdb::codec::DataPoint;
use servers::query_handler::OpentsdbProtocolHandler;
use session::context::QueryContextRef;
@@ -26,6 +28,12 @@ use crate::instance::Instance;
#[async_trait]
impl OpentsdbProtocolHandler for Instance {
async fn exec(&self, data_point: &DataPoint, ctx: QueryContextRef) -> server_error::Result<()> {
self.plugins
.get::<PermissionCheckerRef>()
.as_ref()
.check_permission(ctx.current_user(), PermissionReq::Opentsdb)
.context(AuthSnafu)?;
let requests = InsertRequests {
inserts: vec![data_point.as_grpc_insert()],
};

View File

@@ -13,12 +13,13 @@
// limitations under the License.
use async_trait::async_trait;
use auth::{PermissionChecker, PermissionCheckerRef, PermissionReq};
use common_error::ext::BoxedError;
use metrics::counter;
use opentelemetry_proto::tonic::collector::metrics::v1::{
ExportMetricsServiceRequest, ExportMetricsServiceResponse,
};
use servers::error::{self, Result as ServerResult};
use servers::error::{self, AuthSnafu, Result as ServerResult};
use servers::otlp;
use servers::query_handler::OpenTelemetryProtocolHandler;
use session::context::QueryContextRef;
@@ -34,6 +35,11 @@ impl OpenTelemetryProtocolHandler for Instance {
request: ExportMetricsServiceRequest,
ctx: QueryContextRef,
) -> ServerResult<ExportMetricsServiceResponse> {
self.plugins
.get::<PermissionCheckerRef>()
.as_ref()
.check_permission(ctx.current_user(), PermissionReq::Otlp)
.context(AuthSnafu)?;
let (requests, rows) = otlp::to_grpc_insert_requests(request)?;
let _ = self
.handle_inserts(requests, ctx)

View File

@@ -15,6 +15,7 @@
use api::prom_store::remote::read_request::ResponseType;
use api::prom_store::remote::{Query, QueryResult, ReadRequest, ReadResponse, WriteRequest};
use async_trait::async_trait;
use auth::{PermissionChecker, PermissionCheckerRef, PermissionReq};
use common_catalog::format_full_table_name;
use common_error::ext::BoxedError;
use common_query::Output;
@@ -22,7 +23,7 @@ use common_recordbatch::RecordBatches;
use common_telemetry::logging;
use metrics::counter;
use prost::Message;
use servers::error::{self, Result as ServerResult};
use servers::error::{self, AuthSnafu, Result as ServerResult};
use servers::prom_store::{self, Metrics};
use servers::query_handler::{PromStoreProtocolHandler, PromStoreResponse};
use session::context::QueryContextRef;
@@ -148,6 +149,11 @@ impl Instance {
#[async_trait]
impl PromStoreProtocolHandler for Instance {
async fn write(&self, request: WriteRequest, ctx: QueryContextRef) -> ServerResult<()> {
self.plugins
.get::<PermissionCheckerRef>()
.as_ref()
.check_permission(ctx.current_user(), PermissionReq::PromStoreWrite)
.context(AuthSnafu)?;
let (requests, samples) = prom_store::to_grpc_insert_requests(request)?;
let _ = self
.handle_inserts(requests, ctx)
@@ -164,6 +170,12 @@ impl PromStoreProtocolHandler for Instance {
request: ReadRequest,
ctx: QueryContextRef,
) -> ServerResult<PromStoreResponse> {
self.plugins
.get::<PermissionCheckerRef>()
.as_ref()
.check_permission(ctx.current_user(), PermissionReq::PromStoreRead)
.context(AuthSnafu)?;
let response_type = negotiate_response_type(&request.accepted_response_types)?;
// TODO(dennis): use read_hints to speedup query if possible

View File

@@ -16,10 +16,10 @@ use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use auth::UserProviderRef;
use common_base::Plugins;
use common_runtime::Builder as RuntimeBuilder;
use common_telemetry::info;
use servers::auth::UserProviderRef;
use servers::configurator::ConfiguratorRef;
use servers::error::Error::InternalIo;
use servers::grpc::GrpcServer;

View File

@@ -11,51 +11,51 @@ arrow-schema.workspace = true
async-recursion = "1.0"
async-stream.workspace = true
async-trait = "0.1"
catalog = { workspace = true }
catalog.workspace = true
chrono.workspace = true
client = { workspace = true }
common-base = { workspace = true }
common-catalog = { workspace = true }
common-datasource = { workspace = true }
common-error = { workspace = true }
common-function = { workspace = true }
common-meta = { workspace = true }
common-query = { workspace = true }
common-recordbatch = { workspace = true }
common-telemetry = { workspace = true }
common-time = { workspace = true }
client.workspace = true
common-base.workspace = true
common-catalog.workspace = true
common-datasource.workspace = true
common-error.workspace = true
common-function.workspace = true
common-meta.workspace = true
common-query.workspace = true
common-recordbatch.workspace = true
common-telemetry.workspace = true
common-time.workspace = true
datafusion-common.workspace = true
datafusion-expr.workspace = true
datafusion-optimizer.workspace = true
datafusion-physical-expr.workspace = true
datafusion-sql.workspace = true
datafusion.workspace = true
datatypes = { workspace = true }
datatypes.workspace = true
futures = "0.3"
futures-util.workspace = true
greptime-proto.workspace = true
humantime = "2.1"
metrics.workspace = true
object-store = { workspace = true }
object-store.workspace = true
once_cell.workspace = true
partition = { workspace = true }
promql = { workspace = true }
partition.workspace = true
promql-parser = "0.1.1"
promql.workspace = true
regex.workspace = true
serde.workspace = true
serde_json = "1.0"
session = { workspace = true }
session.workspace = true
snafu = { version = "0.7", features = ["backtraces"] }
sql = { workspace = true }
substrait = { workspace = true }
table = { workspace = true }
sql.workspace = true
substrait.workspace = true
table.workspace = true
tokio.workspace = true
[dev-dependencies]
approx_eq = "0.1"
arrow.workspace = true
catalog = { workspace = true, features = ["testing"] }
common-function-macro = { workspace = true }
common-function-macro.workspace = true
format_num = "0.1"
num = "0.4"
num-traits = "0.2"
@@ -64,7 +64,7 @@ rand.workspace = true
session = { workspace = true, features = ["testing"] }
statrs = "0.16"
stats-cli = "3.0"
store-api = { workspace = true }
store-api.workspace = true
streaming-stats = "0.2"
table = { workspace = true, features = ["testing"] }
tokio-stream = "0.1"

View File

@@ -14,6 +14,7 @@ aide = { version = "0.9", features = ["axum"] }
api = { workspace = true }
arrow-flight.workspace = true
async-trait = "0.1"
auth.workspace = true
axum = { version = "0.6", features = ["headers"] }
axum-macros = "0.3.8"
base64 = "0.13"
@@ -34,7 +35,6 @@ common-time = { workspace = true }
datafusion-common.workspace = true
datafusion-expr.workspace = true
datafusion.workspace = true
datatypes = { workspace = true }
derive_builder.workspace = true
digest = "0.10"
@@ -79,7 +79,7 @@ serde.workspace = true
serde_json = "1.0"
session = { workspace = true }
sha1 = "0.10"
snafu = { version = "0.7", features = ["backtraces"] }
snafu.workspace = true
snap = "1"
sql = { workspace = true }
strum = { version = "0.24", features = ["derive"] }
@@ -96,6 +96,7 @@ tower-http = { version = "0.3", features = ["full"] }
tikv-jemalloc-ctl = { version = "0.5", features = ["use_std"] }
[dev-dependencies]
auth = { workspace = true, features = ["testing"] }
axum-test-helper = { git = "https://github.com/sunng87/axum-test-helper.git", branch = "patch-1" }
catalog = { workspace = true, features = ["testing"] }
client = { workspace = true }

View File

@@ -32,8 +32,6 @@ use tonic::codegen::http::{HeaderMap, HeaderValue};
use tonic::metadata::MetadataMap;
use tonic::Code;
use crate::auth;
#[derive(Debug, Snafu)]
#[snafu(visibility(pub))]
pub enum Error {
@@ -115,6 +113,9 @@ pub enum Error {
#[snafu(display("Not supported: {}", feat))]
NotSupported { feat: String },
#[snafu(display("Invalid request parameter: {}", reason))]
InvalidParameter { reason: String, location: Location },
#[snafu(display("Invalid query: {}", reason))]
InvalidQuery { reason: String, location: Location },
@@ -359,6 +360,7 @@ impl ErrorExt for Error {
| CheckDatabaseValidity { source, .. } => source.status_code(),
NotSupported { .. }
| InvalidParameter { .. }
| InvalidQuery { .. }
| InfluxdbLineProtocol { .. }
| ConnResetByPeer { .. }

View File

@@ -26,6 +26,7 @@ use api::v1::prometheus_gateway_server::{PrometheusGateway, PrometheusGatewaySer
use api::v1::{HealthCheckRequest, HealthCheckResponse};
use arrow_flight::flight_service_server::{FlightService, FlightServiceServer};
use async_trait::async_trait;
use auth::UserProviderRef;
use common_runtime::Runtime;
use common_telemetry::logging::info;
use common_telemetry::{error, warn};
@@ -38,7 +39,6 @@ use tokio_stream::wrappers::TcpListenerStream;
use tonic::{Request, Response, Status};
use self::prom_query_gateway::PrometheusGatewayService;
use crate::auth::UserProviderRef;
use crate::error::{
AlreadyStartedSnafu, GrpcReflectionServiceSnafu, InternalSnafu, Result, StartGrpcSnafu,
TcpBindSnafu,

View File

@@ -18,6 +18,7 @@ use std::time::Instant;
use api::helper::request_type;
use api::v1::auth_header::AuthScheme;
use api::v1::{Basic, GreptimeRequest, RequestHeader};
use auth::{Identity, Password, UserInfoRef, UserProviderRef};
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_catalog::parse_catalog_and_schema_from_db_string;
use common_error::ext::ErrorExt;
@@ -29,7 +30,6 @@ use metrics::{histogram, increment_counter};
use session::context::{QueryContextBuilder, QueryContextRef};
use snafu::{OptionExt, ResultExt};
use crate::auth::{Identity, Password, UserProviderRef};
use crate::error::Error::UnsupportedAuthScheme;
use crate::error::{
AuthSnafu, InvalidQuerySnafu, JoinTaskSnafu, NotFoundAuthHeaderSnafu, Result as InternalResult,
@@ -71,9 +71,12 @@ impl GreptimeRequestHandler {
let header = request.header.as_ref();
let query_ctx = create_query_context(header);
if let Err(e) = self.auth(header, &query_ctx).await? {
return Ok(Err(e));
}
match self.auth(header, &query_ctx).await? {
Err(e) => return Ok(Err(e)),
Ok(user_info) => {
query_ctx.set_current_user(user_info);
}
};
let handler = self.handler.clone();
let request_type = request_type(&query);
@@ -110,9 +113,9 @@ impl GreptimeRequestHandler {
&self,
header: Option<&RequestHeader>,
query_ctx: &QueryContextRef,
) -> TonicResult<InternalResult<()>> {
) -> TonicResult<InternalResult<Option<UserInfoRef>>> {
let Some(user_provider) = self.user_provider.as_ref() else {
return Ok(Ok(()));
return Ok(Ok(None));
};
let auth_scheme = header
@@ -138,7 +141,7 @@ impl GreptimeRequestHandler {
name: "Token AuthScheme".to_string(),
}),
}
.map(|_| ())
.map(Some)
.map_err(|e| {
increment_counter!(
METRIC_AUTH_FAILURE,

View File

@@ -34,6 +34,7 @@ use std::time::{Duration, Instant};
use aide::axum::{routing as apirouting, ApiRouter, IntoApiResponse};
use aide::openapi::{Info, OpenApi, Server as OpenAPIServer};
use async_trait::async_trait;
use auth::UserProviderRef;
use axum::body::BoxBody;
use axum::error_handling::HandleErrorLayer;
use axum::extract::{DefaultBodyLimit, MatchedPath};
@@ -65,7 +66,6 @@ use tower_http::trace::TraceLayer;
use self::authorize::HttpAuth;
use self::influxdb::{influxdb_health, influxdb_ping, influxdb_write_v1, influxdb_write_v2};
use crate::auth::UserProviderRef;
use crate::configurator::ConfiguratorRef;
use crate::error::{AlreadyStartedSnafu, Result, StartHttpSnafu};
use crate::http::admin::{compact, flush};

View File

@@ -14,6 +14,7 @@
use std::marker::PhantomData;
use ::auth::UserProviderRef;
use axum::http::{self, Request, StatusCode};
use axum::response::Response;
use common_catalog::parse_catalog_and_schema_from_db_string;
@@ -24,17 +25,14 @@ use headers::Header;
use http_body::Body;
use metrics::increment_counter;
use secrecy::SecretString;
use session::context::UserInfo;
use snafu::{ensure, IntoError, OptionExt, ResultExt};
use snafu::{ensure, OptionExt, ResultExt};
use tower_http::auth::AsyncAuthorizeRequest;
use super::header::GreptimeDbName;
use super::PUBLIC_APIS;
use crate::auth::Error::IllegalParam;
use crate::auth::{Identity, IllegalParamSnafu, UserProviderRef};
use crate::error::{
self, AuthSnafu, InvalidAuthorizationHeaderSnafu, InvisibleASCIISnafu, NotFoundInfluxAuthSnafu,
Result, UnsupportedAuthSchemeSnafu,
self, InvalidAuthorizationHeaderSnafu, InvalidParameterSnafu, InvisibleASCIISnafu,
NotFoundInfluxAuthSnafu, Result, UnsupportedAuthSchemeSnafu,
};
use crate::http::HTTP_API_PREFIX;
@@ -78,7 +76,9 @@ where
let user_provider = if let Some(user_provider) = user_provider.filter(|_| need_auth) {
user_provider
} else {
let _ = request.extensions_mut().insert(UserInfo::default());
let _ = request
.extensions_mut()
.insert(auth::userinfo_by_name(None));
return Ok(request);
};
@@ -114,8 +114,8 @@ where
match user_provider
.auth(
Identity::UserId(username.as_str(), None),
crate::auth::Password::PlainText(password),
::auth::Identity::UserId(username.as_str(), None),
::auth::Password::PlainText(password),
catalog,
schema,
)
@@ -143,7 +143,7 @@ where
fn extract_catalog_and_schema<B: Send + Sync + 'static>(
request: &Request<B>,
) -> crate::auth::Result<(&str, &str)> {
) -> Result<(&str, &str)> {
// parse database from header
let dbname = request
.headers()
@@ -154,8 +154,8 @@ fn extract_catalog_and_schema<B: Send + Sync + 'static>(
let query = request.uri().query().unwrap_or_default();
extract_db_from_query(query)
})
.context(IllegalParamSnafu {
msg: "db not provided or corrupted",
.context(InvalidParameterSnafu {
reason: "`db` must be provided in query string",
})?;
Ok(parse_catalog_and_schema_from_db_string(dbname))
@@ -193,9 +193,11 @@ fn get_influxdb_credentials<B: Send + Sync + 'static>(
(Some(username), Some(password)) => {
Ok(Some((username.to_string(), password.to_string().into())))
}
_ => Err(AuthSnafu.into_error(IllegalParam {
msg: "influxdb auth: username and password must be provided together".to_string(),
})),
_ => InvalidParameterSnafu {
reason: "influxdb auth: username and password must be provided together"
.to_string(),
}
.fail(),
}
}
}

View File

@@ -17,6 +17,7 @@ use std::env;
use std::time::Instant;
use aide::transform::TransformOperation;
use auth::UserInfoRef;
use axum::extract::{Json, Query, State};
use axum::response::{IntoResponse, Response};
use axum::{Extension, Form};
@@ -25,7 +26,6 @@ use common_telemetry::timer;
use query::parser::PromQuery;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use session::context::UserInfo;
use crate::http::{ApiState, GreptimeOptionsConfigState, JsonResponse};
use crate::metrics_handler::MetricsHandler;
@@ -42,7 +42,7 @@ pub async fn sql(
State(state): State<ApiState>,
Query(query_params): Query<SqlQuery>,
// TODO(fys): pass _user_info into query context
_user_info: Extension<UserInfo>,
user_info: Extension<UserInfoRef>,
Form(form_params): Form<SqlQuery>,
) -> Json<JsonResponse> {
let sql_handler = &state.sql_handler;
@@ -61,6 +61,7 @@ pub async fn sql(
let resp = if let Some(sql) = &sql {
match crate::http::query_context_from_db(sql_handler.clone(), db).await {
Ok(query_ctx) => {
query_ctx.set_current_user(Some(user_info.0));
JsonResponse::from_output(sql_handler.do_query(sql, query_ctx).await).await
}
Err(resp) => resp,
@@ -101,7 +102,7 @@ pub async fn promql(
State(state): State<ApiState>,
Query(params): Query<PromqlQuery>,
// TODO(fys): pass _user_info into query context
_user_info: Extension<UserInfo>,
user_info: Extension<UserInfoRef>,
) -> Json<JsonResponse> {
let sql_handler = &state.sql_handler;
let exec_start = Instant::now();
@@ -117,6 +118,7 @@ pub async fn promql(
let prom_query = params.into();
let resp = match super::query_context_from_db(sql_handler.clone(), db).await {
Ok(query_ctx) => {
query_ctx.set_current_user(Some(user_info.0));
JsonResponse::from_output(sql_handler.do_promql_query(&prom_query, query_ctx).await)
.await
}

View File

@@ -14,9 +14,11 @@
use std::collections::HashMap;
use auth::UserInfoRef;
use axum::extract::{Query, State};
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::Extension;
use common_catalog::consts::DEFAULT_SCHEMA_NAME;
use common_catalog::parse_catalog_and_schema_from_db_string;
use common_grpc::writer::Precision;
@@ -43,6 +45,7 @@ pub async fn influxdb_health() -> Result<impl IntoResponse> {
pub async fn influxdb_write_v1(
State(handler): State<InfluxdbLineProtocolHandlerRef>,
Query(mut params): Query<HashMap<String, String>>,
user_info: Extension<UserInfoRef>,
lines: String,
) -> Result<impl IntoResponse> {
let db = params
@@ -54,13 +57,14 @@ pub async fn influxdb_write_v1(
.map(|val| parse_time_precision(val))
.transpose()?;
influxdb_write(&db, precision, lines, handler).await
influxdb_write(&db, precision, lines, handler, user_info.0).await
}
#[axum_macros::debug_handler]
pub async fn influxdb_write_v2(
State(handler): State<InfluxdbLineProtocolHandlerRef>,
Query(mut params): Query<HashMap<String, String>>,
user_info: Extension<UserInfoRef>,
lines: String,
) -> Result<impl IntoResponse> {
let db = params
@@ -72,7 +76,7 @@ pub async fn influxdb_write_v2(
.map(|val| parse_time_precision(val))
.transpose()?;
influxdb_write(&db, precision, lines, handler).await
influxdb_write(&db, precision, lines, handler, user_info.0).await
}
pub async fn influxdb_write(
@@ -80,6 +84,7 @@ pub async fn influxdb_write(
precision: Option<Precision>,
lines: String,
handler: InfluxdbLineProtocolHandlerRef,
user_info: UserInfoRef,
) -> Result<impl IntoResponse> {
let _timer = timer!(
crate::metrics::METRIC_HTTP_INFLUXDB_WRITE_ELAPSED,
@@ -88,6 +93,7 @@ pub async fn influxdb_write(
let (catalog, schema) = parse_catalog_and_schema_from_db_string(db);
let ctx = QueryContext::with(catalog, schema);
ctx.set_current_user(Some(user_info));
let request = InfluxdbRequest { precision, lines };

View File

@@ -14,9 +14,10 @@
use std::collections::HashMap;
use auth::UserInfoRef;
use axum::extract::{Query, RawBody, State};
use axum::http::StatusCode as HttpStatusCode;
use axum::Json;
use axum::{Extension, Json};
use hyper::Body;
use serde::{Deserialize, Serialize};
use session::context::QueryContext;
@@ -77,12 +78,14 @@ pub enum OpentsdbPutResponse {
pub async fn put(
State(opentsdb_handler): State<OpentsdbProtocolHandlerRef>,
Query(params): Query<HashMap<String, String>>,
user_info: Extension<UserInfoRef>,
RawBody(body): RawBody,
) -> Result<(HttpStatusCode, Json<OpentsdbPutResponse>)> {
let summary = params.contains_key("summary");
let details = params.contains_key("details");
let ctx = QueryContext::with_db_name(params.get("db"));
ctx.set_current_user(Some(user_info.0));
let data_points = parse_data_points(body).await?;

View File

@@ -12,10 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use auth::UserInfoRef;
use axum::extract::{RawBody, State};
use axum::http::header;
use axum::response::IntoResponse;
use axum::TypedHeader;
use axum::{Extension, TypedHeader};
use common_telemetry::timer;
use hyper::Body;
use opentelemetry_proto::tonic::collector::metrics::v1::{
@@ -33,9 +34,11 @@ use crate::query_handler::OpenTelemetryProtocolHandlerRef;
pub async fn metrics(
State(handler): State<OpenTelemetryProtocolHandlerRef>,
TypedHeader(db): TypedHeader<GreptimeDbName>,
user_info: Extension<UserInfoRef>,
RawBody(body): RawBody,
) -> Result<OtlpResponse> {
let ctx = QueryContext::with_db_name(db.value());
ctx.set_current_user(Some(user_info.0));
let _timer = timer!(
crate::metrics::METRIC_HTTP_OPENTELEMETRY_ELAPSED,
&[(crate::metrics::METRIC_DB_LABEL, ctx.get_db_string())]

View File

@@ -20,7 +20,6 @@ use datatypes::schema::Schema;
use query::plan::LogicalPlan;
use serde::{Deserialize, Serialize};
pub mod auth;
pub mod configurator;
pub mod error;
pub mod grpc;

View File

@@ -17,6 +17,7 @@ use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use ::auth::{Identity, Password, UserProviderRef};
use async_trait::async_trait;
use chrono::{NaiveDate, NaiveDateTime};
use common_catalog::parse_catalog_and_schema_from_db_string;
@@ -41,7 +42,6 @@ use sql::parser::ParserContext;
use sql::statements::statement::Statement;
use tokio::io::AsyncWrite;
use crate::auth::{Identity, Password, UserProviderRef};
use crate::error::{self, InvalidPrepareStatementSnafu, Result};
use crate::mysql::helper::{
self, format_placeholder, replace_placeholders, transform_placeholders,
@@ -197,7 +197,7 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
}
};
}
let user_info = user_info.unwrap_or_default();
let user_info = user_info.unwrap_or_else(|| auth::userinfo_by_name(None));
self.session.set_user_info(user_info);

View File

@@ -17,6 +17,7 @@ use std::net::SocketAddr;
use std::sync::Arc;
use async_trait::async_trait;
use auth::UserProviderRef;
use common_runtime::Runtime;
use common_telemetry::logging::{info, warn};
use futures::StreamExt;
@@ -29,7 +30,6 @@ use tokio::io::BufWriter;
use tokio::net::TcpStream;
use tokio_rustls::rustls::ServerConfig;
use crate::auth::UserProviderRef;
use crate::error::{Error, Result};
use crate::mysql::handler::MysqlInstanceShim;
use crate::query_handler::sql::ServerSqlQueryHandlerRef;

View File

@@ -28,6 +28,7 @@ use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use ::auth::UserProviderRef;
use derive_builder::Builder;
use pgwire::api::auth::ServerParameterProvider;
use pgwire::api::store::MemPortalStore;
@@ -38,7 +39,6 @@ use session::Session;
use self::auth_handler::PgLoginVerifier;
use self::handler::DefaultQueryParser;
use crate::auth::UserProviderRef;
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
use crate::SqlPlan;

View File

@@ -15,6 +15,7 @@
use std::fmt::Debug;
use std::sync::Exclusive;
use ::auth::{userinfo_by_name, Identity, Password, UserInfoRef, UserProviderRef};
use async_trait::async_trait;
use common_catalog::parse_catalog_and_schema_from_db_string;
use common_error::ext::ErrorExt;
@@ -26,12 +27,10 @@ use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use pgwire::messages::response::ErrorResponse;
use pgwire::messages::startup::Authentication;
use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
use session::context::UserInfo;
use session::Session;
use snafu::IntoError;
use super::PostgresServerHandler;
use crate::auth::{Identity, Password, UserProviderRef};
use crate::error::{AuthSnafu, Result};
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
@@ -74,26 +73,26 @@ impl LoginInfo {
}
impl PgLoginVerifier {
async fn auth(&self, login: &LoginInfo, password: &str) -> Result<bool> {
async fn auth(&self, login: &LoginInfo, password: &str) -> Result<Option<UserInfoRef>> {
let user_provider = match &self.user_provider {
Some(provider) => provider,
None => return Ok(false),
None => return Ok(None),
};
let user_name = match &login.user {
Some(name) => name,
None => return Ok(false),
None => return Ok(None),
};
let catalog = match &login.catalog {
Some(name) => name,
None => return Ok(false),
None => return Ok(None),
};
let schema = match &login.schema {
Some(name) => name,
None => return Ok(false),
None => return Ok(None),
};
if let Err(e) = user_provider
match user_provider
.auth(
Identity::UserId(user_name, None),
Password::PlainText(password.to_string().into()),
@@ -102,16 +101,17 @@ impl PgLoginVerifier {
)
.await
{
increment_counter!(
crate::metrics::METRIC_AUTH_FAILURE,
&[(
crate::metrics::METRIC_CODE_LABEL,
format!("{}", e.status_code())
)]
);
Err(AuthSnafu.into_error(e))
} else {
Ok(true)
Err(e) => {
increment_counter!(
crate::metrics::METRIC_AUTH_FAILURE,
&[(
crate::metrics::METRIC_CODE_LABEL,
format!("{}", e.status_code())
)]
);
Err(AuthSnafu.into_error(e))
}
Ok(user_info) => Ok(Some(user_info)),
}
}
}
@@ -126,9 +126,7 @@ where
if let Some(current_schema) = client.metadata().get(super::METADATA_SCHEMA) {
session.set_schema(current_schema.clone());
}
if let Some(username) = client.metadata().get(super::METADATA_USER) {
session.set_user_info(UserInfo::new(username));
}
// set userinfo outside
}
#[async_trait]
@@ -174,6 +172,9 @@ impl StartupHandler for PostgresServerHandler {
))
.await?;
} else {
self.session.set_user_info(userinfo_by_name(
client.metadata().get(super::METADATA_USER).cloned(),
));
set_client_info(client, &self.session);
auth::finish_authentication(client, self.param_provider.as_ref()).await;
}
@@ -188,7 +189,12 @@ impl StartupHandler for PostgresServerHandler {
// do authenticate
let auth_result = self.login_verifier.auth(&login_info, pwd.password()).await;
if !matches!(auth_result, Ok(true)) {
if let Ok(Some(user_info)) = auth_result {
self.session.set_user_info(user_info);
set_client_info(client, &self.session);
auth::finish_authentication(client, self.param_provider.as_ref()).await;
} else {
return send_error(
client,
"FATAL",
@@ -197,8 +203,6 @@ impl StartupHandler for PostgresServerHandler {
)
.await;
}
set_client_info(client, &self.session);
auth::finish_authentication(client, self.param_provider.as_ref()).await;
}
_ => {}
}

View File

@@ -16,6 +16,7 @@ use std::future::Future;
use std::net::SocketAddr;
use std::sync::Arc;
use ::auth::UserProviderRef;
use async_trait::async_trait;
use common_runtime::Runtime;
use common_telemetry::logging::error;
@@ -26,7 +27,6 @@ use pgwire::tokio::process_socket;
use tokio_rustls::TlsAcceptor;
use super::{MakePostgresServerHandler, MakePostgresServerHandlerBuilder};
use crate::auth::UserProviderRef;
use crate::error::Result;
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
use crate::server::{AbortableStream, BaseTcpServer, Server};

View File

@@ -17,6 +17,7 @@ use std::collections::{BTreeMap, HashMap, HashSet};
use std::net::SocketAddr;
use std::sync::Arc;
use ::auth::UserProviderRef;
use async_trait::async_trait;
use axum::body::BoxBody;
use axum::extract::{Path, Query, State};
@@ -52,7 +53,6 @@ use tower_http::auth::AsyncRequireAuthorizationLayer;
use tower_http::compression::CompressionLayer;
use tower_http::trace::TraceLayer;
use crate::auth::UserProviderRef;
use crate::error::{
AlreadyStartedSnafu, CollectRecordbatchSnafu, Error, InternalSnafu, InvalidQuerySnafu, Result,
StartHttpSnafu, UnexpectedResultSnafu,

View File

@@ -19,9 +19,10 @@ use api::v1::auth_header::AuthScheme;
use api::v1::Basic;
use arrow_flight::flight_service_server::{FlightService, FlightServiceServer};
use async_trait::async_trait;
use auth::tests::MockUserProvider;
use auth::UserProviderRef;
use client::{Client, Database, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_runtime::{Builder as RuntimeBuilder, Runtime};
use servers::auth::UserProviderRef;
use servers::error::{Result, StartGrpcSnafu, TcpBindSnafu};
use servers::grpc::flight::FlightHandler;
use servers::grpc::handler::GreptimeRequestHandler;
@@ -32,7 +33,6 @@ use table::test_util::MemTable;
use tokio::net::TcpListener;
use tokio_stream::wrappers::TcpListenerStream;
use crate::auth::MockUserProvider;
use crate::{create_testing_grpc_query_handler, LOCALHOST_WITH_0};
struct MockGrpcServer {

View File

@@ -14,16 +14,14 @@
use std::sync::Arc;
use auth::tests::MockUserProvider;
use auth::{UserInfoRef, UserProvider};
use axum::body::BoxBody;
use axum::http;
use hyper::Request;
use servers::auth::UserProvider;
use servers::http::authorize::HttpAuth;
use session::context::UserInfo;
use tower_http::auth::AsyncAuthorizeRequest;
use crate::auth::MockUserProvider;
#[tokio::test]
async fn test_http_auth() {
let mut http_auth: HttpAuth<BoxBody> = HttpAuth::new(None);
@@ -31,8 +29,8 @@ async fn test_http_auth() {
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
let auth_res = http_auth.authorize(req).await.unwrap();
let user_info: &UserInfo = auth_res.extensions().get().unwrap();
let default = UserInfo::default();
let user_info: &UserInfoRef = auth_res.extensions().get().unwrap();
let default = auth::userinfo_by_name(None);
assert_eq!(default.username(), user_info.username());
// In mock user provider, right username:password == "greptime:greptime"
@@ -42,8 +40,8 @@ async fn test_http_auth() {
// base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU="
let req = mock_http_request(Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="), None).unwrap();
let req = http_auth.authorize(req).await.unwrap();
let user_info: &UserInfo = req.extensions().get().unwrap();
let default = UserInfo::default();
let user_info: &UserInfoRef = req.extensions().get().unwrap();
let default = auth::userinfo_by_name(None);
assert_eq!(default.username(), user_info.username());
let req = mock_http_request(None, None).unwrap();
@@ -72,8 +70,8 @@ async fn test_schema_validating() {
)
.unwrap();
let req = http_auth.authorize(req).await.unwrap();
let user_info: &UserInfo = req.extensions().get().unwrap();
let default = UserInfo::default();
let user_info: &UserInfoRef = req.extensions().get().unwrap();
let default = auth::userinfo_by_name(None);
assert_eq!(default.username(), user_info.username());
// wrong database

View File

@@ -26,7 +26,6 @@ use servers::http::{
JsonOutput,
};
use servers::metrics_handler::MetricsHandler;
use session::context::UserInfo;
use table::test_util::MemTable;
use crate::{
@@ -43,7 +42,7 @@ async fn test_sql_not_provided() {
script_handler: None,
}),
Query(http_handler::SqlQuery::default()),
axum::Extension(UserInfo::default()),
axum::Extension(auth::userinfo_by_name(None)),
Form(http_handler::SqlQuery::default()),
)
.await;
@@ -68,7 +67,7 @@ async fn test_sql_output_rows() {
script_handler: None,
}),
query,
axum::Extension(UserInfo::default()),
axum::Extension(auth::userinfo_by_name(None)),
Form(http_handler::SqlQuery::default()),
)
.await;
@@ -114,7 +113,7 @@ async fn test_sql_form() {
script_handler: None,
}),
Query(http_handler::SqlQuery::default()),
axum::Extension(UserInfo::default()),
axum::Extension(auth::userinfo_by_name(None)),
form,
)
.await;

View File

@@ -17,6 +17,7 @@ use std::sync::Arc;
use api::v1::greptime_request::Request;
use api::v1::InsertRequests;
use async_trait::async_trait;
use auth::tests::{DatabaseAuthInfo, MockUserProvider};
use axum::{http, Router};
use axum_test_helper::TestClient;
use common_query::Output;
@@ -33,8 +34,6 @@ use servers::query_handler::InfluxdbLineProtocolHandler;
use session::context::QueryContextRef;
use tokio::sync::mpsc;
use crate::auth::{DatabaseAuthInfo, MockUserProvider};
struct DummyInstance {
tx: Arc<mpsc::Sender<(String, String)>>,
}

View File

@@ -36,7 +36,6 @@ use snafu::ensure;
use sql::statements::statement::Statement;
use table::test_util::MemTable;
mod auth;
mod grpc;
mod http;
mod interceptor;

View File

@@ -16,6 +16,7 @@ use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use auth::tests::{DatabaseAuthInfo, MockUserProvider};
use common_catalog::consts::DEFAULT_SCHEMA_NAME;
use common_recordbatch::RecordBatch;
use common_runtime::Builder as RuntimeBuilder;
@@ -32,7 +33,6 @@ use servers::server::Server;
use servers::tls::TlsOption;
use table::test_util::MemTable;
use crate::auth::{DatabaseAuthInfo, MockUserProvider};
use crate::create_testing_sql_query_handler;
use crate::mysql::{all_datatype_testing_data, MysqlTextRow, TestingData};

View File

@@ -16,6 +16,8 @@ use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use auth::tests::{DatabaseAuthInfo, MockUserProvider};
use auth::UserProviderRef;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_runtime::Builder as RuntimeBuilder;
use pgwire::api::Type;
@@ -23,7 +25,6 @@ use rand::rngs::StdRng;
use rand::Rng;
use rustls::client::{ServerCertVerified, ServerCertVerifier};
use rustls::{Certificate, Error, ServerName};
use servers::auth::UserProviderRef;
use servers::error::Result;
use servers::postgres::PostgresServer;
use servers::server::Server;
@@ -31,7 +32,6 @@ use servers::tls::TlsOption;
use table::test_util::MemTable;
use tokio_postgres::{Client, Error as PgError, NoTls, SimpleQueryMessage};
use crate::auth::{DatabaseAuthInfo, MockUserProvider};
use crate::create_testing_instance;
fn create_postgres_server(

View File

@@ -9,6 +9,7 @@ testing = []
[dependencies]
arc-swap = "1.5"
auth.workspace = true
common-catalog = { workspace = true }
common-telemetry = { workspace = true }
common-time = { workspace = true }

View File

@@ -17,6 +17,7 @@ use std::net::SocketAddr;
use std::sync::Arc;
use arc_swap::ArcSwap;
use auth::UserInfoRef;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_catalog::{build_db_string, parse_catalog_and_schema_from_db_string};
use common_time::TimeZone;
@@ -32,6 +33,7 @@ pub type ConnInfoRef = Arc<ConnInfo>;
pub struct QueryContext {
current_catalog: String,
current_schema: String,
current_user: ArcSwap<Option<UserInfoRef>>,
time_zone: ArcSwap<Option<TimeZone>>,
sql_dialect: Box<dyn Dialect + Send + Sync>,
trace_id: u64,
@@ -109,6 +111,16 @@ impl QueryContext {
let _ = self.time_zone.swap(Arc::new(tz));
}
#[inline]
pub fn current_user(&self) -> Option<UserInfoRef> {
self.current_user.load().as_ref().clone()
}
#[inline]
pub fn set_current_user(&self, user: Option<UserInfoRef>) {
let _ = self.current_user.swap(Arc::new(user));
}
#[inline]
pub fn trace_id(&self) -> u64 {
self.trace_id
@@ -124,6 +136,9 @@ impl QueryContextBuilder {
current_schema: self
.current_schema
.unwrap_or_else(|| DEFAULT_SCHEMA_NAME.to_string()),
current_user: self
.current_user
.unwrap_or_else(|| ArcSwap::new(Arc::new(None))),
time_zone: self
.time_zone
.unwrap_or_else(|| ArcSwap::new(Arc::new(None))),
@@ -140,33 +155,6 @@ impl QueryContextBuilder {
}
}
pub const DEFAULT_USERNAME: &str = "greptime";
#[derive(Clone, Debug)]
pub struct UserInfo {
username: String,
}
impl Default for UserInfo {
fn default() -> Self {
Self {
username: DEFAULT_USERNAME.to_string(),
}
}
}
impl UserInfo {
pub fn username(&self) -> &str {
self.username.as_str()
}
pub fn new(username: impl Into<String>) -> Self {
Self {
username: username.into(),
}
}
}
#[derive(Debug)]
pub struct ConnInfo {
pub client_addr: Option<SocketAddr>,
@@ -225,7 +213,7 @@ mod test {
use common_catalog::consts::DEFAULT_CATALOG_NAME;
use super::*;
use crate::context::{Channel, UserInfo};
use crate::context::Channel;
use crate::Session;
#[test]
@@ -233,8 +221,6 @@ mod test {
let session = Session::new(Some("127.0.0.1:9000".parse().unwrap()), Channel::Mysql);
// test user_info
assert_eq!(session.user_info().username(), "greptime");
session.set_user_info(UserInfo::new("root"));
assert_eq!(session.user_info().username(), "root");
// test channel
assert_eq!(session.conn_info().channel, Channel::Mysql);

View File

@@ -18,18 +18,19 @@ use std::net::SocketAddr;
use std::sync::Arc;
use arc_swap::ArcSwap;
use auth::UserInfoRef;
use common_catalog::build_db_string;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use context::QueryContextBuilder;
use crate::context::{Channel, ConnInfo, QueryContextRef, UserInfo};
use crate::context::{Channel, ConnInfo, QueryContextRef};
/// Session for persistent connection such as MySQL, PostgreSQL etc.
#[derive(Debug)]
pub struct Session {
catalog: ArcSwap<String>,
schema: ArcSwap<String>,
user_info: ArcSwap<UserInfo>,
user_info: ArcSwap<UserInfoRef>,
conn_info: ConnInfo,
}
@@ -40,7 +41,7 @@ impl Session {
Session {
catalog: ArcSwap::new(Arc::new(DEFAULT_CATALOG_NAME.into())),
schema: ArcSwap::new(Arc::new(DEFAULT_SCHEMA_NAME.into())),
user_info: ArcSwap::new(Arc::new(UserInfo::default())),
user_info: ArcSwap::new(Arc::new(auth::userinfo_by_name(None))),
conn_info: ConnInfo::new(addr, channel),
}
}
@@ -48,6 +49,9 @@ impl Session {
#[inline]
pub fn new_query_context(&self) -> QueryContextRef {
QueryContextBuilder::default()
.current_user(ArcSwap::new(Arc::new(Some(
self.user_info.load().as_ref().clone(),
))))
.current_catalog(self.catalog.load().to_string())
.current_schema(self.schema.load().to_string())
.sql_dialect(self.conn_info.channel.dialect())
@@ -65,12 +69,12 @@ impl Session {
}
#[inline]
pub fn user_info(&self) -> Arc<UserInfo> {
self.user_info.load().clone()
pub fn user_info(&self) -> UserInfoRef {
self.user_info.load().clone().as_ref().clone()
}
#[inline]
pub fn set_user_info(&self, user_info: UserInfo) {
pub fn set_user_info(&self, user_info: UserInfoRef) {
self.user_info.store(Arc::new(user_info));
}

View File

@@ -10,6 +10,7 @@ dashboard = []
[dependencies]
api = { workspace = true }
async-trait = "0.1"
auth.workspace = true
axum = "0.6"
axum-test-helper = { git = "https://github.com/sunng87/axum-test-helper.git", branch = "patch-1" }
catalog = { workspace = true }

View File

@@ -19,6 +19,7 @@ use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use auth::UserProviderRef;
use axum::Router;
use catalog::{CatalogManagerRef, RegisterTableRequest};
use common_catalog::consts::{
@@ -48,7 +49,6 @@ use object_store::services::{Azblob, Gcs, Oss, S3};
use object_store::test_util::TempFolder;
use object_store::ObjectStore;
use secrecy::ExposeSecret;
use servers::auth::UserProviderRef;
use servers::grpc::GrpcServer;
use servers::http::{HttpOptions, HttpServerBuilder};
use servers::metrics_handler::MetricsHandler;

View File

@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use api::v1::alter_expr::Kind;
use api::v1::promql_request::Promql;
use api::v1::{
@@ -21,10 +19,10 @@ use api::v1::{
CreateTableExpr, InsertRequest, InsertRequests, PromInstantQuery, PromRangeQuery,
PromqlRequest, RequestHeader, SemanticType, TableId,
};
use auth::user_provider_from_option;
use client::{Client, Database, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_catalog::consts::{MIN_USER_TABLE_ID, MITO_ENGINE};
use common_query::Output;
use servers::auth::user_provider::StaticUserProvider;
use servers::prometheus::{PromData, PromSeries, PrometheusJsonResponse, PrometheusResponse};
use servers::server::Server;
use tests_integration::test_util::{
@@ -118,14 +116,13 @@ pub async fn test_dbname(store_type: StorageType) {
}
pub async fn test_grpc_auth(store_type: StorageType) {
let user_provider = StaticUserProvider::try_from("cmd:greptime_user=greptime_pwd").unwrap();
let (addr, mut guard, fe_grpc_server) = setup_grpc_server_with_user_provider(
store_type,
"auto_create_table",
Some(Arc::new(user_provider)),
let user_provider = user_provider_from_option(
&"static_user_provider:cmd:greptime_user=greptime_pwd".to_string(),
)
.await;
.unwrap();
let (addr, mut guard, fe_grpc_server) =
setup_grpc_server_with_user_provider(store_type, "auto_create_table", Some(user_provider))
.await;
let grpc_client = Client::with_urls(vec![addr]);
let mut db = Database::new_with_dbname(

View File

@@ -13,13 +13,12 @@
// limitations under the License.
use std::collections::BTreeMap;
use std::sync::Arc;
use auth::user_provider_from_option;
use axum::http::StatusCode;
use axum_test_helper::TestClient;
use common_error::status_code::StatusCode as ErrorCode;
use serde_json::json;
use servers::auth::user_provider::StaticUserProvider;
use servers::http::handler::HealthResponse;
use servers::http::{JsonOutput, JsonResponse};
use servers::prometheus::{PrometheusJsonResponse, PrometheusResponse};
@@ -76,12 +75,15 @@ macro_rules! http_tests {
pub async fn test_http_auth(store_type: StorageType) {
common_telemetry::init_default_ut_logging();
let user_provider = StaticUserProvider::try_from("cmd:greptime_user=greptime_pwd").unwrap();
let user_provider = user_provider_from_option(
&"static_user_provider:cmd:greptime_user=greptime_pwd".to_string(),
)
.unwrap();
let (app, mut guard) = setup_test_http_app_with_frontend_and_user_provider(
store_type,
"sql_api",
Some(Arc::new(user_provider)),
Some(user_provider),
)
.await;
let client = TestClient::new(app);

View File

@@ -11,10 +11,9 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use auth::user_provider_from_option;
use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc};
use servers::auth::user_provider::StaticUserProvider;
use sqlx::mysql::{MySqlDatabaseError, MySqlPoolOptions};
use sqlx::postgres::{PgDatabaseError, PgPoolOptions};
use sqlx::Row;
@@ -65,13 +64,13 @@ macro_rules! sql_tests {
}
pub async fn test_mysql_auth(store_type: StorageType) {
let user_provider = StaticUserProvider::try_from("cmd:greptime_user=greptime_pwd").unwrap();
let (addr, mut guard, fe_mysql_server) = setup_mysql_server_with_user_provider(
store_type,
"sql_crud",
Some(Arc::new(user_provider)),
let user_provider = user_provider_from_option(
&"static_user_provider:cmd:greptime_user=greptime_pwd".to_string(),
)
.await;
.unwrap();
let (addr, mut guard, fe_mysql_server) =
setup_mysql_server_with_user_provider(store_type, "sql_crud", Some(user_provider)).await;
// 1. no auth
let conn_re = MySqlPoolOptions::new()
@@ -204,10 +203,13 @@ pub async fn test_mysql_crud(store_type: StorageType) {
}
pub async fn test_postgres_auth(store_type: StorageType) {
let user_provider = StaticUserProvider::try_from("cmd:greptime_user=greptime_pwd").unwrap();
let user_provider = user_provider_from_option(
&"static_user_provider:cmd:greptime_user=greptime_pwd".to_string(),
)
.unwrap();
let (addr, mut guard, fe_pg_server) =
setup_pg_server_with_user_provider(store_type, "sql_crud", Some(Arc::new(user_provider)))
.await;
setup_pg_server_with_user_provider(store_type, "sql_crud", Some(user_provider)).await;
// 1. no auth
let conn_re = PgPoolOptions::new()