mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-17 02:32:56 +00:00
refactor: merge servers::context into session (#811)
* refactor: move context to session * chore: add unit test * chore: add pg, opentsdb, influxdb and prometheus to channel enum
This commit is contained in:
@@ -14,13 +14,12 @@
|
||||
|
||||
pub mod user_provider;
|
||||
|
||||
pub const DEFAULT_USERNAME: &str = "greptime";
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_error::ext::BoxedError;
|
||||
use common_error::prelude::ErrorExt;
|
||||
use common_error::status_code::StatusCode;
|
||||
use session::context::UserInfo;
|
||||
use snafu::{Backtrace, ErrorCompat, OptionExt, Snafu};
|
||||
|
||||
use crate::auth::user_provider::StaticUserProvider;
|
||||
@@ -52,31 +51,6 @@ pub enum Password<'a> {
|
||||
PgMD5(HashedPassword<'a>, Salt<'a>),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct UserInfo {
|
||||
username: String,
|
||||
}
|
||||
|
||||
impl Default for UserInfo {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
username: DEFAULT_USERNAME.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserInfo {
|
||||
pub fn user_name(&self) -> &str {
|
||||
&self.username
|
||||
}
|
||||
|
||||
pub fn new(username: impl Into<String>) -> Self {
|
||||
Self {
|
||||
username: username.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn user_provider_from_option(opt: &String) -> Result<UserProviderRef> {
|
||||
let (name, content) = opt.split_once(':').context(InvalidConfigSnafu {
|
||||
value: opt.to_string(),
|
||||
@@ -213,7 +187,7 @@ mod tests {
|
||||
)
|
||||
.await;
|
||||
assert!(auth_result.is_ok());
|
||||
assert_eq!("greptime", auth_result.unwrap().user_name());
|
||||
assert_eq!("greptime", auth_result.unwrap().username());
|
||||
|
||||
// auth failed, unsupported password type
|
||||
let auth_result = user_provider
|
||||
|
||||
@@ -21,13 +21,13 @@ use std::path::Path;
|
||||
use async_trait::async_trait;
|
||||
use digest;
|
||||
use digest::Digest;
|
||||
use session::context::UserInfo;
|
||||
use sha1::Sha1;
|
||||
use snafu::{ensure, OptionExt, ResultExt};
|
||||
|
||||
use crate::auth::{
|
||||
Error, HashedPassword, Identity, InvalidConfigSnafu, IoSnafu, Password, Result, Salt,
|
||||
UnsupportedPasswordTypeSnafu, UserInfo, UserNotFoundSnafu, UserPasswordMismatchSnafu,
|
||||
UserProvider,
|
||||
UnsupportedPasswordTypeSnafu, UserNotFoundSnafu, UserPasswordMismatchSnafu, UserProvider,
|
||||
};
|
||||
|
||||
pub const STATIC_USER_PROVIDER: &str = "static_user_provider";
|
||||
|
||||
@@ -1,153 +0,0 @@
|
||||
// Copyright 2022 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use snafu::OptionExt;
|
||||
|
||||
use crate::auth::UserInfo;
|
||||
use crate::error::{BuildingContextSnafu, Result};
|
||||
|
||||
type CtxFnRef = Arc<dyn Fn(&Context) -> bool + Send + Sync>;
|
||||
|
||||
pub struct Context {
|
||||
pub client_info: ClientInfo,
|
||||
pub user_info: UserInfo,
|
||||
pub quota: Quota,
|
||||
pub predicates: Vec<CtxFnRef>,
|
||||
}
|
||||
|
||||
impl Context {
|
||||
pub fn add_predicate(&mut self, predicate: CtxFnRef) {
|
||||
self.predicates.push(predicate);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct CtxBuilder {
|
||||
client_addr: Option<String>,
|
||||
from_channel: Option<Channel>,
|
||||
user_info: Option<UserInfo>,
|
||||
}
|
||||
|
||||
impl CtxBuilder {
|
||||
pub fn new() -> CtxBuilder {
|
||||
CtxBuilder::default()
|
||||
}
|
||||
|
||||
pub fn client_addr(mut self, addr: String) -> CtxBuilder {
|
||||
self.client_addr = Some(addr);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn set_channel(mut self, channel: Channel) -> CtxBuilder {
|
||||
self.from_channel = Some(channel);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn set_user_info(mut self, user_info: UserInfo) -> CtxBuilder {
|
||||
self.user_info = Some(user_info);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Result<Context> {
|
||||
Ok(Context {
|
||||
client_info: ClientInfo {
|
||||
client_host: self.client_addr.context(BuildingContextSnafu {
|
||||
err_msg: "unknown client addr while building ctx",
|
||||
})?,
|
||||
channel: self.from_channel.context(BuildingContextSnafu {
|
||||
err_msg: "unknown channel while building ctx",
|
||||
})?,
|
||||
},
|
||||
user_info: self.user_info.context(BuildingContextSnafu {
|
||||
err_msg: "missing user info while building ctx",
|
||||
})?,
|
||||
quota: Quota::default(),
|
||||
predicates: vec![],
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ClientInfo {
|
||||
pub client_host: String,
|
||||
pub channel: Channel,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum Channel {
|
||||
Grpc,
|
||||
Http,
|
||||
Mysql,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Quota {
|
||||
pub total: u64,
|
||||
pub consumed: u64,
|
||||
pub estimated: u64,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::auth::UserInfo;
|
||||
use crate::context::Channel::{self, Http};
|
||||
use crate::context::{ClientInfo, Context, CtxBuilder};
|
||||
|
||||
#[test]
|
||||
fn test_predicate() {
|
||||
let mut ctx = Context {
|
||||
client_info: ClientInfo {
|
||||
client_host: Default::default(),
|
||||
channel: Channel::Grpc,
|
||||
},
|
||||
user_info: UserInfo::new("greptime"),
|
||||
quota: Default::default(),
|
||||
predicates: vec![],
|
||||
};
|
||||
ctx.add_predicate(Arc::new(|ctx: &Context| {
|
||||
ctx.quota.total > ctx.quota.consumed
|
||||
}));
|
||||
ctx.quota.total = 10;
|
||||
ctx.quota.consumed = 5;
|
||||
|
||||
let predicates = ctx.predicates.clone();
|
||||
let mut re = true;
|
||||
for predicate in predicates {
|
||||
re &= predicate(&ctx);
|
||||
}
|
||||
assert!(re);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build() {
|
||||
let ctx = CtxBuilder::new()
|
||||
.client_addr("127.0.0.1:4001".to_string())
|
||||
.set_channel(Http)
|
||||
.set_user_info(UserInfo::new("greptime"))
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(ctx.client_info.client_host, String::from("127.0.0.1:4001"));
|
||||
|
||||
assert_eq!(ctx.quota.total, 0);
|
||||
assert_eq!(ctx.quota.consumed, 0);
|
||||
assert_eq!(ctx.quota.estimated, 0);
|
||||
|
||||
assert_eq!(ctx.predicates.capacity(), 0);
|
||||
}
|
||||
}
|
||||
@@ -19,10 +19,11 @@ use axum::response::Response;
|
||||
use common_telemetry::error;
|
||||
use futures::future::BoxFuture;
|
||||
use http_body::Body;
|
||||
use session::context::UserInfo;
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
use tower_http::auth::AsyncAuthorizeRequest;
|
||||
|
||||
use crate::auth::{Identity, UserInfo, UserProviderRef};
|
||||
use crate::auth::{Identity, UserProviderRef};
|
||||
use crate::error::{self, Result};
|
||||
|
||||
pub struct HttpAuth<RespBody> {
|
||||
@@ -174,11 +175,12 @@ mod tests {
|
||||
use axum::body::BoxBody;
|
||||
use axum::http;
|
||||
use hyper::Request;
|
||||
use session::context::UserInfo;
|
||||
use tower_http::auth::AsyncAuthorizeRequest;
|
||||
|
||||
use super::{auth_header, decode_basic, AuthScheme, HttpAuth};
|
||||
use crate::auth::test::MockUserProvider;
|
||||
use crate::auth::{UserInfo, UserProvider};
|
||||
use crate::auth::UserProvider;
|
||||
use crate::error;
|
||||
use crate::error::Result;
|
||||
|
||||
@@ -194,7 +196,7 @@ mod tests {
|
||||
let auth_res = http_auth.authorize(req).await.unwrap();
|
||||
let user_info: &UserInfo = auth_res.extensions().get().unwrap();
|
||||
let default = UserInfo::default();
|
||||
assert_eq!(default.user_name(), user_info.user_name());
|
||||
assert_eq!(default.username(), user_info.username());
|
||||
|
||||
// In mock user provider, right username:password == "greptime:greptime"
|
||||
let mock_user_provider = Some(Arc::new(MockUserProvider {}) as Arc<dyn UserProvider>);
|
||||
@@ -208,7 +210,7 @@ mod tests {
|
||||
let req = http_auth.authorize(req).await.unwrap();
|
||||
let user_info: &UserInfo = req.extensions().get().unwrap();
|
||||
let default = UserInfo::default();
|
||||
assert_eq!(default.user_name(), user_info.user_name());
|
||||
assert_eq!(default.username(), user_info.username());
|
||||
|
||||
let req = mock_http_request_no_auth().unwrap();
|
||||
let auth_res = http_auth.authorize(req).await;
|
||||
|
||||
@@ -24,9 +24,8 @@ use common_error::status_code::StatusCode;
|
||||
use common_telemetry::metric;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use session::context::QueryContext;
|
||||
use session::context::{QueryContext, UserInfo};
|
||||
|
||||
use crate::auth::UserInfo;
|
||||
use crate::http::{ApiState, JsonResponse};
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize, JsonSchema)]
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub mod auth;
|
||||
pub mod context;
|
||||
pub mod error;
|
||||
pub mod grpc;
|
||||
pub mod http;
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
@@ -22,13 +23,11 @@ use opensrv_mysql::{
|
||||
AsyncMysqlShim, ErrorKind, InitWriter, ParamParser, QueryResultWriter, StatementMetaWriter,
|
||||
};
|
||||
use rand::RngCore;
|
||||
use session::context::Channel;
|
||||
use session::Session;
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::auth::{Identity, Password, UserProviderRef};
|
||||
use crate::context::Channel::Mysql;
|
||||
use crate::context::{Context, CtxBuilder};
|
||||
use crate::error::{self, Result};
|
||||
use crate::mysql::writer::MysqlResultWriter;
|
||||
use crate::query_handler::SqlQueryHandlerRef;
|
||||
@@ -37,9 +36,6 @@ use crate::query_handler::SqlQueryHandlerRef;
|
||||
pub struct MysqlInstanceShim {
|
||||
query_handler: SqlQueryHandlerRef,
|
||||
salt: [u8; 20],
|
||||
client_addr: String,
|
||||
// TODO(LFC): Break `Context` struct into different fields in `Session`, each with its own purpose.
|
||||
ctx: Arc<RwLock<Option<Context>>>,
|
||||
session: Arc<Session>,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
}
|
||||
@@ -47,7 +43,7 @@ pub struct MysqlInstanceShim {
|
||||
impl MysqlInstanceShim {
|
||||
pub fn create(
|
||||
query_handler: SqlQueryHandlerRef,
|
||||
client_addr: String,
|
||||
client_addr: SocketAddr,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
) -> MysqlInstanceShim {
|
||||
// init a random salt
|
||||
@@ -66,9 +62,7 @@ impl MysqlInstanceShim {
|
||||
MysqlInstanceShim {
|
||||
query_handler,
|
||||
salt: scramble,
|
||||
client_addr,
|
||||
ctx: Arc::new(RwLock::new(None)),
|
||||
session: Arc::new(Session::new()),
|
||||
session: Arc::new(Session::new(client_addr, Channel::Mysql)),
|
||||
user_provider,
|
||||
}
|
||||
}
|
||||
@@ -115,11 +109,11 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
|
||||
) -> bool {
|
||||
// if not specified then **greptime** will be used
|
||||
let username = String::from_utf8_lossy(username);
|
||||
let client_addr = self.client_addr.clone();
|
||||
|
||||
let mut user_info = None;
|
||||
let addr = self.session.conn_info().client_host.to_string();
|
||||
if let Some(user_provider) = &self.user_provider {
|
||||
let user_id = Identity::UserId(&username, Some(&client_addr));
|
||||
let user_id = Identity::UserId(&username, Some(addr.as_str()));
|
||||
|
||||
let password = match auth_plugin {
|
||||
"mysql_native_password" => Password::MysqlNativePassword(auth_data, salt),
|
||||
@@ -140,22 +134,9 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
|
||||
}
|
||||
let user_info = user_info.unwrap_or_default();
|
||||
|
||||
return match CtxBuilder::new()
|
||||
.client_addr(client_addr)
|
||||
.set_channel(Mysql)
|
||||
.set_user_info(user_info)
|
||||
.build()
|
||||
{
|
||||
Ok(ctx) => {
|
||||
let mut a = self.ctx.write().await;
|
||||
*a = Some(ctx);
|
||||
true
|
||||
}
|
||||
Err(e) => {
|
||||
error!(e; "create ctx failed when authing mysql conn");
|
||||
false
|
||||
}
|
||||
};
|
||||
self.session.set_user_info(user_info);
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
async fn on_prepare<'a>(&'a mut self, _: &'a str, w: StatementMetaWriter<'a, W>) -> Result<()> {
|
||||
|
||||
@@ -127,11 +127,7 @@ impl MysqlServer {
|
||||
force_tls: bool,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
) -> Result<()> {
|
||||
let mut shim = MysqlInstanceShim::create(
|
||||
query_handler,
|
||||
stream.peer_addr()?.to_string(),
|
||||
user_provider,
|
||||
);
|
||||
let mut shim = MysqlInstanceShim::create(query_handler, stream.peer_addr()?, user_provider);
|
||||
let (mut r, w) = stream.into_split();
|
||||
let mut w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w);
|
||||
let ops = IntermediaryOptions::default();
|
||||
|
||||
@@ -18,8 +18,8 @@ use axum::body::Body;
|
||||
use axum::extract::{Json, Query, RawBody, State};
|
||||
use common_telemetry::metric;
|
||||
use metrics::counter;
|
||||
use servers::auth::UserInfo;
|
||||
use servers::http::{handler as http_handler, script as script_handler, ApiState, JsonOutput};
|
||||
use session::context::UserInfo;
|
||||
use table::test_util::MemTable;
|
||||
|
||||
use crate::{create_testing_script_handler, create_testing_sql_query_handler};
|
||||
|
||||
@@ -12,12 +12,14 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwapOption;
|
||||
use common_telemetry::info;
|
||||
|
||||
pub type QueryContextRef = Arc<QueryContext>;
|
||||
pub type ConnInfoRef = Arc<ConnInfo>;
|
||||
|
||||
pub struct QueryContext {
|
||||
current_schema: ArcSwapOption<String>,
|
||||
@@ -58,3 +60,78 @@ impl QueryContext {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub const DEFAULT_USERNAME: &str = "greptime";
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct UserInfo {
|
||||
username: String,
|
||||
}
|
||||
|
||||
impl Default for UserInfo {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
username: DEFAULT_USERNAME.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserInfo {
|
||||
pub fn username(&self) -> &str {
|
||||
self.username.as_str()
|
||||
}
|
||||
|
||||
pub fn new(username: impl Into<String>) -> Self {
|
||||
Self {
|
||||
username: username.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ConnInfo {
|
||||
pub client_host: SocketAddr,
|
||||
pub channel: Channel,
|
||||
}
|
||||
|
||||
impl ConnInfo {
|
||||
pub fn new(client_host: SocketAddr, channel: Channel) -> Self {
|
||||
Self {
|
||||
client_host,
|
||||
channel,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum Channel {
|
||||
Grpc,
|
||||
Http,
|
||||
Mysql,
|
||||
Postgres,
|
||||
Opentsdb,
|
||||
Influxdb,
|
||||
Prometheus,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::context::{Channel, UserInfo};
|
||||
use crate::Session;
|
||||
|
||||
#[test]
|
||||
fn test_session() {
|
||||
let session = Session::new("127.0.0.1:9000".parse().unwrap(), Channel::Mysql);
|
||||
// test user_info
|
||||
assert_eq!(session.user_info().username(), "greptime");
|
||||
session.set_user_info(UserInfo::new("root"));
|
||||
assert_eq!(session.user_info().username(), "root");
|
||||
|
||||
// test channel
|
||||
assert_eq!(session.conn_info().channel, Channel::Mysql);
|
||||
assert_eq!(
|
||||
session.conn_info().client_host.ip().to_string(),
|
||||
"127.0.0.1"
|
||||
);
|
||||
assert_eq!(session.conn_info().client_host.port(), 9000);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,23 +14,38 @@
|
||||
|
||||
pub mod context;
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::context::{QueryContext, QueryContextRef};
|
||||
use arc_swap::ArcSwap;
|
||||
|
||||
use crate::context::{Channel, ConnInfo, ConnInfoRef, QueryContext, QueryContextRef, UserInfo};
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Session {
|
||||
query_ctx: QueryContextRef,
|
||||
user_info: ArcSwap<UserInfo>,
|
||||
conn_info: ConnInfoRef,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub fn new() -> Self {
|
||||
pub fn new(addr: SocketAddr, channel: Channel) -> Self {
|
||||
Session {
|
||||
query_ctx: Arc::new(QueryContext::new()),
|
||||
user_info: ArcSwap::new(Arc::new(UserInfo::default())),
|
||||
conn_info: Arc::new(ConnInfo::new(addr, channel)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn context(&self) -> QueryContextRef {
|
||||
self.query_ctx.clone()
|
||||
}
|
||||
pub fn conn_info(&self) -> ConnInfoRef {
|
||||
self.conn_info.clone()
|
||||
}
|
||||
pub fn user_info(&self) -> Arc<UserInfo> {
|
||||
self.user_info.load().clone()
|
||||
}
|
||||
pub fn set_user_info(&self, user_info: UserInfo) {
|
||||
self.user_info.store(Arc::new(user_info));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user