mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-05 21:02:58 +00:00
refactor: auth crate (#2148)
* chore: move user_info to auth crate * chore: temp commit before resolving tests compile error * chore: fix compile issue * chore: minor fix * chore: tmp save * chore: change user_info to trait * chore: minor change & use auth result user info in pg session setup * chore: add as_any to user_info * chore: rename user_info * chore: remove ice file * chore: add permission checker * chore: add grpc permission check * chore: add session spawn user_info to query_ctx * chore: minor update * chore: add permission checker to sql handler & temp save * chore: add permission checker to prometheus handler * chore: add permission checker to opentsdb handler * chore: add permission checker to other handlers * chore: add test * chore: add user_info setting on http entrance * chore: fix toml * chore: remove box in permission req * chore: cr issue * chore: cr issue
This commit is contained in:
22
Cargo.lock
generated
22
Cargo.lock
generated
@@ -672,6 +672,23 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "auth"
|
||||
version = "0.3.2"
|
||||
dependencies = [
|
||||
"api",
|
||||
"async-trait",
|
||||
"common-error",
|
||||
"common-test-util",
|
||||
"digest",
|
||||
"hex",
|
||||
"secrecy",
|
||||
"sha1",
|
||||
"snafu",
|
||||
"sql",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "auto_ops"
|
||||
version = "0.3.0"
|
||||
@@ -1559,6 +1576,7 @@ version = "0.3.2"
|
||||
dependencies = [
|
||||
"anymap",
|
||||
"async-trait",
|
||||
"auth",
|
||||
"catalog",
|
||||
"chrono",
|
||||
"clap 3.2.25",
|
||||
@@ -3242,6 +3260,7 @@ dependencies = [
|
||||
"async-compat",
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
"auth",
|
||||
"catalog",
|
||||
"chrono",
|
||||
"client",
|
||||
@@ -8824,6 +8843,7 @@ dependencies = [
|
||||
"api",
|
||||
"arrow-flight",
|
||||
"async-trait",
|
||||
"auth",
|
||||
"axum",
|
||||
"axum-macros",
|
||||
"axum-test-helper",
|
||||
@@ -8912,6 +8932,7 @@ name = "session"
|
||||
version = "0.3.2"
|
||||
dependencies = [
|
||||
"arc-swap",
|
||||
"auth",
|
||||
"common-catalog",
|
||||
"common-telemetry",
|
||||
"common-time",
|
||||
@@ -9933,6 +9954,7 @@ version = "0.3.2"
|
||||
dependencies = [
|
||||
"api",
|
||||
"async-trait",
|
||||
"auth",
|
||||
"axum",
|
||||
"axum-test-helper",
|
||||
"catalog",
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
members = [
|
||||
"benchmarks",
|
||||
"src/api",
|
||||
"src/auth",
|
||||
"src/catalog",
|
||||
"src/client",
|
||||
"src/cmd",
|
||||
@@ -102,6 +103,7 @@ metrics = "0.20"
|
||||
meter-core = { git = "https://github.com/GreptimeTeam/greptime-meter.git", rev = "abbd357c1e193cd270ea65ee7652334a150b628f" }
|
||||
## workspaces members
|
||||
api = { path = "src/api" }
|
||||
auth = { path = "src/auth" }
|
||||
catalog = { path = "src/catalog" }
|
||||
client = { path = "src/client" }
|
||||
cmd = { path = "src/cmd" }
|
||||
|
||||
26
src/auth/Cargo.toml
Normal file
26
src/auth/Cargo.toml
Normal file
@@ -0,0 +1,26 @@
|
||||
[package]
|
||||
name = "auth"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[features]
|
||||
default = []
|
||||
testing = []
|
||||
|
||||
[dependencies]
|
||||
api.workspace = true
|
||||
async-trait.workspace = true
|
||||
common-error.workspace = true
|
||||
digest = "0.10"
|
||||
hex = { version = "0.4" }
|
||||
secrecy = { version = "0.8", features = ["serde", "alloc"] }
|
||||
sha1 = "0.10"
|
||||
snafu.workspace = true
|
||||
sql.workspace = true
|
||||
tokio.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
common-test-util.workspace = true
|
||||
68
src/auth/src/common.rs
Normal file
68
src/auth/src/common.rs
Normal file
@@ -0,0 +1,68 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use secrecy::SecretString;
|
||||
use snafu::OptionExt;
|
||||
|
||||
use crate::error::{InvalidConfigSnafu, Result};
|
||||
use crate::user_info::DefaultUserInfo;
|
||||
use crate::user_provider::static_user_provider::{StaticUserProvider, STATIC_USER_PROVIDER};
|
||||
use crate::{UserInfoRef, UserProviderRef};
|
||||
|
||||
pub(crate) const DEFAULT_USERNAME: &str = "greptime";
|
||||
|
||||
/// construct a [`UserInfo`] impl with name
|
||||
/// use default username `greptime` if None is provided
|
||||
pub fn userinfo_by_name(username: Option<String>) -> UserInfoRef {
|
||||
DefaultUserInfo::with_name(username.unwrap_or_else(|| DEFAULT_USERNAME.to_string()))
|
||||
}
|
||||
|
||||
pub fn user_provider_from_option(opt: &String) -> Result<UserProviderRef> {
|
||||
let (name, content) = opt.split_once(':').context(InvalidConfigSnafu {
|
||||
value: opt.to_string(),
|
||||
msg: "UserProviderOption must be in format `<option>:<value>`",
|
||||
})?;
|
||||
match name {
|
||||
STATIC_USER_PROVIDER => {
|
||||
let provider =
|
||||
StaticUserProvider::try_from(content).map(|p| Arc::new(p) as UserProviderRef)?;
|
||||
Ok(provider)
|
||||
}
|
||||
_ => InvalidConfigSnafu {
|
||||
value: name.to_string(),
|
||||
msg: "Invalid UserProviderOption",
|
||||
}
|
||||
.fail(),
|
||||
}
|
||||
}
|
||||
|
||||
type Username<'a> = &'a str;
|
||||
type HostOrIp<'a> = &'a str;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Identity<'a> {
|
||||
UserId(Username<'a>, Option<HostOrIp<'a>>),
|
||||
}
|
||||
|
||||
pub type HashedPassword<'a> = &'a [u8];
|
||||
pub type Salt<'a> = &'a [u8];
|
||||
|
||||
/// Authentication information sent by the client.
|
||||
pub enum Password<'a> {
|
||||
PlainText(SecretString),
|
||||
MysqlNativePassword(HashedPassword<'a>, Salt<'a>),
|
||||
PgMD5(HashedPassword<'a>, Salt<'a>),
|
||||
}
|
||||
@@ -4,7 +4,7 @@
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@@ -12,83 +12,9 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_error::ext::{BoxedError, ErrorExt};
|
||||
use common_error::status_code::StatusCode;
|
||||
use secrecy::SecretString;
|
||||
use session::context::UserInfo;
|
||||
use snafu::{Location, OptionExt, Snafu};
|
||||
|
||||
use crate::auth::user_provider::StaticUserProvider;
|
||||
|
||||
pub mod user_provider;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait UserProvider: Send + Sync {
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// [`authenticate`] checks whether a user is valid and allowed to access the database.
|
||||
async fn authenticate(&self, id: Identity<'_>, password: Password<'_>) -> Result<UserInfo>;
|
||||
|
||||
/// [`authorize`] checks whether a connection request
|
||||
/// from a certain user to a certain catalog/schema is legal.
|
||||
/// This method should be called after [`authenticate`].
|
||||
async fn authorize(&self, catalog: &str, schema: &str, user_info: &UserInfo) -> Result<()>;
|
||||
|
||||
/// [`auth`] is a combination of [`authenticate`] and [`authorize`].
|
||||
/// In most cases it's preferred for both convenience and performance.
|
||||
async fn auth(
|
||||
&self,
|
||||
id: Identity<'_>,
|
||||
password: Password<'_>,
|
||||
catalog: &str,
|
||||
schema: &str,
|
||||
) -> Result<UserInfo> {
|
||||
let user_info = self.authenticate(id, password).await?;
|
||||
self.authorize(catalog, schema, &user_info).await?;
|
||||
Ok(user_info)
|
||||
}
|
||||
}
|
||||
|
||||
pub type UserProviderRef = Arc<dyn UserProvider>;
|
||||
|
||||
type Username<'a> = &'a str;
|
||||
type HostOrIp<'a> = &'a str;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Identity<'a> {
|
||||
UserId(Username<'a>, Option<HostOrIp<'a>>),
|
||||
}
|
||||
|
||||
pub type HashedPassword<'a> = &'a [u8];
|
||||
pub type Salt<'a> = &'a [u8];
|
||||
|
||||
/// Authentication information sent by the client.
|
||||
pub enum Password<'a> {
|
||||
PlainText(SecretString),
|
||||
MysqlNativePassword(HashedPassword<'a>, Salt<'a>),
|
||||
PgMD5(HashedPassword<'a>, Salt<'a>),
|
||||
}
|
||||
|
||||
pub fn user_provider_from_option(opt: &String) -> Result<UserProviderRef> {
|
||||
let (name, content) = opt.split_once(':').context(InvalidConfigSnafu {
|
||||
value: opt.to_string(),
|
||||
msg: "UserProviderOption must be in format `<option>:<value>`",
|
||||
})?;
|
||||
match name {
|
||||
user_provider::STATIC_USER_PROVIDER => {
|
||||
let provider =
|
||||
StaticUserProvider::try_from(content).map(|p| Arc::new(p) as UserProviderRef)?;
|
||||
Ok(provider)
|
||||
}
|
||||
_ => InvalidConfigSnafu {
|
||||
value: name.to_string(),
|
||||
msg: "Invalid UserProviderOption",
|
||||
}
|
||||
.fail(),
|
||||
}
|
||||
}
|
||||
use snafu::{Location, Snafu};
|
||||
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(visibility(pub))]
|
||||
33
src/auth/src/lib.rs
Normal file
33
src/auth/src/lib.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
mod common;
|
||||
mod error;
|
||||
mod permission;
|
||||
mod user_info;
|
||||
mod user_provider;
|
||||
|
||||
#[cfg(feature = "testing")]
|
||||
pub mod tests;
|
||||
|
||||
pub use common::{user_provider_from_option, userinfo_by_name, HashedPassword, Identity, Password};
|
||||
pub use error::{Error, Result};
|
||||
pub use permission::{PermissionChecker, PermissionReq, PermissionResp};
|
||||
pub use user_info::UserInfo;
|
||||
pub use user_provider::UserProvider;
|
||||
|
||||
/// pub type alias
|
||||
pub type UserInfoRef = std::sync::Arc<dyn UserInfo>;
|
||||
pub type UserProviderRef = std::sync::Arc<dyn UserProvider>;
|
||||
pub type PermissionCheckerRef = std::sync::Arc<dyn PermissionChecker>;
|
||||
60
src/auth/src/permission.rs
Normal file
60
src/auth/src/permission.rs
Normal file
@@ -0,0 +1,60 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::fmt::Debug;
|
||||
|
||||
use api::v1::greptime_request::Request;
|
||||
use sql::statements::statement::Statement;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::{PermissionCheckerRef, UserInfoRef};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum PermissionReq<'a> {
|
||||
GrpcRequest(&'a Request),
|
||||
SqlStatement(&'a Statement),
|
||||
PromQuery,
|
||||
Opentsdb,
|
||||
LineProtocol,
|
||||
PromStoreWrite,
|
||||
PromStoreRead,
|
||||
Otlp,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum PermissionResp {
|
||||
Allow,
|
||||
Reject,
|
||||
}
|
||||
|
||||
pub trait PermissionChecker: Send + Sync {
|
||||
fn check_permission(
|
||||
&self,
|
||||
user_info: Option<UserInfoRef>,
|
||||
req: PermissionReq,
|
||||
) -> Result<PermissionResp>;
|
||||
}
|
||||
|
||||
impl PermissionChecker for Option<&PermissionCheckerRef> {
|
||||
fn check_permission(
|
||||
&self,
|
||||
user_info: Option<UserInfoRef>,
|
||||
req: PermissionReq,
|
||||
) -> Result<PermissionResp> {
|
||||
match self {
|
||||
Some(checker) => checker.check_permission(user_info, req),
|
||||
None => Ok(PermissionResp::Allow),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -11,14 +11,17 @@
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use secrecy::ExposeSecret;
|
||||
use servers::auth::user_provider::auth_mysql;
|
||||
use servers::auth::{
|
||||
AccessDeniedSnafu, Identity, Password, UnsupportedPasswordTypeSnafu, UserNotFoundSnafu,
|
||||
UserPasswordMismatchSnafu, UserProvider,
|
||||
|
||||
use crate::error::{
|
||||
AccessDeniedSnafu, Result, UnsupportedPasswordTypeSnafu, UserNotFoundSnafu,
|
||||
UserPasswordMismatchSnafu,
|
||||
};
|
||||
use session::context::UserInfo;
|
||||
use crate::user_info::DefaultUserInfo;
|
||||
use crate::user_provider::static_user_provider::auth_mysql;
|
||||
#[allow(unused_imports)]
|
||||
use crate::Error;
|
||||
use crate::{Identity, Password, UserInfoRef, UserProvider};
|
||||
|
||||
pub struct DatabaseAuthInfo<'a> {
|
||||
pub catalog: &'a str,
|
||||
@@ -56,17 +59,13 @@ impl UserProvider for MockUserProvider {
|
||||
"mock_user_provider"
|
||||
}
|
||||
|
||||
async fn authenticate(
|
||||
&self,
|
||||
id: Identity<'_>,
|
||||
password: Password<'_>,
|
||||
) -> servers::auth::Result<UserInfo> {
|
||||
async fn authenticate(&self, id: Identity<'_>, password: Password<'_>) -> Result<UserInfoRef> {
|
||||
match id {
|
||||
Identity::UserId(username, _host) => match password {
|
||||
Password::PlainText(password) => {
|
||||
if username == "greptime" {
|
||||
if password.expose_secret() == "greptime" {
|
||||
Ok(UserInfo::new("greptime"))
|
||||
Ok(DefaultUserInfo::with_name("greptime"))
|
||||
} else {
|
||||
UserPasswordMismatchSnafu {
|
||||
username: username.to_string(),
|
||||
@@ -82,7 +81,7 @@ impl UserProvider for MockUserProvider {
|
||||
}
|
||||
Password::MysqlNativePassword(auth_data, salt) => {
|
||||
auth_mysql(auth_data, salt, username, "greptime".as_bytes())
|
||||
.map(|_| UserInfo::new(username))
|
||||
.map(|_| DefaultUserInfo::with_name(username))
|
||||
}
|
||||
_ => UnsupportedPasswordTypeSnafu {
|
||||
password_type: "mysql_native_password",
|
||||
@@ -92,12 +91,7 @@ impl UserProvider for MockUserProvider {
|
||||
}
|
||||
}
|
||||
|
||||
async fn authorize(
|
||||
&self,
|
||||
catalog: &str,
|
||||
schema: &str,
|
||||
user_info: &UserInfo,
|
||||
) -> servers::auth::Result<()> {
|
||||
async fn authorize(&self, catalog: &str, schema: &str, user_info: &UserInfoRef) -> Result<()> {
|
||||
if catalog == self.catalog && schema == self.schema && user_info.username() == self.username
|
||||
{
|
||||
Ok(())
|
||||
@@ -137,7 +131,7 @@ async fn test_auth_by_plain_text() {
|
||||
assert!(auth_result.is_err());
|
||||
assert!(matches!(
|
||||
auth_result.err().unwrap(),
|
||||
servers::auth::Error::UnsupportedPasswordType { .. }
|
||||
Error::UnsupportedPasswordType { .. }
|
||||
));
|
||||
|
||||
// auth failed, err: user not exist.
|
||||
@@ -150,7 +144,7 @@ async fn test_auth_by_plain_text() {
|
||||
assert!(auth_result.is_err());
|
||||
assert!(matches!(
|
||||
auth_result.err().unwrap(),
|
||||
servers::auth::Error::UserNotFound { .. }
|
||||
Error::UserNotFound { .. }
|
||||
));
|
||||
|
||||
// auth failed, err: wrong password
|
||||
@@ -163,7 +157,7 @@ async fn test_auth_by_plain_text() {
|
||||
assert!(auth_result.is_err());
|
||||
assert!(matches!(
|
||||
auth_result.err().unwrap(),
|
||||
servers::auth::Error::UserPasswordMismatch { .. }
|
||||
Error::UserPasswordMismatch { .. }
|
||||
))
|
||||
}
|
||||
|
||||
@@ -176,8 +170,8 @@ async fn test_schema_validate() {
|
||||
username: "test_user",
|
||||
});
|
||||
|
||||
let right_user = UserInfo::new("test_user");
|
||||
let wrong_user = UserInfo::default();
|
||||
let right_user = DefaultUserInfo::with_name("test_user");
|
||||
let wrong_user = DefaultUserInfo::with_name("greptime");
|
||||
|
||||
// check catalog
|
||||
let re = validator
|
||||
47
src/auth/src/user_info.rs
Normal file
47
src/auth/src/user_info.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::any::Any;
|
||||
use std::fmt::Debug;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::UserInfoRef;
|
||||
|
||||
pub trait UserInfo: Debug + Sync + Send {
|
||||
fn as_any(&self) -> &dyn Any;
|
||||
fn username(&self) -> &str;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct DefaultUserInfo {
|
||||
username: String,
|
||||
}
|
||||
|
||||
impl DefaultUserInfo {
|
||||
pub(crate) fn with_name(username: impl Into<String>) -> UserInfoRef {
|
||||
Arc::new(Self {
|
||||
username: username.into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl UserInfo for DefaultUserInfo {
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn username(&self) -> &str {
|
||||
self.username.as_str()
|
||||
}
|
||||
}
|
||||
46
src/auth/src/user_provider.rs
Normal file
46
src/auth/src/user_provider.rs
Normal file
@@ -0,0 +1,46 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
pub(crate) mod static_user_provider;
|
||||
|
||||
use crate::common::{Identity, Password};
|
||||
use crate::error::Result;
|
||||
use crate::UserInfoRef;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait UserProvider: Send + Sync {
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// [`authenticate`] checks whether a user is valid and allowed to access the database.
|
||||
async fn authenticate(&self, id: Identity<'_>, password: Password<'_>) -> Result<UserInfoRef>;
|
||||
|
||||
/// [`authorize`] checks whether a connection request
|
||||
/// from a certain user to a certain catalog/schema is legal.
|
||||
/// This method should be called after [`authenticate`].
|
||||
async fn authorize(&self, catalog: &str, schema: &str, user_info: &UserInfoRef) -> Result<()>;
|
||||
|
||||
/// [`auth`] is a combination of [`authenticate`] and [`authorize`].
|
||||
/// In most cases it's preferred for both convenience and performance.
|
||||
async fn auth(
|
||||
&self,
|
||||
id: Identity<'_>,
|
||||
password: Password<'_>,
|
||||
catalog: &str,
|
||||
schema: &str,
|
||||
) -> Result<UserInfoRef> {
|
||||
let user_info = self.authenticate(id, password).await?;
|
||||
self.authorize(catalog, schema, &user_info).await?;
|
||||
Ok(user_info)
|
||||
}
|
||||
}
|
||||
@@ -19,20 +19,20 @@ use std::io::BufRead;
|
||||
use std::path::Path;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use digest;
|
||||
use digest::Digest;
|
||||
use secrecy::ExposeSecret;
|
||||
use session::context::UserInfo;
|
||||
use sha1::Sha1;
|
||||
use snafu::{ensure, OptionExt, ResultExt};
|
||||
|
||||
use crate::auth::{
|
||||
Error, HashedPassword, Identity, IllegalParamSnafu, InvalidConfigSnafu, IoSnafu, Password,
|
||||
Result, Salt, UnsupportedPasswordTypeSnafu, UserNotFoundSnafu, UserPasswordMismatchSnafu,
|
||||
UserProvider,
|
||||
use crate::common::Salt;
|
||||
use crate::error::{
|
||||
Error, IllegalParamSnafu, InvalidConfigSnafu, IoSnafu, Result, UnsupportedPasswordTypeSnafu,
|
||||
UserNotFoundSnafu, UserPasswordMismatchSnafu,
|
||||
};
|
||||
use crate::user_info::DefaultUserInfo;
|
||||
use crate::{HashedPassword, Identity, Password, UserInfoRef, UserProvider};
|
||||
|
||||
pub const STATIC_USER_PROVIDER: &str = "static_user_provider";
|
||||
pub(crate) const STATIC_USER_PROVIDER: &str = "static_user_provider";
|
||||
|
||||
impl TryFrom<&str> for StaticUserProvider {
|
||||
type Error = Error;
|
||||
@@ -91,7 +91,7 @@ impl TryFrom<&str> for StaticUserProvider {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct StaticUserProvider {
|
||||
pub(crate) struct StaticUserProvider {
|
||||
users: HashMap<String, Vec<u8>>,
|
||||
}
|
||||
|
||||
@@ -105,7 +105,7 @@ impl UserProvider for StaticUserProvider {
|
||||
&self,
|
||||
input_id: Identity<'_>,
|
||||
input_pwd: Password<'_>,
|
||||
) -> Result<UserInfo> {
|
||||
) -> Result<UserInfoRef> {
|
||||
match input_id {
|
||||
Identity::UserId(username, _) => {
|
||||
ensure!(
|
||||
@@ -127,7 +127,7 @@ impl UserProvider for StaticUserProvider {
|
||||
}
|
||||
);
|
||||
return if save_pwd == pwd.expose_secret().as_bytes() {
|
||||
Ok(UserInfo::new(username))
|
||||
Ok(DefaultUserInfo::with_name(username))
|
||||
} else {
|
||||
UserPasswordMismatchSnafu {
|
||||
username: username.to_string(),
|
||||
@@ -143,7 +143,7 @@ impl UserProvider for StaticUserProvider {
|
||||
}
|
||||
);
|
||||
auth_mysql(auth_data, salt, username, save_pwd)
|
||||
.map(|_| UserInfo::new(username))
|
||||
.map(|_| DefaultUserInfo::with_name(username))
|
||||
}
|
||||
Password::PgMD5(_, _) => UnsupportedPasswordTypeSnafu {
|
||||
password_type: "pg_md5",
|
||||
@@ -154,7 +154,12 @@ impl UserProvider for StaticUserProvider {
|
||||
}
|
||||
}
|
||||
|
||||
async fn authorize(&self, _catalog: &str, _schema: &str, _user_info: &UserInfo) -> Result<()> {
|
||||
async fn authorize(
|
||||
&self,
|
||||
_catalog: &str,
|
||||
_schema: &str,
|
||||
_user_info: &UserInfoRef,
|
||||
) -> Result<()> {
|
||||
// default allow all
|
||||
Ok(())
|
||||
}
|
||||
@@ -208,10 +213,13 @@ pub mod test {
|
||||
use std::io::{LineWriter, Write};
|
||||
|
||||
use common_test_util::temp_dir::create_temp_dir;
|
||||
use session::context::UserInfo;
|
||||
|
||||
use crate::auth::user_provider::{double_sha1, sha1_one, sha1_two, StaticUserProvider};
|
||||
use crate::auth::{Identity, Password, UserProvider};
|
||||
use crate::user_info::DefaultUserInfo;
|
||||
use crate::user_provider::static_user_provider::{
|
||||
double_sha1, sha1_one, sha1_two, StaticUserProvider,
|
||||
};
|
||||
use crate::user_provider::{Identity, Password};
|
||||
use crate::UserProvider;
|
||||
|
||||
#[test]
|
||||
fn test_sha() {
|
||||
@@ -249,9 +257,10 @@ pub mod test {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_authorize() {
|
||||
let user_info = DefaultUserInfo::with_name("root");
|
||||
let provider = StaticUserProvider::try_from("cmd:root=123456,admin=654321").unwrap();
|
||||
provider
|
||||
.authorize("catalog", "schema", &UserInfo::new("root"))
|
||||
.authorize("catalog", "schema", &user_info)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
61
src/auth/tests/mod.rs
Normal file
61
src/auth/tests/mod.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#![feature(assert_matches)]
|
||||
use std::assert_matches::assert_matches;
|
||||
use std::sync::Arc;
|
||||
|
||||
use api::v1::greptime_request::Request;
|
||||
use auth::Error::InternalState;
|
||||
use auth::{PermissionChecker, PermissionCheckerRef, PermissionReq, PermissionResp, UserInfoRef};
|
||||
use sql::statements::show::{ShowDatabases, ShowKind};
|
||||
use sql::statements::statement::Statement;
|
||||
|
||||
struct DummyPermissionChecker;
|
||||
|
||||
impl PermissionChecker for DummyPermissionChecker {
|
||||
fn check_permission(
|
||||
&self,
|
||||
_user_info: Option<UserInfoRef>,
|
||||
req: PermissionReq,
|
||||
) -> auth::Result<PermissionResp> {
|
||||
match req {
|
||||
PermissionReq::GrpcRequest(_) => Ok(PermissionResp::Allow),
|
||||
PermissionReq::SqlStatement(_) => Ok(PermissionResp::Reject),
|
||||
_ => Err(InternalState {
|
||||
msg: "testing".to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_permission_checker() {
|
||||
let checker: PermissionCheckerRef = Arc::new(DummyPermissionChecker);
|
||||
|
||||
let grpc_result = checker.check_permission(
|
||||
None,
|
||||
PermissionReq::GrpcRequest(&Request::Query(Default::default())),
|
||||
);
|
||||
assert_matches!(grpc_result, Ok(PermissionResp::Allow));
|
||||
|
||||
let sql_result = checker.check_permission(
|
||||
None,
|
||||
PermissionReq::SqlStatement(&Statement::ShowDatabases(ShowDatabases::new(ShowKind::All))),
|
||||
);
|
||||
assert_matches!(sql_result, Ok(PermissionResp::Reject));
|
||||
|
||||
let err_result = checker.check_permission(None, PermissionReq::Opentsdb);
|
||||
assert_matches!(err_result, Err(InternalState { msg }) if msg == "testing");
|
||||
}
|
||||
@@ -17,6 +17,7 @@ metrics-process = ["servers/metrics-process"]
|
||||
[dependencies]
|
||||
anymap = "1.0.0-beta.2"
|
||||
async-trait.workspace = true
|
||||
auth.workspace = true
|
||||
catalog = { workspace = true }
|
||||
chrono.workspace = true
|
||||
clap = { version = "3.1", features = ["derive"] }
|
||||
|
||||
@@ -80,7 +80,7 @@ pub enum Error {
|
||||
#[snafu(display("Illegal auth config: {}", source))]
|
||||
IllegalAuthConfig {
|
||||
location: Location,
|
||||
source: servers::auth::Error,
|
||||
source: auth::Error,
|
||||
},
|
||||
|
||||
#[snafu(display("Unsupported selector type, {} source: {}", selector_type, source))]
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use auth::UserProviderRef;
|
||||
use clap::Parser;
|
||||
use common_base::Plugins;
|
||||
use common_telemetry::logging;
|
||||
@@ -21,9 +22,8 @@ use frontend::frontend::FrontendOptions;
|
||||
use frontend::instance::{FrontendInstance, Instance as FeInstance};
|
||||
use frontend::service_config::{InfluxdbOptions, PrometheusOptions};
|
||||
use meta_client::MetaClientOptions;
|
||||
use servers::auth::UserProviderRef;
|
||||
use servers::tls::{TlsMode, TlsOption};
|
||||
use servers::{auth, Mode};
|
||||
use servers::Mode;
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::error::{self, IllegalAuthConfigSnafu, Result, StartCatalogManagerSnafu};
|
||||
@@ -236,10 +236,10 @@ mod tests {
|
||||
use std::io::Write;
|
||||
use std::time::Duration;
|
||||
|
||||
use auth::{Identity, Password, UserProviderRef};
|
||||
use common_base::readable_size::ReadableSize;
|
||||
use common_test_util::temp_dir::create_named_temp_file;
|
||||
use frontend::service_config::GrpcOptions;
|
||||
use servers::auth::{Identity, Password, UserProviderRef};
|
||||
|
||||
use super::*;
|
||||
use crate::options::ENV_VAR_SEP;
|
||||
|
||||
@@ -345,9 +345,9 @@ mod tests {
|
||||
use std::io::Write;
|
||||
use std::time::Duration;
|
||||
|
||||
use auth::{Identity, Password, UserProviderRef};
|
||||
use common_base::readable_size::ReadableSize;
|
||||
use common_test_util::temp_dir::create_named_temp_file;
|
||||
use servers::auth::{Identity, Password, UserProviderRef};
|
||||
use servers::Mode;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -14,6 +14,7 @@ api = { workspace = true }
|
||||
async-compat = "0.2"
|
||||
async-stream.workspace = true
|
||||
async-trait = "0.1"
|
||||
auth.workspace = true
|
||||
catalog = { workspace = true }
|
||||
chrono.workspace = true
|
||||
client = { workspace = true }
|
||||
|
||||
@@ -578,6 +578,12 @@ pub enum Error {
|
||||
source: common_meta::error::Error,
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to pass permission check, source: {}", source))]
|
||||
Permission {
|
||||
source: auth::Error,
|
||||
location: Location,
|
||||
},
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
@@ -603,6 +609,8 @@ impl ErrorExt for Error {
|
||||
|
||||
Error::NotSupported { .. } => StatusCode::Unsupported,
|
||||
|
||||
Error::Permission { source, .. } => source.status_code(),
|
||||
|
||||
Error::HandleHeartbeatResponse { source, .. }
|
||||
| Error::TableMetadataManager { source, .. } => source.status_code(),
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ use api::v1::greptime_request::Request;
|
||||
use api::v1::meta::Role;
|
||||
use api::v1::{AddColumns, AlterExpr, Column, DdlRequest, InsertRequest, InsertRequests};
|
||||
use async_trait::async_trait;
|
||||
use auth::{PermissionChecker, PermissionCheckerRef, PermissionReq};
|
||||
use catalog::remote::CachedMetaKvBackend;
|
||||
use catalog::CatalogManagerRef;
|
||||
use client::client_manager::DatanodeClients;
|
||||
@@ -57,7 +58,7 @@ use query::query_engine::options::{validate_catalog_and_schema, QueryOptions};
|
||||
use query::query_engine::DescribeResult;
|
||||
use query::{QueryEngineFactory, QueryEngineRef};
|
||||
use servers::error as server_error;
|
||||
use servers::error::{ExecuteQuerySnafu, ParsePromQLSnafu};
|
||||
use servers::error::{AuthSnafu, ExecuteQuerySnafu, ParsePromQLSnafu};
|
||||
use servers::interceptor::{
|
||||
PromQueryInterceptor, PromQueryInterceptorRef, SqlQueryInterceptor, SqlQueryInterceptorRef,
|
||||
};
|
||||
@@ -79,8 +80,8 @@ use sqlparser::ast::ObjectName;
|
||||
use crate::catalog::FrontendCatalogManager;
|
||||
use crate::error::{
|
||||
self, Error, ExecLogicalPlanSnafu, ExecutePromqlSnafu, ExternalSnafu,
|
||||
InvalidInsertRequestSnafu, MissingMetasrvOptsSnafu, ParseSqlSnafu, PlanStatementSnafu, Result,
|
||||
SqlExecInterceptedSnafu,
|
||||
InvalidInsertRequestSnafu, MissingMetasrvOptsSnafu, ParseSqlSnafu, PermissionSnafu,
|
||||
PlanStatementSnafu, Result, SqlExecInterceptedSnafu,
|
||||
};
|
||||
use crate::expr_factory::{CreateExprFactoryRef, DefaultCreateExprFactory};
|
||||
use crate::frontend::FrontendOptions;
|
||||
@@ -488,6 +489,9 @@ impl SqlQueryHandler for Instance {
|
||||
Err(e) => return vec![Err(e)],
|
||||
};
|
||||
|
||||
let checker_ref = self.plugins.get::<PermissionCheckerRef>();
|
||||
let checker = checker_ref.as_ref();
|
||||
|
||||
match parse_stmt(query.as_ref(), query_ctx.sql_dialect())
|
||||
.and_then(|stmts| query_interceptor.post_parsing(stmts, query_ctx.clone()))
|
||||
{
|
||||
@@ -501,6 +505,18 @@ impl SqlQueryHandler for Instance {
|
||||
results.push(Err(e));
|
||||
break;
|
||||
}
|
||||
|
||||
if let Err(e) = checker
|
||||
.check_permission(
|
||||
query_ctx.current_user(),
|
||||
PermissionReq::SqlStatement(&stmt),
|
||||
)
|
||||
.context(PermissionSnafu)
|
||||
{
|
||||
results.push(Err(e));
|
||||
break;
|
||||
}
|
||||
|
||||
match self.query_statement(stmt, query_ctx.clone()).await {
|
||||
Ok(output) => {
|
||||
let output_result =
|
||||
@@ -523,6 +539,8 @@ impl SqlQueryHandler for Instance {
|
||||
|
||||
async fn do_exec_plan(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result<Output> {
|
||||
let _timer = timer!(metrics::METRIC_EXEC_PLAN_ELAPSED);
|
||||
// plan should be prepared before exec
|
||||
// we'll do check there
|
||||
self.query_engine
|
||||
.execute(plan, query_ctx)
|
||||
.await
|
||||
@@ -534,6 +552,7 @@ impl SqlQueryHandler for Instance {
|
||||
query: &PromQuery,
|
||||
query_ctx: QueryContextRef,
|
||||
) -> Vec<Result<Output>> {
|
||||
// check will be done in prometheus handler's do_query
|
||||
let result = PrometheusHandler::do_query(self, query, query_ctx)
|
||||
.await
|
||||
.with_context(|_| ExecutePromqlSnafu {
|
||||
@@ -551,6 +570,12 @@ impl SqlQueryHandler for Instance {
|
||||
stmt,
|
||||
Statement::Insert(_) | Statement::Query(_) | Statement::Delete(_)
|
||||
) {
|
||||
self.plugins
|
||||
.get::<PermissionCheckerRef>()
|
||||
.as_ref()
|
||||
.check_permission(query_ctx.current_user(), PermissionReq::SqlStatement(&stmt))
|
||||
.context(PermissionSnafu)?;
|
||||
|
||||
let plan = self
|
||||
.query_engine
|
||||
.planner()
|
||||
@@ -588,6 +613,12 @@ impl PrometheusHandler for Instance {
|
||||
.get::<PromQueryInterceptorRef<server_error::Error>>();
|
||||
interceptor.pre_execute(query, query_ctx.clone())?;
|
||||
|
||||
self.plugins
|
||||
.get::<PermissionCheckerRef>()
|
||||
.as_ref()
|
||||
.check_permission(query_ctx.current_user(), PermissionReq::PromQuery)
|
||||
.context(AuthSnafu)?;
|
||||
|
||||
let stmt = QueryLanguageParser::parse_promql(query).with_context(|_| ParsePromQLSnafu {
|
||||
query: query.clone(),
|
||||
})?;
|
||||
|
||||
@@ -15,15 +15,16 @@
|
||||
use api::v1::greptime_request::Request;
|
||||
use api::v1::query_request::Query;
|
||||
use async_trait::async_trait;
|
||||
use auth::{PermissionChecker, PermissionCheckerRef, PermissionReq};
|
||||
use common_query::Output;
|
||||
use query::parser::PromQuery;
|
||||
use servers::interceptor::{GrpcQueryInterceptor, GrpcQueryInterceptorRef};
|
||||
use servers::query_handler::grpc::GrpcQueryHandler;
|
||||
use servers::query_handler::sql::SqlQueryHandler;
|
||||
use session::context::QueryContextRef;
|
||||
use snafu::{ensure, OptionExt};
|
||||
use snafu::{ensure, OptionExt, ResultExt};
|
||||
|
||||
use crate::error::{Error, IncompleteGrpcResultSnafu, NotSupportedSnafu, Result};
|
||||
use crate::error::{Error, IncompleteGrpcResultSnafu, NotSupportedSnafu, PermissionSnafu, Result};
|
||||
use crate::instance::Instance;
|
||||
|
||||
#[async_trait]
|
||||
@@ -35,6 +36,12 @@ impl GrpcQueryHandler for Instance {
|
||||
let interceptor = interceptor_ref.as_ref();
|
||||
interceptor.pre_execute(&request, ctx.clone())?;
|
||||
|
||||
self.plugins
|
||||
.get::<PermissionCheckerRef>()
|
||||
.as_ref()
|
||||
.check_permission(ctx.current_user(), PermissionReq::GrpcRequest(&request))
|
||||
.context(PermissionSnafu)?;
|
||||
|
||||
let output = match request {
|
||||
Request::Inserts(requests) => self.handle_inserts(requests, ctx.clone()).await?,
|
||||
Request::Query(query_request) => {
|
||||
|
||||
@@ -13,7 +13,9 @@
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use auth::{PermissionChecker, PermissionCheckerRef, PermissionReq};
|
||||
use common_error::ext::BoxedError;
|
||||
use servers::error::AuthSnafu;
|
||||
use servers::influxdb::InfluxdbRequest;
|
||||
use servers::query_handler::InfluxdbLineProtocolHandler;
|
||||
use session::context::QueryContextRef;
|
||||
@@ -28,6 +30,12 @@ impl InfluxdbLineProtocolHandler for Instance {
|
||||
request: &InfluxdbRequest,
|
||||
ctx: QueryContextRef,
|
||||
) -> servers::error::Result<()> {
|
||||
self.plugins
|
||||
.get::<PermissionCheckerRef>()
|
||||
.as_ref()
|
||||
.check_permission(ctx.current_user(), PermissionReq::LineProtocol)
|
||||
.context(AuthSnafu)?;
|
||||
|
||||
let requests = request.try_into()?;
|
||||
let _ = self
|
||||
.handle_inserts(requests, ctx)
|
||||
|
||||
@@ -14,8 +14,10 @@
|
||||
|
||||
use api::v1::InsertRequests;
|
||||
use async_trait::async_trait;
|
||||
use auth::{PermissionChecker, PermissionCheckerRef, PermissionReq};
|
||||
use common_error::ext::BoxedError;
|
||||
use servers::error as server_error;
|
||||
use servers::error::AuthSnafu;
|
||||
use servers::opentsdb::codec::DataPoint;
|
||||
use servers::query_handler::OpentsdbProtocolHandler;
|
||||
use session::context::QueryContextRef;
|
||||
@@ -26,6 +28,12 @@ use crate::instance::Instance;
|
||||
#[async_trait]
|
||||
impl OpentsdbProtocolHandler for Instance {
|
||||
async fn exec(&self, data_point: &DataPoint, ctx: QueryContextRef) -> server_error::Result<()> {
|
||||
self.plugins
|
||||
.get::<PermissionCheckerRef>()
|
||||
.as_ref()
|
||||
.check_permission(ctx.current_user(), PermissionReq::Opentsdb)
|
||||
.context(AuthSnafu)?;
|
||||
|
||||
let requests = InsertRequests {
|
||||
inserts: vec![data_point.as_grpc_insert()],
|
||||
};
|
||||
|
||||
@@ -13,12 +13,13 @@
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use auth::{PermissionChecker, PermissionCheckerRef, PermissionReq};
|
||||
use common_error::ext::BoxedError;
|
||||
use metrics::counter;
|
||||
use opentelemetry_proto::tonic::collector::metrics::v1::{
|
||||
ExportMetricsServiceRequest, ExportMetricsServiceResponse,
|
||||
};
|
||||
use servers::error::{self, Result as ServerResult};
|
||||
use servers::error::{self, AuthSnafu, Result as ServerResult};
|
||||
use servers::otlp;
|
||||
use servers::query_handler::OpenTelemetryProtocolHandler;
|
||||
use session::context::QueryContextRef;
|
||||
@@ -34,6 +35,11 @@ impl OpenTelemetryProtocolHandler for Instance {
|
||||
request: ExportMetricsServiceRequest,
|
||||
ctx: QueryContextRef,
|
||||
) -> ServerResult<ExportMetricsServiceResponse> {
|
||||
self.plugins
|
||||
.get::<PermissionCheckerRef>()
|
||||
.as_ref()
|
||||
.check_permission(ctx.current_user(), PermissionReq::Otlp)
|
||||
.context(AuthSnafu)?;
|
||||
let (requests, rows) = otlp::to_grpc_insert_requests(request)?;
|
||||
let _ = self
|
||||
.handle_inserts(requests, ctx)
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
use api::prom_store::remote::read_request::ResponseType;
|
||||
use api::prom_store::remote::{Query, QueryResult, ReadRequest, ReadResponse, WriteRequest};
|
||||
use async_trait::async_trait;
|
||||
use auth::{PermissionChecker, PermissionCheckerRef, PermissionReq};
|
||||
use common_catalog::format_full_table_name;
|
||||
use common_error::ext::BoxedError;
|
||||
use common_query::Output;
|
||||
@@ -22,7 +23,7 @@ use common_recordbatch::RecordBatches;
|
||||
use common_telemetry::logging;
|
||||
use metrics::counter;
|
||||
use prost::Message;
|
||||
use servers::error::{self, Result as ServerResult};
|
||||
use servers::error::{self, AuthSnafu, Result as ServerResult};
|
||||
use servers::prom_store::{self, Metrics};
|
||||
use servers::query_handler::{PromStoreProtocolHandler, PromStoreResponse};
|
||||
use session::context::QueryContextRef;
|
||||
@@ -148,6 +149,11 @@ impl Instance {
|
||||
#[async_trait]
|
||||
impl PromStoreProtocolHandler for Instance {
|
||||
async fn write(&self, request: WriteRequest, ctx: QueryContextRef) -> ServerResult<()> {
|
||||
self.plugins
|
||||
.get::<PermissionCheckerRef>()
|
||||
.as_ref()
|
||||
.check_permission(ctx.current_user(), PermissionReq::PromStoreWrite)
|
||||
.context(AuthSnafu)?;
|
||||
let (requests, samples) = prom_store::to_grpc_insert_requests(request)?;
|
||||
let _ = self
|
||||
.handle_inserts(requests, ctx)
|
||||
@@ -164,6 +170,12 @@ impl PromStoreProtocolHandler for Instance {
|
||||
request: ReadRequest,
|
||||
ctx: QueryContextRef,
|
||||
) -> ServerResult<PromStoreResponse> {
|
||||
self.plugins
|
||||
.get::<PermissionCheckerRef>()
|
||||
.as_ref()
|
||||
.check_permission(ctx.current_user(), PermissionReq::PromStoreRead)
|
||||
.context(AuthSnafu)?;
|
||||
|
||||
let response_type = negotiate_response_type(&request.accepted_response_types)?;
|
||||
|
||||
// TODO(dennis): use read_hints to speedup query if possible
|
||||
|
||||
@@ -16,10 +16,10 @@ use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use auth::UserProviderRef;
|
||||
use common_base::Plugins;
|
||||
use common_runtime::Builder as RuntimeBuilder;
|
||||
use common_telemetry::info;
|
||||
use servers::auth::UserProviderRef;
|
||||
use servers::configurator::ConfiguratorRef;
|
||||
use servers::error::Error::InternalIo;
|
||||
use servers::grpc::GrpcServer;
|
||||
|
||||
@@ -11,51 +11,51 @@ arrow-schema.workspace = true
|
||||
async-recursion = "1.0"
|
||||
async-stream.workspace = true
|
||||
async-trait = "0.1"
|
||||
catalog = { workspace = true }
|
||||
catalog.workspace = true
|
||||
chrono.workspace = true
|
||||
client = { workspace = true }
|
||||
common-base = { workspace = true }
|
||||
common-catalog = { workspace = true }
|
||||
common-datasource = { workspace = true }
|
||||
common-error = { workspace = true }
|
||||
common-function = { workspace = true }
|
||||
common-meta = { workspace = true }
|
||||
common-query = { workspace = true }
|
||||
common-recordbatch = { workspace = true }
|
||||
common-telemetry = { workspace = true }
|
||||
common-time = { workspace = true }
|
||||
client.workspace = true
|
||||
common-base.workspace = true
|
||||
common-catalog.workspace = true
|
||||
common-datasource.workspace = true
|
||||
common-error.workspace = true
|
||||
common-function.workspace = true
|
||||
common-meta.workspace = true
|
||||
common-query.workspace = true
|
||||
common-recordbatch.workspace = true
|
||||
common-telemetry.workspace = true
|
||||
common-time.workspace = true
|
||||
datafusion-common.workspace = true
|
||||
datafusion-expr.workspace = true
|
||||
datafusion-optimizer.workspace = true
|
||||
datafusion-physical-expr.workspace = true
|
||||
datafusion-sql.workspace = true
|
||||
datafusion.workspace = true
|
||||
datatypes = { workspace = true }
|
||||
datatypes.workspace = true
|
||||
futures = "0.3"
|
||||
futures-util.workspace = true
|
||||
greptime-proto.workspace = true
|
||||
humantime = "2.1"
|
||||
metrics.workspace = true
|
||||
object-store = { workspace = true }
|
||||
object-store.workspace = true
|
||||
once_cell.workspace = true
|
||||
partition = { workspace = true }
|
||||
promql = { workspace = true }
|
||||
partition.workspace = true
|
||||
promql-parser = "0.1.1"
|
||||
promql.workspace = true
|
||||
regex.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json = "1.0"
|
||||
session = { workspace = true }
|
||||
session.workspace = true
|
||||
snafu = { version = "0.7", features = ["backtraces"] }
|
||||
sql = { workspace = true }
|
||||
substrait = { workspace = true }
|
||||
table = { workspace = true }
|
||||
sql.workspace = true
|
||||
substrait.workspace = true
|
||||
table.workspace = true
|
||||
tokio.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
approx_eq = "0.1"
|
||||
arrow.workspace = true
|
||||
catalog = { workspace = true, features = ["testing"] }
|
||||
common-function-macro = { workspace = true }
|
||||
common-function-macro.workspace = true
|
||||
format_num = "0.1"
|
||||
num = "0.4"
|
||||
num-traits = "0.2"
|
||||
@@ -64,7 +64,7 @@ rand.workspace = true
|
||||
session = { workspace = true, features = ["testing"] }
|
||||
statrs = "0.16"
|
||||
stats-cli = "3.0"
|
||||
store-api = { workspace = true }
|
||||
store-api.workspace = true
|
||||
streaming-stats = "0.2"
|
||||
table = { workspace = true, features = ["testing"] }
|
||||
tokio-stream = "0.1"
|
||||
|
||||
@@ -14,6 +14,7 @@ aide = { version = "0.9", features = ["axum"] }
|
||||
api = { workspace = true }
|
||||
arrow-flight.workspace = true
|
||||
async-trait = "0.1"
|
||||
auth.workspace = true
|
||||
axum = { version = "0.6", features = ["headers"] }
|
||||
axum-macros = "0.3.8"
|
||||
base64 = "0.13"
|
||||
@@ -34,7 +35,6 @@ common-time = { workspace = true }
|
||||
datafusion-common.workspace = true
|
||||
datafusion-expr.workspace = true
|
||||
datafusion.workspace = true
|
||||
|
||||
datatypes = { workspace = true }
|
||||
derive_builder.workspace = true
|
||||
digest = "0.10"
|
||||
@@ -79,7 +79,7 @@ serde.workspace = true
|
||||
serde_json = "1.0"
|
||||
session = { workspace = true }
|
||||
sha1 = "0.10"
|
||||
snafu = { version = "0.7", features = ["backtraces"] }
|
||||
snafu.workspace = true
|
||||
snap = "1"
|
||||
sql = { workspace = true }
|
||||
strum = { version = "0.24", features = ["derive"] }
|
||||
@@ -96,6 +96,7 @@ tower-http = { version = "0.3", features = ["full"] }
|
||||
tikv-jemalloc-ctl = { version = "0.5", features = ["use_std"] }
|
||||
|
||||
[dev-dependencies]
|
||||
auth = { workspace = true, features = ["testing"] }
|
||||
axum-test-helper = { git = "https://github.com/sunng87/axum-test-helper.git", branch = "patch-1" }
|
||||
catalog = { workspace = true, features = ["testing"] }
|
||||
client = { workspace = true }
|
||||
|
||||
@@ -32,8 +32,6 @@ use tonic::codegen::http::{HeaderMap, HeaderValue};
|
||||
use tonic::metadata::MetadataMap;
|
||||
use tonic::Code;
|
||||
|
||||
use crate::auth;
|
||||
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(visibility(pub))]
|
||||
pub enum Error {
|
||||
@@ -115,6 +113,9 @@ pub enum Error {
|
||||
#[snafu(display("Not supported: {}", feat))]
|
||||
NotSupported { feat: String },
|
||||
|
||||
#[snafu(display("Invalid request parameter: {}", reason))]
|
||||
InvalidParameter { reason: String, location: Location },
|
||||
|
||||
#[snafu(display("Invalid query: {}", reason))]
|
||||
InvalidQuery { reason: String, location: Location },
|
||||
|
||||
@@ -359,6 +360,7 @@ impl ErrorExt for Error {
|
||||
| CheckDatabaseValidity { source, .. } => source.status_code(),
|
||||
|
||||
NotSupported { .. }
|
||||
| InvalidParameter { .. }
|
||||
| InvalidQuery { .. }
|
||||
| InfluxdbLineProtocol { .. }
|
||||
| ConnResetByPeer { .. }
|
||||
|
||||
@@ -26,6 +26,7 @@ use api::v1::prometheus_gateway_server::{PrometheusGateway, PrometheusGatewaySer
|
||||
use api::v1::{HealthCheckRequest, HealthCheckResponse};
|
||||
use arrow_flight::flight_service_server::{FlightService, FlightServiceServer};
|
||||
use async_trait::async_trait;
|
||||
use auth::UserProviderRef;
|
||||
use common_runtime::Runtime;
|
||||
use common_telemetry::logging::info;
|
||||
use common_telemetry::{error, warn};
|
||||
@@ -38,7 +39,6 @@ use tokio_stream::wrappers::TcpListenerStream;
|
||||
use tonic::{Request, Response, Status};
|
||||
|
||||
use self::prom_query_gateway::PrometheusGatewayService;
|
||||
use crate::auth::UserProviderRef;
|
||||
use crate::error::{
|
||||
AlreadyStartedSnafu, GrpcReflectionServiceSnafu, InternalSnafu, Result, StartGrpcSnafu,
|
||||
TcpBindSnafu,
|
||||
|
||||
@@ -18,6 +18,7 @@ use std::time::Instant;
|
||||
use api::helper::request_type;
|
||||
use api::v1::auth_header::AuthScheme;
|
||||
use api::v1::{Basic, GreptimeRequest, RequestHeader};
|
||||
use auth::{Identity, Password, UserInfoRef, UserProviderRef};
|
||||
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use common_catalog::parse_catalog_and_schema_from_db_string;
|
||||
use common_error::ext::ErrorExt;
|
||||
@@ -29,7 +30,6 @@ use metrics::{histogram, increment_counter};
|
||||
use session::context::{QueryContextBuilder, QueryContextRef};
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
|
||||
use crate::auth::{Identity, Password, UserProviderRef};
|
||||
use crate::error::Error::UnsupportedAuthScheme;
|
||||
use crate::error::{
|
||||
AuthSnafu, InvalidQuerySnafu, JoinTaskSnafu, NotFoundAuthHeaderSnafu, Result as InternalResult,
|
||||
@@ -71,9 +71,12 @@ impl GreptimeRequestHandler {
|
||||
let header = request.header.as_ref();
|
||||
let query_ctx = create_query_context(header);
|
||||
|
||||
if let Err(e) = self.auth(header, &query_ctx).await? {
|
||||
return Ok(Err(e));
|
||||
}
|
||||
match self.auth(header, &query_ctx).await? {
|
||||
Err(e) => return Ok(Err(e)),
|
||||
Ok(user_info) => {
|
||||
query_ctx.set_current_user(user_info);
|
||||
}
|
||||
};
|
||||
|
||||
let handler = self.handler.clone();
|
||||
let request_type = request_type(&query);
|
||||
@@ -110,9 +113,9 @@ impl GreptimeRequestHandler {
|
||||
&self,
|
||||
header: Option<&RequestHeader>,
|
||||
query_ctx: &QueryContextRef,
|
||||
) -> TonicResult<InternalResult<()>> {
|
||||
) -> TonicResult<InternalResult<Option<UserInfoRef>>> {
|
||||
let Some(user_provider) = self.user_provider.as_ref() else {
|
||||
return Ok(Ok(()));
|
||||
return Ok(Ok(None));
|
||||
};
|
||||
|
||||
let auth_scheme = header
|
||||
@@ -138,7 +141,7 @@ impl GreptimeRequestHandler {
|
||||
name: "Token AuthScheme".to_string(),
|
||||
}),
|
||||
}
|
||||
.map(|_| ())
|
||||
.map(Some)
|
||||
.map_err(|e| {
|
||||
increment_counter!(
|
||||
METRIC_AUTH_FAILURE,
|
||||
|
||||
@@ -34,6 +34,7 @@ use std::time::{Duration, Instant};
|
||||
use aide::axum::{routing as apirouting, ApiRouter, IntoApiResponse};
|
||||
use aide::openapi::{Info, OpenApi, Server as OpenAPIServer};
|
||||
use async_trait::async_trait;
|
||||
use auth::UserProviderRef;
|
||||
use axum::body::BoxBody;
|
||||
use axum::error_handling::HandleErrorLayer;
|
||||
use axum::extract::{DefaultBodyLimit, MatchedPath};
|
||||
@@ -65,7 +66,6 @@ use tower_http::trace::TraceLayer;
|
||||
|
||||
use self::authorize::HttpAuth;
|
||||
use self::influxdb::{influxdb_health, influxdb_ping, influxdb_write_v1, influxdb_write_v2};
|
||||
use crate::auth::UserProviderRef;
|
||||
use crate::configurator::ConfiguratorRef;
|
||||
use crate::error::{AlreadyStartedSnafu, Result, StartHttpSnafu};
|
||||
use crate::http::admin::{compact, flush};
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use ::auth::UserProviderRef;
|
||||
use axum::http::{self, Request, StatusCode};
|
||||
use axum::response::Response;
|
||||
use common_catalog::parse_catalog_and_schema_from_db_string;
|
||||
@@ -24,17 +25,14 @@ use headers::Header;
|
||||
use http_body::Body;
|
||||
use metrics::increment_counter;
|
||||
use secrecy::SecretString;
|
||||
use session::context::UserInfo;
|
||||
use snafu::{ensure, IntoError, OptionExt, ResultExt};
|
||||
use snafu::{ensure, OptionExt, ResultExt};
|
||||
use tower_http::auth::AsyncAuthorizeRequest;
|
||||
|
||||
use super::header::GreptimeDbName;
|
||||
use super::PUBLIC_APIS;
|
||||
use crate::auth::Error::IllegalParam;
|
||||
use crate::auth::{Identity, IllegalParamSnafu, UserProviderRef};
|
||||
use crate::error::{
|
||||
self, AuthSnafu, InvalidAuthorizationHeaderSnafu, InvisibleASCIISnafu, NotFoundInfluxAuthSnafu,
|
||||
Result, UnsupportedAuthSchemeSnafu,
|
||||
self, InvalidAuthorizationHeaderSnafu, InvalidParameterSnafu, InvisibleASCIISnafu,
|
||||
NotFoundInfluxAuthSnafu, Result, UnsupportedAuthSchemeSnafu,
|
||||
};
|
||||
use crate::http::HTTP_API_PREFIX;
|
||||
|
||||
@@ -78,7 +76,9 @@ where
|
||||
let user_provider = if let Some(user_provider) = user_provider.filter(|_| need_auth) {
|
||||
user_provider
|
||||
} else {
|
||||
let _ = request.extensions_mut().insert(UserInfo::default());
|
||||
let _ = request
|
||||
.extensions_mut()
|
||||
.insert(auth::userinfo_by_name(None));
|
||||
return Ok(request);
|
||||
};
|
||||
|
||||
@@ -114,8 +114,8 @@ where
|
||||
|
||||
match user_provider
|
||||
.auth(
|
||||
Identity::UserId(username.as_str(), None),
|
||||
crate::auth::Password::PlainText(password),
|
||||
::auth::Identity::UserId(username.as_str(), None),
|
||||
::auth::Password::PlainText(password),
|
||||
catalog,
|
||||
schema,
|
||||
)
|
||||
@@ -143,7 +143,7 @@ where
|
||||
|
||||
fn extract_catalog_and_schema<B: Send + Sync + 'static>(
|
||||
request: &Request<B>,
|
||||
) -> crate::auth::Result<(&str, &str)> {
|
||||
) -> Result<(&str, &str)> {
|
||||
// parse database from header
|
||||
let dbname = request
|
||||
.headers()
|
||||
@@ -154,8 +154,8 @@ fn extract_catalog_and_schema<B: Send + Sync + 'static>(
|
||||
let query = request.uri().query().unwrap_or_default();
|
||||
extract_db_from_query(query)
|
||||
})
|
||||
.context(IllegalParamSnafu {
|
||||
msg: "db not provided or corrupted",
|
||||
.context(InvalidParameterSnafu {
|
||||
reason: "`db` must be provided in query string",
|
||||
})?;
|
||||
|
||||
Ok(parse_catalog_and_schema_from_db_string(dbname))
|
||||
@@ -193,9 +193,11 @@ fn get_influxdb_credentials<B: Send + Sync + 'static>(
|
||||
(Some(username), Some(password)) => {
|
||||
Ok(Some((username.to_string(), password.to_string().into())))
|
||||
}
|
||||
_ => Err(AuthSnafu.into_error(IllegalParam {
|
||||
msg: "influxdb auth: username and password must be provided together".to_string(),
|
||||
})),
|
||||
_ => InvalidParameterSnafu {
|
||||
reason: "influxdb auth: username and password must be provided together"
|
||||
.to_string(),
|
||||
}
|
||||
.fail(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ use std::env;
|
||||
use std::time::Instant;
|
||||
|
||||
use aide::transform::TransformOperation;
|
||||
use auth::UserInfoRef;
|
||||
use axum::extract::{Json, Query, State};
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::{Extension, Form};
|
||||
@@ -25,7 +26,6 @@ use common_telemetry::timer;
|
||||
use query::parser::PromQuery;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use session::context::UserInfo;
|
||||
|
||||
use crate::http::{ApiState, GreptimeOptionsConfigState, JsonResponse};
|
||||
use crate::metrics_handler::MetricsHandler;
|
||||
@@ -42,7 +42,7 @@ pub async fn sql(
|
||||
State(state): State<ApiState>,
|
||||
Query(query_params): Query<SqlQuery>,
|
||||
// TODO(fys): pass _user_info into query context
|
||||
_user_info: Extension<UserInfo>,
|
||||
user_info: Extension<UserInfoRef>,
|
||||
Form(form_params): Form<SqlQuery>,
|
||||
) -> Json<JsonResponse> {
|
||||
let sql_handler = &state.sql_handler;
|
||||
@@ -61,6 +61,7 @@ pub async fn sql(
|
||||
let resp = if let Some(sql) = &sql {
|
||||
match crate::http::query_context_from_db(sql_handler.clone(), db).await {
|
||||
Ok(query_ctx) => {
|
||||
query_ctx.set_current_user(Some(user_info.0));
|
||||
JsonResponse::from_output(sql_handler.do_query(sql, query_ctx).await).await
|
||||
}
|
||||
Err(resp) => resp,
|
||||
@@ -101,7 +102,7 @@ pub async fn promql(
|
||||
State(state): State<ApiState>,
|
||||
Query(params): Query<PromqlQuery>,
|
||||
// TODO(fys): pass _user_info into query context
|
||||
_user_info: Extension<UserInfo>,
|
||||
user_info: Extension<UserInfoRef>,
|
||||
) -> Json<JsonResponse> {
|
||||
let sql_handler = &state.sql_handler;
|
||||
let exec_start = Instant::now();
|
||||
@@ -117,6 +118,7 @@ pub async fn promql(
|
||||
let prom_query = params.into();
|
||||
let resp = match super::query_context_from_db(sql_handler.clone(), db).await {
|
||||
Ok(query_ctx) => {
|
||||
query_ctx.set_current_user(Some(user_info.0));
|
||||
JsonResponse::from_output(sql_handler.do_promql_query(&prom_query, query_ctx).await)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -14,9 +14,11 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use auth::UserInfoRef;
|
||||
use axum::extract::{Query, State};
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::Extension;
|
||||
use common_catalog::consts::DEFAULT_SCHEMA_NAME;
|
||||
use common_catalog::parse_catalog_and_schema_from_db_string;
|
||||
use common_grpc::writer::Precision;
|
||||
@@ -43,6 +45,7 @@ pub async fn influxdb_health() -> Result<impl IntoResponse> {
|
||||
pub async fn influxdb_write_v1(
|
||||
State(handler): State<InfluxdbLineProtocolHandlerRef>,
|
||||
Query(mut params): Query<HashMap<String, String>>,
|
||||
user_info: Extension<UserInfoRef>,
|
||||
lines: String,
|
||||
) -> Result<impl IntoResponse> {
|
||||
let db = params
|
||||
@@ -54,13 +57,14 @@ pub async fn influxdb_write_v1(
|
||||
.map(|val| parse_time_precision(val))
|
||||
.transpose()?;
|
||||
|
||||
influxdb_write(&db, precision, lines, handler).await
|
||||
influxdb_write(&db, precision, lines, handler, user_info.0).await
|
||||
}
|
||||
|
||||
#[axum_macros::debug_handler]
|
||||
pub async fn influxdb_write_v2(
|
||||
State(handler): State<InfluxdbLineProtocolHandlerRef>,
|
||||
Query(mut params): Query<HashMap<String, String>>,
|
||||
user_info: Extension<UserInfoRef>,
|
||||
lines: String,
|
||||
) -> Result<impl IntoResponse> {
|
||||
let db = params
|
||||
@@ -72,7 +76,7 @@ pub async fn influxdb_write_v2(
|
||||
.map(|val| parse_time_precision(val))
|
||||
.transpose()?;
|
||||
|
||||
influxdb_write(&db, precision, lines, handler).await
|
||||
influxdb_write(&db, precision, lines, handler, user_info.0).await
|
||||
}
|
||||
|
||||
pub async fn influxdb_write(
|
||||
@@ -80,6 +84,7 @@ pub async fn influxdb_write(
|
||||
precision: Option<Precision>,
|
||||
lines: String,
|
||||
handler: InfluxdbLineProtocolHandlerRef,
|
||||
user_info: UserInfoRef,
|
||||
) -> Result<impl IntoResponse> {
|
||||
let _timer = timer!(
|
||||
crate::metrics::METRIC_HTTP_INFLUXDB_WRITE_ELAPSED,
|
||||
@@ -88,6 +93,7 @@ pub async fn influxdb_write(
|
||||
|
||||
let (catalog, schema) = parse_catalog_and_schema_from_db_string(db);
|
||||
let ctx = QueryContext::with(catalog, schema);
|
||||
ctx.set_current_user(Some(user_info));
|
||||
|
||||
let request = InfluxdbRequest { precision, lines };
|
||||
|
||||
|
||||
@@ -14,9 +14,10 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use auth::UserInfoRef;
|
||||
use axum::extract::{Query, RawBody, State};
|
||||
use axum::http::StatusCode as HttpStatusCode;
|
||||
use axum::Json;
|
||||
use axum::{Extension, Json};
|
||||
use hyper::Body;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use session::context::QueryContext;
|
||||
@@ -77,12 +78,14 @@ pub enum OpentsdbPutResponse {
|
||||
pub async fn put(
|
||||
State(opentsdb_handler): State<OpentsdbProtocolHandlerRef>,
|
||||
Query(params): Query<HashMap<String, String>>,
|
||||
user_info: Extension<UserInfoRef>,
|
||||
RawBody(body): RawBody,
|
||||
) -> Result<(HttpStatusCode, Json<OpentsdbPutResponse>)> {
|
||||
let summary = params.contains_key("summary");
|
||||
let details = params.contains_key("details");
|
||||
|
||||
let ctx = QueryContext::with_db_name(params.get("db"));
|
||||
ctx.set_current_user(Some(user_info.0));
|
||||
|
||||
let data_points = parse_data_points(body).await?;
|
||||
|
||||
|
||||
@@ -12,10 +12,11 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use auth::UserInfoRef;
|
||||
use axum::extract::{RawBody, State};
|
||||
use axum::http::header;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::TypedHeader;
|
||||
use axum::{Extension, TypedHeader};
|
||||
use common_telemetry::timer;
|
||||
use hyper::Body;
|
||||
use opentelemetry_proto::tonic::collector::metrics::v1::{
|
||||
@@ -33,9 +34,11 @@ use crate::query_handler::OpenTelemetryProtocolHandlerRef;
|
||||
pub async fn metrics(
|
||||
State(handler): State<OpenTelemetryProtocolHandlerRef>,
|
||||
TypedHeader(db): TypedHeader<GreptimeDbName>,
|
||||
user_info: Extension<UserInfoRef>,
|
||||
RawBody(body): RawBody,
|
||||
) -> Result<OtlpResponse> {
|
||||
let ctx = QueryContext::with_db_name(db.value());
|
||||
ctx.set_current_user(Some(user_info.0));
|
||||
let _timer = timer!(
|
||||
crate::metrics::METRIC_HTTP_OPENTELEMETRY_ELAPSED,
|
||||
&[(crate::metrics::METRIC_DB_LABEL, ctx.get_db_string())]
|
||||
|
||||
@@ -20,7 +20,6 @@ use datatypes::schema::Schema;
|
||||
use query::plan::LogicalPlan;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub mod auth;
|
||||
pub mod configurator;
|
||||
pub mod error;
|
||||
pub mod grpc;
|
||||
|
||||
@@ -17,6 +17,7 @@ use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use ::auth::{Identity, Password, UserProviderRef};
|
||||
use async_trait::async_trait;
|
||||
use chrono::{NaiveDate, NaiveDateTime};
|
||||
use common_catalog::parse_catalog_and_schema_from_db_string;
|
||||
@@ -41,7 +42,6 @@ use sql::parser::ParserContext;
|
||||
use sql::statements::statement::Statement;
|
||||
use tokio::io::AsyncWrite;
|
||||
|
||||
use crate::auth::{Identity, Password, UserProviderRef};
|
||||
use crate::error::{self, InvalidPrepareStatementSnafu, Result};
|
||||
use crate::mysql::helper::{
|
||||
self, format_placeholder, replace_placeholders, transform_placeholders,
|
||||
@@ -197,7 +197,7 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
|
||||
}
|
||||
};
|
||||
}
|
||||
let user_info = user_info.unwrap_or_default();
|
||||
let user_info = user_info.unwrap_or_else(|| auth::userinfo_by_name(None));
|
||||
|
||||
self.session.set_user_info(user_info);
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use auth::UserProviderRef;
|
||||
use common_runtime::Runtime;
|
||||
use common_telemetry::logging::{info, warn};
|
||||
use futures::StreamExt;
|
||||
@@ -29,7 +30,6 @@ use tokio::io::BufWriter;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_rustls::rustls::ServerConfig;
|
||||
|
||||
use crate::auth::UserProviderRef;
|
||||
use crate::error::{Error, Result};
|
||||
use crate::mysql::handler::MysqlInstanceShim;
|
||||
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
|
||||
|
||||
@@ -28,6 +28,7 @@ use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use ::auth::UserProviderRef;
|
||||
use derive_builder::Builder;
|
||||
use pgwire::api::auth::ServerParameterProvider;
|
||||
use pgwire::api::store::MemPortalStore;
|
||||
@@ -38,7 +39,6 @@ use session::Session;
|
||||
|
||||
use self::auth_handler::PgLoginVerifier;
|
||||
use self::handler::DefaultQueryParser;
|
||||
use crate::auth::UserProviderRef;
|
||||
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
|
||||
use crate::SqlPlan;
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
use std::fmt::Debug;
|
||||
use std::sync::Exclusive;
|
||||
|
||||
use ::auth::{userinfo_by_name, Identity, Password, UserInfoRef, UserProviderRef};
|
||||
use async_trait::async_trait;
|
||||
use common_catalog::parse_catalog_and_schema_from_db_string;
|
||||
use common_error::ext::ErrorExt;
|
||||
@@ -26,12 +27,10 @@ use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
|
||||
use pgwire::messages::response::ErrorResponse;
|
||||
use pgwire::messages::startup::Authentication;
|
||||
use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
|
||||
use session::context::UserInfo;
|
||||
use session::Session;
|
||||
use snafu::IntoError;
|
||||
|
||||
use super::PostgresServerHandler;
|
||||
use crate::auth::{Identity, Password, UserProviderRef};
|
||||
use crate::error::{AuthSnafu, Result};
|
||||
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
|
||||
|
||||
@@ -74,26 +73,26 @@ impl LoginInfo {
|
||||
}
|
||||
|
||||
impl PgLoginVerifier {
|
||||
async fn auth(&self, login: &LoginInfo, password: &str) -> Result<bool> {
|
||||
async fn auth(&self, login: &LoginInfo, password: &str) -> Result<Option<UserInfoRef>> {
|
||||
let user_provider = match &self.user_provider {
|
||||
Some(provider) => provider,
|
||||
None => return Ok(false),
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
let user_name = match &login.user {
|
||||
Some(name) => name,
|
||||
None => return Ok(false),
|
||||
None => return Ok(None),
|
||||
};
|
||||
let catalog = match &login.catalog {
|
||||
Some(name) => name,
|
||||
None => return Ok(false),
|
||||
None => return Ok(None),
|
||||
};
|
||||
let schema = match &login.schema {
|
||||
Some(name) => name,
|
||||
None => return Ok(false),
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
if let Err(e) = user_provider
|
||||
match user_provider
|
||||
.auth(
|
||||
Identity::UserId(user_name, None),
|
||||
Password::PlainText(password.to_string().into()),
|
||||
@@ -102,16 +101,17 @@ impl PgLoginVerifier {
|
||||
)
|
||||
.await
|
||||
{
|
||||
increment_counter!(
|
||||
crate::metrics::METRIC_AUTH_FAILURE,
|
||||
&[(
|
||||
crate::metrics::METRIC_CODE_LABEL,
|
||||
format!("{}", e.status_code())
|
||||
)]
|
||||
);
|
||||
Err(AuthSnafu.into_error(e))
|
||||
} else {
|
||||
Ok(true)
|
||||
Err(e) => {
|
||||
increment_counter!(
|
||||
crate::metrics::METRIC_AUTH_FAILURE,
|
||||
&[(
|
||||
crate::metrics::METRIC_CODE_LABEL,
|
||||
format!("{}", e.status_code())
|
||||
)]
|
||||
);
|
||||
Err(AuthSnafu.into_error(e))
|
||||
}
|
||||
Ok(user_info) => Ok(Some(user_info)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -126,9 +126,7 @@ where
|
||||
if let Some(current_schema) = client.metadata().get(super::METADATA_SCHEMA) {
|
||||
session.set_schema(current_schema.clone());
|
||||
}
|
||||
if let Some(username) = client.metadata().get(super::METADATA_USER) {
|
||||
session.set_user_info(UserInfo::new(username));
|
||||
}
|
||||
// set userinfo outside
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -174,6 +172,9 @@ impl StartupHandler for PostgresServerHandler {
|
||||
))
|
||||
.await?;
|
||||
} else {
|
||||
self.session.set_user_info(userinfo_by_name(
|
||||
client.metadata().get(super::METADATA_USER).cloned(),
|
||||
));
|
||||
set_client_info(client, &self.session);
|
||||
auth::finish_authentication(client, self.param_provider.as_ref()).await;
|
||||
}
|
||||
@@ -188,7 +189,12 @@ impl StartupHandler for PostgresServerHandler {
|
||||
|
||||
// do authenticate
|
||||
let auth_result = self.login_verifier.auth(&login_info, pwd.password()).await;
|
||||
if !matches!(auth_result, Ok(true)) {
|
||||
|
||||
if let Ok(Some(user_info)) = auth_result {
|
||||
self.session.set_user_info(user_info);
|
||||
set_client_info(client, &self.session);
|
||||
auth::finish_authentication(client, self.param_provider.as_ref()).await;
|
||||
} else {
|
||||
return send_error(
|
||||
client,
|
||||
"FATAL",
|
||||
@@ -197,8 +203,6 @@ impl StartupHandler for PostgresServerHandler {
|
||||
)
|
||||
.await;
|
||||
}
|
||||
set_client_info(client, &self.session);
|
||||
auth::finish_authentication(client, self.param_provider.as_ref()).await;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ use std::future::Future;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use ::auth::UserProviderRef;
|
||||
use async_trait::async_trait;
|
||||
use common_runtime::Runtime;
|
||||
use common_telemetry::logging::error;
|
||||
@@ -26,7 +27,6 @@ use pgwire::tokio::process_socket;
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
|
||||
use super::{MakePostgresServerHandler, MakePostgresServerHandlerBuilder};
|
||||
use crate::auth::UserProviderRef;
|
||||
use crate::error::Result;
|
||||
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
|
||||
use crate::server::{AbortableStream, BaseTcpServer, Server};
|
||||
|
||||
@@ -17,6 +17,7 @@ use std::collections::{BTreeMap, HashMap, HashSet};
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use ::auth::UserProviderRef;
|
||||
use async_trait::async_trait;
|
||||
use axum::body::BoxBody;
|
||||
use axum::extract::{Path, Query, State};
|
||||
@@ -52,7 +53,6 @@ use tower_http::auth::AsyncRequireAuthorizationLayer;
|
||||
use tower_http::compression::CompressionLayer;
|
||||
use tower_http::trace::TraceLayer;
|
||||
|
||||
use crate::auth::UserProviderRef;
|
||||
use crate::error::{
|
||||
AlreadyStartedSnafu, CollectRecordbatchSnafu, Error, InternalSnafu, InvalidQuerySnafu, Result,
|
||||
StartHttpSnafu, UnexpectedResultSnafu,
|
||||
|
||||
@@ -19,9 +19,10 @@ use api::v1::auth_header::AuthScheme;
|
||||
use api::v1::Basic;
|
||||
use arrow_flight::flight_service_server::{FlightService, FlightServiceServer};
|
||||
use async_trait::async_trait;
|
||||
use auth::tests::MockUserProvider;
|
||||
use auth::UserProviderRef;
|
||||
use client::{Client, Database, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use common_runtime::{Builder as RuntimeBuilder, Runtime};
|
||||
use servers::auth::UserProviderRef;
|
||||
use servers::error::{Result, StartGrpcSnafu, TcpBindSnafu};
|
||||
use servers::grpc::flight::FlightHandler;
|
||||
use servers::grpc::handler::GreptimeRequestHandler;
|
||||
@@ -32,7 +33,6 @@ use table::test_util::MemTable;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_stream::wrappers::TcpListenerStream;
|
||||
|
||||
use crate::auth::MockUserProvider;
|
||||
use crate::{create_testing_grpc_query_handler, LOCALHOST_WITH_0};
|
||||
|
||||
struct MockGrpcServer {
|
||||
|
||||
@@ -14,16 +14,14 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use auth::tests::MockUserProvider;
|
||||
use auth::{UserInfoRef, UserProvider};
|
||||
use axum::body::BoxBody;
|
||||
use axum::http;
|
||||
use hyper::Request;
|
||||
use servers::auth::UserProvider;
|
||||
use servers::http::authorize::HttpAuth;
|
||||
use session::context::UserInfo;
|
||||
use tower_http::auth::AsyncAuthorizeRequest;
|
||||
|
||||
use crate::auth::MockUserProvider;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_http_auth() {
|
||||
let mut http_auth: HttpAuth<BoxBody> = HttpAuth::new(None);
|
||||
@@ -31,8 +29,8 @@ async fn test_http_auth() {
|
||||
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
|
||||
let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
|
||||
let auth_res = http_auth.authorize(req).await.unwrap();
|
||||
let user_info: &UserInfo = auth_res.extensions().get().unwrap();
|
||||
let default = UserInfo::default();
|
||||
let user_info: &UserInfoRef = auth_res.extensions().get().unwrap();
|
||||
let default = auth::userinfo_by_name(None);
|
||||
assert_eq!(default.username(), user_info.username());
|
||||
|
||||
// In mock user provider, right username:password == "greptime:greptime"
|
||||
@@ -42,8 +40,8 @@ async fn test_http_auth() {
|
||||
// base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU="
|
||||
let req = mock_http_request(Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="), None).unwrap();
|
||||
let req = http_auth.authorize(req).await.unwrap();
|
||||
let user_info: &UserInfo = req.extensions().get().unwrap();
|
||||
let default = UserInfo::default();
|
||||
let user_info: &UserInfoRef = req.extensions().get().unwrap();
|
||||
let default = auth::userinfo_by_name(None);
|
||||
assert_eq!(default.username(), user_info.username());
|
||||
|
||||
let req = mock_http_request(None, None).unwrap();
|
||||
@@ -72,8 +70,8 @@ async fn test_schema_validating() {
|
||||
)
|
||||
.unwrap();
|
||||
let req = http_auth.authorize(req).await.unwrap();
|
||||
let user_info: &UserInfo = req.extensions().get().unwrap();
|
||||
let default = UserInfo::default();
|
||||
let user_info: &UserInfoRef = req.extensions().get().unwrap();
|
||||
let default = auth::userinfo_by_name(None);
|
||||
assert_eq!(default.username(), user_info.username());
|
||||
|
||||
// wrong database
|
||||
|
||||
@@ -26,7 +26,6 @@ use servers::http::{
|
||||
JsonOutput,
|
||||
};
|
||||
use servers::metrics_handler::MetricsHandler;
|
||||
use session::context::UserInfo;
|
||||
use table::test_util::MemTable;
|
||||
|
||||
use crate::{
|
||||
@@ -43,7 +42,7 @@ async fn test_sql_not_provided() {
|
||||
script_handler: None,
|
||||
}),
|
||||
Query(http_handler::SqlQuery::default()),
|
||||
axum::Extension(UserInfo::default()),
|
||||
axum::Extension(auth::userinfo_by_name(None)),
|
||||
Form(http_handler::SqlQuery::default()),
|
||||
)
|
||||
.await;
|
||||
@@ -68,7 +67,7 @@ async fn test_sql_output_rows() {
|
||||
script_handler: None,
|
||||
}),
|
||||
query,
|
||||
axum::Extension(UserInfo::default()),
|
||||
axum::Extension(auth::userinfo_by_name(None)),
|
||||
Form(http_handler::SqlQuery::default()),
|
||||
)
|
||||
.await;
|
||||
@@ -114,7 +113,7 @@ async fn test_sql_form() {
|
||||
script_handler: None,
|
||||
}),
|
||||
Query(http_handler::SqlQuery::default()),
|
||||
axum::Extension(UserInfo::default()),
|
||||
axum::Extension(auth::userinfo_by_name(None)),
|
||||
form,
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -17,6 +17,7 @@ use std::sync::Arc;
|
||||
use api::v1::greptime_request::Request;
|
||||
use api::v1::InsertRequests;
|
||||
use async_trait::async_trait;
|
||||
use auth::tests::{DatabaseAuthInfo, MockUserProvider};
|
||||
use axum::{http, Router};
|
||||
use axum_test_helper::TestClient;
|
||||
use common_query::Output;
|
||||
@@ -33,8 +34,6 @@ use servers::query_handler::InfluxdbLineProtocolHandler;
|
||||
use session::context::QueryContextRef;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::auth::{DatabaseAuthInfo, MockUserProvider};
|
||||
|
||||
struct DummyInstance {
|
||||
tx: Arc<mpsc::Sender<(String, String)>>,
|
||||
}
|
||||
|
||||
@@ -36,7 +36,6 @@ use snafu::ensure;
|
||||
use sql::statements::statement::Statement;
|
||||
use table::test_util::MemTable;
|
||||
|
||||
mod auth;
|
||||
mod grpc;
|
||||
mod http;
|
||||
mod interceptor;
|
||||
|
||||
@@ -16,6 +16,7 @@ use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use auth::tests::{DatabaseAuthInfo, MockUserProvider};
|
||||
use common_catalog::consts::DEFAULT_SCHEMA_NAME;
|
||||
use common_recordbatch::RecordBatch;
|
||||
use common_runtime::Builder as RuntimeBuilder;
|
||||
@@ -32,7 +33,6 @@ use servers::server::Server;
|
||||
use servers::tls::TlsOption;
|
||||
use table::test_util::MemTable;
|
||||
|
||||
use crate::auth::{DatabaseAuthInfo, MockUserProvider};
|
||||
use crate::create_testing_sql_query_handler;
|
||||
use crate::mysql::{all_datatype_testing_data, MysqlTextRow, TestingData};
|
||||
|
||||
|
||||
@@ -16,6 +16,8 @@ use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
use auth::tests::{DatabaseAuthInfo, MockUserProvider};
|
||||
use auth::UserProviderRef;
|
||||
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use common_runtime::Builder as RuntimeBuilder;
|
||||
use pgwire::api::Type;
|
||||
@@ -23,7 +25,6 @@ use rand::rngs::StdRng;
|
||||
use rand::Rng;
|
||||
use rustls::client::{ServerCertVerified, ServerCertVerifier};
|
||||
use rustls::{Certificate, Error, ServerName};
|
||||
use servers::auth::UserProviderRef;
|
||||
use servers::error::Result;
|
||||
use servers::postgres::PostgresServer;
|
||||
use servers::server::Server;
|
||||
@@ -31,7 +32,6 @@ use servers::tls::TlsOption;
|
||||
use table::test_util::MemTable;
|
||||
use tokio_postgres::{Client, Error as PgError, NoTls, SimpleQueryMessage};
|
||||
|
||||
use crate::auth::{DatabaseAuthInfo, MockUserProvider};
|
||||
use crate::create_testing_instance;
|
||||
|
||||
fn create_postgres_server(
|
||||
|
||||
@@ -9,6 +9,7 @@ testing = []
|
||||
|
||||
[dependencies]
|
||||
arc-swap = "1.5"
|
||||
auth.workspace = true
|
||||
common-catalog = { workspace = true }
|
||||
common-telemetry = { workspace = true }
|
||||
common-time = { workspace = true }
|
||||
|
||||
@@ -17,6 +17,7 @@ use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use auth::UserInfoRef;
|
||||
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use common_catalog::{build_db_string, parse_catalog_and_schema_from_db_string};
|
||||
use common_time::TimeZone;
|
||||
@@ -32,6 +33,7 @@ pub type ConnInfoRef = Arc<ConnInfo>;
|
||||
pub struct QueryContext {
|
||||
current_catalog: String,
|
||||
current_schema: String,
|
||||
current_user: ArcSwap<Option<UserInfoRef>>,
|
||||
time_zone: ArcSwap<Option<TimeZone>>,
|
||||
sql_dialect: Box<dyn Dialect + Send + Sync>,
|
||||
trace_id: u64,
|
||||
@@ -109,6 +111,16 @@ impl QueryContext {
|
||||
let _ = self.time_zone.swap(Arc::new(tz));
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn current_user(&self) -> Option<UserInfoRef> {
|
||||
self.current_user.load().as_ref().clone()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn set_current_user(&self, user: Option<UserInfoRef>) {
|
||||
let _ = self.current_user.swap(Arc::new(user));
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn trace_id(&self) -> u64 {
|
||||
self.trace_id
|
||||
@@ -124,6 +136,9 @@ impl QueryContextBuilder {
|
||||
current_schema: self
|
||||
.current_schema
|
||||
.unwrap_or_else(|| DEFAULT_SCHEMA_NAME.to_string()),
|
||||
current_user: self
|
||||
.current_user
|
||||
.unwrap_or_else(|| ArcSwap::new(Arc::new(None))),
|
||||
time_zone: self
|
||||
.time_zone
|
||||
.unwrap_or_else(|| ArcSwap::new(Arc::new(None))),
|
||||
@@ -140,33 +155,6 @@ impl QueryContextBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
pub const DEFAULT_USERNAME: &str = "greptime";
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct UserInfo {
|
||||
username: String,
|
||||
}
|
||||
|
||||
impl Default for UserInfo {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
username: DEFAULT_USERNAME.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserInfo {
|
||||
pub fn username(&self) -> &str {
|
||||
self.username.as_str()
|
||||
}
|
||||
|
||||
pub fn new(username: impl Into<String>) -> Self {
|
||||
Self {
|
||||
username: username.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ConnInfo {
|
||||
pub client_addr: Option<SocketAddr>,
|
||||
@@ -225,7 +213,7 @@ mod test {
|
||||
use common_catalog::consts::DEFAULT_CATALOG_NAME;
|
||||
|
||||
use super::*;
|
||||
use crate::context::{Channel, UserInfo};
|
||||
use crate::context::Channel;
|
||||
use crate::Session;
|
||||
|
||||
#[test]
|
||||
@@ -233,8 +221,6 @@ mod test {
|
||||
let session = Session::new(Some("127.0.0.1:9000".parse().unwrap()), Channel::Mysql);
|
||||
// test user_info
|
||||
assert_eq!(session.user_info().username(), "greptime");
|
||||
session.set_user_info(UserInfo::new("root"));
|
||||
assert_eq!(session.user_info().username(), "root");
|
||||
|
||||
// test channel
|
||||
assert_eq!(session.conn_info().channel, Channel::Mysql);
|
||||
|
||||
@@ -18,18 +18,19 @@ use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use auth::UserInfoRef;
|
||||
use common_catalog::build_db_string;
|
||||
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use context::QueryContextBuilder;
|
||||
|
||||
use crate::context::{Channel, ConnInfo, QueryContextRef, UserInfo};
|
||||
use crate::context::{Channel, ConnInfo, QueryContextRef};
|
||||
|
||||
/// Session for persistent connection such as MySQL, PostgreSQL etc.
|
||||
#[derive(Debug)]
|
||||
pub struct Session {
|
||||
catalog: ArcSwap<String>,
|
||||
schema: ArcSwap<String>,
|
||||
user_info: ArcSwap<UserInfo>,
|
||||
user_info: ArcSwap<UserInfoRef>,
|
||||
conn_info: ConnInfo,
|
||||
}
|
||||
|
||||
@@ -40,7 +41,7 @@ impl Session {
|
||||
Session {
|
||||
catalog: ArcSwap::new(Arc::new(DEFAULT_CATALOG_NAME.into())),
|
||||
schema: ArcSwap::new(Arc::new(DEFAULT_SCHEMA_NAME.into())),
|
||||
user_info: ArcSwap::new(Arc::new(UserInfo::default())),
|
||||
user_info: ArcSwap::new(Arc::new(auth::userinfo_by_name(None))),
|
||||
conn_info: ConnInfo::new(addr, channel),
|
||||
}
|
||||
}
|
||||
@@ -48,6 +49,9 @@ impl Session {
|
||||
#[inline]
|
||||
pub fn new_query_context(&self) -> QueryContextRef {
|
||||
QueryContextBuilder::default()
|
||||
.current_user(ArcSwap::new(Arc::new(Some(
|
||||
self.user_info.load().as_ref().clone(),
|
||||
))))
|
||||
.current_catalog(self.catalog.load().to_string())
|
||||
.current_schema(self.schema.load().to_string())
|
||||
.sql_dialect(self.conn_info.channel.dialect())
|
||||
@@ -65,12 +69,12 @@ impl Session {
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn user_info(&self) -> Arc<UserInfo> {
|
||||
self.user_info.load().clone()
|
||||
pub fn user_info(&self) -> UserInfoRef {
|
||||
self.user_info.load().clone().as_ref().clone()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn set_user_info(&self, user_info: UserInfo) {
|
||||
pub fn set_user_info(&self, user_info: UserInfoRef) {
|
||||
self.user_info.store(Arc::new(user_info));
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ dashboard = []
|
||||
[dependencies]
|
||||
api = { workspace = true }
|
||||
async-trait = "0.1"
|
||||
auth.workspace = true
|
||||
axum = "0.6"
|
||||
axum-test-helper = { git = "https://github.com/sunng87/axum-test-helper.git", branch = "patch-1" }
|
||||
catalog = { workspace = true }
|
||||
|
||||
@@ -19,6 +19,7 @@ use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use auth::UserProviderRef;
|
||||
use axum::Router;
|
||||
use catalog::{CatalogManagerRef, RegisterTableRequest};
|
||||
use common_catalog::consts::{
|
||||
@@ -48,7 +49,6 @@ use object_store::services::{Azblob, Gcs, Oss, S3};
|
||||
use object_store::test_util::TempFolder;
|
||||
use object_store::ObjectStore;
|
||||
use secrecy::ExposeSecret;
|
||||
use servers::auth::UserProviderRef;
|
||||
use servers::grpc::GrpcServer;
|
||||
use servers::http::{HttpOptions, HttpServerBuilder};
|
||||
use servers::metrics_handler::MetricsHandler;
|
||||
|
||||
@@ -12,8 +12,6 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use api::v1::alter_expr::Kind;
|
||||
use api::v1::promql_request::Promql;
|
||||
use api::v1::{
|
||||
@@ -21,10 +19,10 @@ use api::v1::{
|
||||
CreateTableExpr, InsertRequest, InsertRequests, PromInstantQuery, PromRangeQuery,
|
||||
PromqlRequest, RequestHeader, SemanticType, TableId,
|
||||
};
|
||||
use auth::user_provider_from_option;
|
||||
use client::{Client, Database, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use common_catalog::consts::{MIN_USER_TABLE_ID, MITO_ENGINE};
|
||||
use common_query::Output;
|
||||
use servers::auth::user_provider::StaticUserProvider;
|
||||
use servers::prometheus::{PromData, PromSeries, PrometheusJsonResponse, PrometheusResponse};
|
||||
use servers::server::Server;
|
||||
use tests_integration::test_util::{
|
||||
@@ -118,14 +116,13 @@ pub async fn test_dbname(store_type: StorageType) {
|
||||
}
|
||||
|
||||
pub async fn test_grpc_auth(store_type: StorageType) {
|
||||
let user_provider = StaticUserProvider::try_from("cmd:greptime_user=greptime_pwd").unwrap();
|
||||
|
||||
let (addr, mut guard, fe_grpc_server) = setup_grpc_server_with_user_provider(
|
||||
store_type,
|
||||
"auto_create_table",
|
||||
Some(Arc::new(user_provider)),
|
||||
let user_provider = user_provider_from_option(
|
||||
&"static_user_provider:cmd:greptime_user=greptime_pwd".to_string(),
|
||||
)
|
||||
.await;
|
||||
.unwrap();
|
||||
let (addr, mut guard, fe_grpc_server) =
|
||||
setup_grpc_server_with_user_provider(store_type, "auto_create_table", Some(user_provider))
|
||||
.await;
|
||||
|
||||
let grpc_client = Client::with_urls(vec![addr]);
|
||||
let mut db = Database::new_with_dbname(
|
||||
|
||||
@@ -13,13 +13,12 @@
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use auth::user_provider_from_option;
|
||||
use axum::http::StatusCode;
|
||||
use axum_test_helper::TestClient;
|
||||
use common_error::status_code::StatusCode as ErrorCode;
|
||||
use serde_json::json;
|
||||
use servers::auth::user_provider::StaticUserProvider;
|
||||
use servers::http::handler::HealthResponse;
|
||||
use servers::http::{JsonOutput, JsonResponse};
|
||||
use servers::prometheus::{PrometheusJsonResponse, PrometheusResponse};
|
||||
@@ -76,12 +75,15 @@ macro_rules! http_tests {
|
||||
pub async fn test_http_auth(store_type: StorageType) {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let user_provider = StaticUserProvider::try_from("cmd:greptime_user=greptime_pwd").unwrap();
|
||||
let user_provider = user_provider_from_option(
|
||||
&"static_user_provider:cmd:greptime_user=greptime_pwd".to_string(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let (app, mut guard) = setup_test_http_app_with_frontend_and_user_provider(
|
||||
store_type,
|
||||
"sql_api",
|
||||
Some(Arc::new(user_provider)),
|
||||
Some(user_provider),
|
||||
)
|
||||
.await;
|
||||
let client = TestClient::new(app);
|
||||
|
||||
@@ -11,10 +11,9 @@
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
use std::sync::Arc;
|
||||
|
||||
use auth::user_provider_from_option;
|
||||
use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc};
|
||||
use servers::auth::user_provider::StaticUserProvider;
|
||||
use sqlx::mysql::{MySqlDatabaseError, MySqlPoolOptions};
|
||||
use sqlx::postgres::{PgDatabaseError, PgPoolOptions};
|
||||
use sqlx::Row;
|
||||
@@ -65,13 +64,13 @@ macro_rules! sql_tests {
|
||||
}
|
||||
|
||||
pub async fn test_mysql_auth(store_type: StorageType) {
|
||||
let user_provider = StaticUserProvider::try_from("cmd:greptime_user=greptime_pwd").unwrap();
|
||||
let (addr, mut guard, fe_mysql_server) = setup_mysql_server_with_user_provider(
|
||||
store_type,
|
||||
"sql_crud",
|
||||
Some(Arc::new(user_provider)),
|
||||
let user_provider = user_provider_from_option(
|
||||
&"static_user_provider:cmd:greptime_user=greptime_pwd".to_string(),
|
||||
)
|
||||
.await;
|
||||
.unwrap();
|
||||
|
||||
let (addr, mut guard, fe_mysql_server) =
|
||||
setup_mysql_server_with_user_provider(store_type, "sql_crud", Some(user_provider)).await;
|
||||
|
||||
// 1. no auth
|
||||
let conn_re = MySqlPoolOptions::new()
|
||||
@@ -204,10 +203,13 @@ pub async fn test_mysql_crud(store_type: StorageType) {
|
||||
}
|
||||
|
||||
pub async fn test_postgres_auth(store_type: StorageType) {
|
||||
let user_provider = StaticUserProvider::try_from("cmd:greptime_user=greptime_pwd").unwrap();
|
||||
let user_provider = user_provider_from_option(
|
||||
&"static_user_provider:cmd:greptime_user=greptime_pwd".to_string(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let (addr, mut guard, fe_pg_server) =
|
||||
setup_pg_server_with_user_provider(store_type, "sql_crud", Some(Arc::new(user_provider)))
|
||||
.await;
|
||||
setup_pg_server_with_user_provider(store_type, "sql_crud", Some(user_provider)).await;
|
||||
|
||||
// 1. no auth
|
||||
let conn_re = PgPoolOptions::new()
|
||||
|
||||
Reference in New Issue
Block a user