refactor: merge servers::context into session (#811)

* refactor: move context to session

* chore: add unit test

* chore: add pg, opentsdb, influxdb and prometheus to channel enum
This commit is contained in:
shuiyisong
2022-12-31 00:00:04 +08:00
committed by GitHub
parent 4d56d896ca
commit 179ff728df
11 changed files with 117 additions and 227 deletions

View File

@@ -14,13 +14,12 @@
pub mod user_provider;
pub const DEFAULT_USERNAME: &str = "greptime";
use std::sync::Arc;
use common_error::ext::BoxedError;
use common_error::prelude::ErrorExt;
use common_error::status_code::StatusCode;
use session::context::UserInfo;
use snafu::{Backtrace, ErrorCompat, OptionExt, Snafu};
use crate::auth::user_provider::StaticUserProvider;
@@ -52,31 +51,6 @@ pub enum Password<'a> {
PgMD5(HashedPassword<'a>, Salt<'a>),
}
#[derive(Clone, Debug)]
pub struct UserInfo {
username: String,
}
impl Default for UserInfo {
fn default() -> Self {
Self {
username: DEFAULT_USERNAME.to_string(),
}
}
}
impl UserInfo {
pub fn user_name(&self) -> &str {
&self.username
}
pub fn new(username: impl Into<String>) -> Self {
Self {
username: username.into(),
}
}
}
pub fn user_provider_from_option(opt: &String) -> Result<UserProviderRef> {
let (name, content) = opt.split_once(':').context(InvalidConfigSnafu {
value: opt.to_string(),
@@ -213,7 +187,7 @@ mod tests {
)
.await;
assert!(auth_result.is_ok());
assert_eq!("greptime", auth_result.unwrap().user_name());
assert_eq!("greptime", auth_result.unwrap().username());
// auth failed, unsupported password type
let auth_result = user_provider

View File

@@ -21,13 +21,13 @@ use std::path::Path;
use async_trait::async_trait;
use digest;
use digest::Digest;
use session::context::UserInfo;
use sha1::Sha1;
use snafu::{ensure, OptionExt, ResultExt};
use crate::auth::{
Error, HashedPassword, Identity, InvalidConfigSnafu, IoSnafu, Password, Result, Salt,
UnsupportedPasswordTypeSnafu, UserInfo, UserNotFoundSnafu, UserPasswordMismatchSnafu,
UserProvider,
UnsupportedPasswordTypeSnafu, UserNotFoundSnafu, UserPasswordMismatchSnafu, UserProvider,
};
pub const STATIC_USER_PROVIDER: &str = "static_user_provider";

View File

@@ -1,153 +0,0 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use snafu::OptionExt;
use crate::auth::UserInfo;
use crate::error::{BuildingContextSnafu, Result};
type CtxFnRef = Arc<dyn Fn(&Context) -> bool + Send + Sync>;
pub struct Context {
pub client_info: ClientInfo,
pub user_info: UserInfo,
pub quota: Quota,
pub predicates: Vec<CtxFnRef>,
}
impl Context {
pub fn add_predicate(&mut self, predicate: CtxFnRef) {
self.predicates.push(predicate);
}
}
#[derive(Default)]
pub struct CtxBuilder {
client_addr: Option<String>,
from_channel: Option<Channel>,
user_info: Option<UserInfo>,
}
impl CtxBuilder {
pub fn new() -> CtxBuilder {
CtxBuilder::default()
}
pub fn client_addr(mut self, addr: String) -> CtxBuilder {
self.client_addr = Some(addr);
self
}
pub fn set_channel(mut self, channel: Channel) -> CtxBuilder {
self.from_channel = Some(channel);
self
}
pub fn set_user_info(mut self, user_info: UserInfo) -> CtxBuilder {
self.user_info = Some(user_info);
self
}
pub fn build(self) -> Result<Context> {
Ok(Context {
client_info: ClientInfo {
client_host: self.client_addr.context(BuildingContextSnafu {
err_msg: "unknown client addr while building ctx",
})?,
channel: self.from_channel.context(BuildingContextSnafu {
err_msg: "unknown channel while building ctx",
})?,
},
user_info: self.user_info.context(BuildingContextSnafu {
err_msg: "missing user info while building ctx",
})?,
quota: Quota::default(),
predicates: vec![],
})
}
}
pub struct ClientInfo {
pub client_host: String,
pub channel: Channel,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Channel {
Grpc,
Http,
Mysql,
}
#[derive(Default)]
pub struct Quota {
pub total: u64,
pub consumed: u64,
pub estimated: u64,
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use crate::auth::UserInfo;
use crate::context::Channel::{self, Http};
use crate::context::{ClientInfo, Context, CtxBuilder};
#[test]
fn test_predicate() {
let mut ctx = Context {
client_info: ClientInfo {
client_host: Default::default(),
channel: Channel::Grpc,
},
user_info: UserInfo::new("greptime"),
quota: Default::default(),
predicates: vec![],
};
ctx.add_predicate(Arc::new(|ctx: &Context| {
ctx.quota.total > ctx.quota.consumed
}));
ctx.quota.total = 10;
ctx.quota.consumed = 5;
let predicates = ctx.predicates.clone();
let mut re = true;
for predicate in predicates {
re &= predicate(&ctx);
}
assert!(re);
}
#[test]
fn test_build() {
let ctx = CtxBuilder::new()
.client_addr("127.0.0.1:4001".to_string())
.set_channel(Http)
.set_user_info(UserInfo::new("greptime"))
.build()
.unwrap();
assert_eq!(ctx.client_info.client_host, String::from("127.0.0.1:4001"));
assert_eq!(ctx.quota.total, 0);
assert_eq!(ctx.quota.consumed, 0);
assert_eq!(ctx.quota.estimated, 0);
assert_eq!(ctx.predicates.capacity(), 0);
}
}

View File

@@ -19,10 +19,11 @@ use axum::response::Response;
use common_telemetry::error;
use futures::future::BoxFuture;
use http_body::Body;
use session::context::UserInfo;
use snafu::{OptionExt, ResultExt};
use tower_http::auth::AsyncAuthorizeRequest;
use crate::auth::{Identity, UserInfo, UserProviderRef};
use crate::auth::{Identity, UserProviderRef};
use crate::error::{self, Result};
pub struct HttpAuth<RespBody> {
@@ -174,11 +175,12 @@ mod tests {
use axum::body::BoxBody;
use axum::http;
use hyper::Request;
use session::context::UserInfo;
use tower_http::auth::AsyncAuthorizeRequest;
use super::{auth_header, decode_basic, AuthScheme, HttpAuth};
use crate::auth::test::MockUserProvider;
use crate::auth::{UserInfo, UserProvider};
use crate::auth::UserProvider;
use crate::error;
use crate::error::Result;
@@ -194,7 +196,7 @@ mod tests {
let auth_res = http_auth.authorize(req).await.unwrap();
let user_info: &UserInfo = auth_res.extensions().get().unwrap();
let default = UserInfo::default();
assert_eq!(default.user_name(), user_info.user_name());
assert_eq!(default.username(), user_info.username());
// In mock user provider, right username:password == "greptime:greptime"
let mock_user_provider = Some(Arc::new(MockUserProvider {}) as Arc<dyn UserProvider>);
@@ -208,7 +210,7 @@ mod tests {
let req = http_auth.authorize(req).await.unwrap();
let user_info: &UserInfo = req.extensions().get().unwrap();
let default = UserInfo::default();
assert_eq!(default.user_name(), user_info.user_name());
assert_eq!(default.username(), user_info.username());
let req = mock_http_request_no_auth().unwrap();
let auth_res = http_auth.authorize(req).await;

View File

@@ -24,9 +24,8 @@ use common_error::status_code::StatusCode;
use common_telemetry::metric;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use session::context::QueryContext;
use session::context::{QueryContext, UserInfo};
use crate::auth::UserInfo;
use crate::http::{ApiState, JsonResponse};
#[derive(Debug, Default, Serialize, Deserialize, JsonSchema)]

View File

@@ -17,7 +17,6 @@
use serde::{Deserialize, Serialize};
pub mod auth;
pub mod context;
pub mod error;
pub mod grpc;
pub mod http;

View File

@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Instant;
@@ -22,13 +23,11 @@ use opensrv_mysql::{
AsyncMysqlShim, ErrorKind, InitWriter, ParamParser, QueryResultWriter, StatementMetaWriter,
};
use rand::RngCore;
use session::context::Channel;
use session::Session;
use tokio::io::AsyncWrite;
use tokio::sync::RwLock;
use crate::auth::{Identity, Password, UserProviderRef};
use crate::context::Channel::Mysql;
use crate::context::{Context, CtxBuilder};
use crate::error::{self, Result};
use crate::mysql::writer::MysqlResultWriter;
use crate::query_handler::SqlQueryHandlerRef;
@@ -37,9 +36,6 @@ use crate::query_handler::SqlQueryHandlerRef;
pub struct MysqlInstanceShim {
query_handler: SqlQueryHandlerRef,
salt: [u8; 20],
client_addr: String,
// TODO(LFC): Break `Context` struct into different fields in `Session`, each with its own purpose.
ctx: Arc<RwLock<Option<Context>>>,
session: Arc<Session>,
user_provider: Option<UserProviderRef>,
}
@@ -47,7 +43,7 @@ pub struct MysqlInstanceShim {
impl MysqlInstanceShim {
pub fn create(
query_handler: SqlQueryHandlerRef,
client_addr: String,
client_addr: SocketAddr,
user_provider: Option<UserProviderRef>,
) -> MysqlInstanceShim {
// init a random salt
@@ -66,9 +62,7 @@ impl MysqlInstanceShim {
MysqlInstanceShim {
query_handler,
salt: scramble,
client_addr,
ctx: Arc::new(RwLock::new(None)),
session: Arc::new(Session::new()),
session: Arc::new(Session::new(client_addr, Channel::Mysql)),
user_provider,
}
}
@@ -115,11 +109,11 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
) -> bool {
// if not specified then **greptime** will be used
let username = String::from_utf8_lossy(username);
let client_addr = self.client_addr.clone();
let mut user_info = None;
let addr = self.session.conn_info().client_host.to_string();
if let Some(user_provider) = &self.user_provider {
let user_id = Identity::UserId(&username, Some(&client_addr));
let user_id = Identity::UserId(&username, Some(addr.as_str()));
let password = match auth_plugin {
"mysql_native_password" => Password::MysqlNativePassword(auth_data, salt),
@@ -140,22 +134,9 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
}
let user_info = user_info.unwrap_or_default();
return match CtxBuilder::new()
.client_addr(client_addr)
.set_channel(Mysql)
.set_user_info(user_info)
.build()
{
Ok(ctx) => {
let mut a = self.ctx.write().await;
*a = Some(ctx);
true
}
Err(e) => {
error!(e; "create ctx failed when authing mysql conn");
false
}
};
self.session.set_user_info(user_info);
true
}
async fn on_prepare<'a>(&'a mut self, _: &'a str, w: StatementMetaWriter<'a, W>) -> Result<()> {

View File

@@ -127,11 +127,7 @@ impl MysqlServer {
force_tls: bool,
user_provider: Option<UserProviderRef>,
) -> Result<()> {
let mut shim = MysqlInstanceShim::create(
query_handler,
stream.peer_addr()?.to_string(),
user_provider,
);
let mut shim = MysqlInstanceShim::create(query_handler, stream.peer_addr()?, user_provider);
let (mut r, w) = stream.into_split();
let mut w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w);
let ops = IntermediaryOptions::default();

View File

@@ -18,8 +18,8 @@ use axum::body::Body;
use axum::extract::{Json, Query, RawBody, State};
use common_telemetry::metric;
use metrics::counter;
use servers::auth::UserInfo;
use servers::http::{handler as http_handler, script as script_handler, ApiState, JsonOutput};
use session::context::UserInfo;
use table::test_util::MemTable;
use crate::{create_testing_script_handler, create_testing_sql_query_handler};

View File

@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::net::SocketAddr;
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use common_telemetry::info;
pub type QueryContextRef = Arc<QueryContext>;
pub type ConnInfoRef = Arc<ConnInfo>;
pub struct QueryContext {
current_schema: ArcSwapOption<String>,
@@ -58,3 +60,78 @@ impl QueryContext {
)
}
}
pub const DEFAULT_USERNAME: &str = "greptime";
#[derive(Clone, Debug)]
pub struct UserInfo {
username: String,
}
impl Default for UserInfo {
fn default() -> Self {
Self {
username: DEFAULT_USERNAME.to_string(),
}
}
}
impl UserInfo {
pub fn username(&self) -> &str {
self.username.as_str()
}
pub fn new(username: impl Into<String>) -> Self {
Self {
username: username.into(),
}
}
}
pub struct ConnInfo {
pub client_host: SocketAddr,
pub channel: Channel,
}
impl ConnInfo {
pub fn new(client_host: SocketAddr, channel: Channel) -> Self {
Self {
client_host,
channel,
}
}
}
#[derive(Debug, PartialEq)]
pub enum Channel {
Grpc,
Http,
Mysql,
Postgres,
Opentsdb,
Influxdb,
Prometheus,
}
#[cfg(test)]
mod test {
use crate::context::{Channel, UserInfo};
use crate::Session;
#[test]
fn test_session() {
let session = Session::new("127.0.0.1:9000".parse().unwrap(), Channel::Mysql);
// test user_info
assert_eq!(session.user_info().username(), "greptime");
session.set_user_info(UserInfo::new("root"));
assert_eq!(session.user_info().username(), "root");
// test channel
assert_eq!(session.conn_info().channel, Channel::Mysql);
assert_eq!(
session.conn_info().client_host.ip().to_string(),
"127.0.0.1"
);
assert_eq!(session.conn_info().client_host.port(), 9000);
}
}

View File

@@ -14,23 +14,38 @@
pub mod context;
use std::net::SocketAddr;
use std::sync::Arc;
use crate::context::{QueryContext, QueryContextRef};
use arc_swap::ArcSwap;
use crate::context::{Channel, ConnInfo, ConnInfoRef, QueryContext, QueryContextRef, UserInfo};
#[derive(Default)]
pub struct Session {
query_ctx: QueryContextRef,
user_info: ArcSwap<UserInfo>,
conn_info: ConnInfoRef,
}
impl Session {
pub fn new() -> Self {
pub fn new(addr: SocketAddr, channel: Channel) -> Self {
Session {
query_ctx: Arc::new(QueryContext::new()),
user_info: ArcSwap::new(Arc::new(UserInfo::default())),
conn_info: Arc::new(ConnInfo::new(addr, channel)),
}
}
pub fn context(&self) -> QueryContextRef {
self.query_ctx.clone()
}
pub fn conn_info(&self) -> ConnInfoRef {
self.conn_info.clone()
}
pub fn user_info(&self) -> Arc<UserInfo> {
self.user_info.load().clone()
}
pub fn set_user_info(&self, user_info: UserInfo) {
self.user_info.store(Arc::new(user_info));
}
}