diff --git a/src/servers/src/auth.rs b/src/servers/src/auth.rs index 7ab9994f7c..0ad3cd0213 100644 --- a/src/servers/src/auth.rs +++ b/src/servers/src/auth.rs @@ -14,13 +14,12 @@ pub mod user_provider; -pub const DEFAULT_USERNAME: &str = "greptime"; - use std::sync::Arc; use common_error::ext::BoxedError; use common_error::prelude::ErrorExt; use common_error::status_code::StatusCode; +use session::context::UserInfo; use snafu::{Backtrace, ErrorCompat, OptionExt, Snafu}; use crate::auth::user_provider::StaticUserProvider; @@ -52,31 +51,6 @@ pub enum Password<'a> { PgMD5(HashedPassword<'a>, Salt<'a>), } -#[derive(Clone, Debug)] -pub struct UserInfo { - username: String, -} - -impl Default for UserInfo { - fn default() -> Self { - Self { - username: DEFAULT_USERNAME.to_string(), - } - } -} - -impl UserInfo { - pub fn user_name(&self) -> &str { - &self.username - } - - pub fn new(username: impl Into) -> Self { - Self { - username: username.into(), - } - } -} - pub fn user_provider_from_option(opt: &String) -> Result { let (name, content) = opt.split_once(':').context(InvalidConfigSnafu { value: opt.to_string(), @@ -213,7 +187,7 @@ mod tests { ) .await; assert!(auth_result.is_ok()); - assert_eq!("greptime", auth_result.unwrap().user_name()); + assert_eq!("greptime", auth_result.unwrap().username()); // auth failed, unsupported password type let auth_result = user_provider diff --git a/src/servers/src/auth/user_provider.rs b/src/servers/src/auth/user_provider.rs index 262ae777fa..84edb725c7 100644 --- a/src/servers/src/auth/user_provider.rs +++ b/src/servers/src/auth/user_provider.rs @@ -21,13 +21,13 @@ use std::path::Path; use async_trait::async_trait; use digest; use digest::Digest; +use session::context::UserInfo; use sha1::Sha1; use snafu::{ensure, OptionExt, ResultExt}; use crate::auth::{ Error, HashedPassword, Identity, InvalidConfigSnafu, IoSnafu, Password, Result, Salt, - UnsupportedPasswordTypeSnafu, UserInfo, UserNotFoundSnafu, UserPasswordMismatchSnafu, - UserProvider, + UnsupportedPasswordTypeSnafu, UserNotFoundSnafu, UserPasswordMismatchSnafu, UserProvider, }; pub const STATIC_USER_PROVIDER: &str = "static_user_provider"; diff --git a/src/servers/src/context.rs b/src/servers/src/context.rs deleted file mode 100644 index b068faba52..0000000000 --- a/src/servers/src/context.rs +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright 2022 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -use snafu::OptionExt; - -use crate::auth::UserInfo; -use crate::error::{BuildingContextSnafu, Result}; - -type CtxFnRef = Arc bool + Send + Sync>; - -pub struct Context { - pub client_info: ClientInfo, - pub user_info: UserInfo, - pub quota: Quota, - pub predicates: Vec, -} - -impl Context { - pub fn add_predicate(&mut self, predicate: CtxFnRef) { - self.predicates.push(predicate); - } -} - -#[derive(Default)] -pub struct CtxBuilder { - client_addr: Option, - from_channel: Option, - user_info: Option, -} - -impl CtxBuilder { - pub fn new() -> CtxBuilder { - CtxBuilder::default() - } - - pub fn client_addr(mut self, addr: String) -> CtxBuilder { - self.client_addr = Some(addr); - self - } - - pub fn set_channel(mut self, channel: Channel) -> CtxBuilder { - self.from_channel = Some(channel); - self - } - - pub fn set_user_info(mut self, user_info: UserInfo) -> CtxBuilder { - self.user_info = Some(user_info); - self - } - - pub fn build(self) -> Result { - Ok(Context { - client_info: ClientInfo { - client_host: self.client_addr.context(BuildingContextSnafu { - err_msg: "unknown client addr while building ctx", - })?, - channel: self.from_channel.context(BuildingContextSnafu { - err_msg: "unknown channel while building ctx", - })?, - }, - user_info: self.user_info.context(BuildingContextSnafu { - err_msg: "missing user info while building ctx", - })?, - quota: Quota::default(), - predicates: vec![], - }) - } -} - -pub struct ClientInfo { - pub client_host: String, - pub channel: Channel, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum Channel { - Grpc, - Http, - Mysql, -} - -#[derive(Default)] -pub struct Quota { - pub total: u64, - pub consumed: u64, - pub estimated: u64, -} - -#[cfg(test)] -mod test { - - use std::sync::Arc; - - use crate::auth::UserInfo; - use crate::context::Channel::{self, Http}; - use crate::context::{ClientInfo, Context, CtxBuilder}; - - #[test] - fn test_predicate() { - let mut ctx = Context { - client_info: ClientInfo { - client_host: Default::default(), - channel: Channel::Grpc, - }, - user_info: UserInfo::new("greptime"), - quota: Default::default(), - predicates: vec![], - }; - ctx.add_predicate(Arc::new(|ctx: &Context| { - ctx.quota.total > ctx.quota.consumed - })); - ctx.quota.total = 10; - ctx.quota.consumed = 5; - - let predicates = ctx.predicates.clone(); - let mut re = true; - for predicate in predicates { - re &= predicate(&ctx); - } - assert!(re); - } - - #[test] - fn test_build() { - let ctx = CtxBuilder::new() - .client_addr("127.0.0.1:4001".to_string()) - .set_channel(Http) - .set_user_info(UserInfo::new("greptime")) - .build() - .unwrap(); - - assert_eq!(ctx.client_info.client_host, String::from("127.0.0.1:4001")); - - assert_eq!(ctx.quota.total, 0); - assert_eq!(ctx.quota.consumed, 0); - assert_eq!(ctx.quota.estimated, 0); - - assert_eq!(ctx.predicates.capacity(), 0); - } -} diff --git a/src/servers/src/http/authorize.rs b/src/servers/src/http/authorize.rs index 06fdb6a894..b9dadaccdf 100644 --- a/src/servers/src/http/authorize.rs +++ b/src/servers/src/http/authorize.rs @@ -19,10 +19,11 @@ use axum::response::Response; use common_telemetry::error; use futures::future::BoxFuture; use http_body::Body; +use session::context::UserInfo; use snafu::{OptionExt, ResultExt}; use tower_http::auth::AsyncAuthorizeRequest; -use crate::auth::{Identity, UserInfo, UserProviderRef}; +use crate::auth::{Identity, UserProviderRef}; use crate::error::{self, Result}; pub struct HttpAuth { @@ -174,11 +175,12 @@ mod tests { use axum::body::BoxBody; use axum::http; use hyper::Request; + use session::context::UserInfo; use tower_http::auth::AsyncAuthorizeRequest; use super::{auth_header, decode_basic, AuthScheme, HttpAuth}; use crate::auth::test::MockUserProvider; - use crate::auth::{UserInfo, UserProvider}; + use crate::auth::UserProvider; use crate::error; use crate::error::Result; @@ -194,7 +196,7 @@ mod tests { let auth_res = http_auth.authorize(req).await.unwrap(); let user_info: &UserInfo = auth_res.extensions().get().unwrap(); let default = UserInfo::default(); - assert_eq!(default.user_name(), user_info.user_name()); + assert_eq!(default.username(), user_info.username()); // In mock user provider, right username:password == "greptime:greptime" let mock_user_provider = Some(Arc::new(MockUserProvider {}) as Arc); @@ -208,7 +210,7 @@ mod tests { let req = http_auth.authorize(req).await.unwrap(); let user_info: &UserInfo = req.extensions().get().unwrap(); let default = UserInfo::default(); - assert_eq!(default.user_name(), user_info.user_name()); + assert_eq!(default.username(), user_info.username()); let req = mock_http_request_no_auth().unwrap(); let auth_res = http_auth.authorize(req).await; diff --git a/src/servers/src/http/handler.rs b/src/servers/src/http/handler.rs index 361a00dab2..623d9cbe02 100644 --- a/src/servers/src/http/handler.rs +++ b/src/servers/src/http/handler.rs @@ -24,9 +24,8 @@ use common_error::status_code::StatusCode; use common_telemetry::metric; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use session::context::QueryContext; +use session::context::{QueryContext, UserInfo}; -use crate::auth::UserInfo; use crate::http::{ApiState, JsonResponse}; #[derive(Debug, Default, Serialize, Deserialize, JsonSchema)] diff --git a/src/servers/src/lib.rs b/src/servers/src/lib.rs index e18caf7fa3..7e80333a1f 100644 --- a/src/servers/src/lib.rs +++ b/src/servers/src/lib.rs @@ -17,7 +17,6 @@ use serde::{Deserialize, Serialize}; pub mod auth; -pub mod context; pub mod error; pub mod grpc; pub mod http; diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index 483a4956b4..ae7de92ba1 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::net::SocketAddr; use std::sync::Arc; use std::time::Instant; @@ -22,13 +23,11 @@ use opensrv_mysql::{ AsyncMysqlShim, ErrorKind, InitWriter, ParamParser, QueryResultWriter, StatementMetaWriter, }; use rand::RngCore; +use session::context::Channel; use session::Session; use tokio::io::AsyncWrite; -use tokio::sync::RwLock; use crate::auth::{Identity, Password, UserProviderRef}; -use crate::context::Channel::Mysql; -use crate::context::{Context, CtxBuilder}; use crate::error::{self, Result}; use crate::mysql::writer::MysqlResultWriter; use crate::query_handler::SqlQueryHandlerRef; @@ -37,9 +36,6 @@ use crate::query_handler::SqlQueryHandlerRef; pub struct MysqlInstanceShim { query_handler: SqlQueryHandlerRef, salt: [u8; 20], - client_addr: String, - // TODO(LFC): Break `Context` struct into different fields in `Session`, each with its own purpose. - ctx: Arc>>, session: Arc, user_provider: Option, } @@ -47,7 +43,7 @@ pub struct MysqlInstanceShim { impl MysqlInstanceShim { pub fn create( query_handler: SqlQueryHandlerRef, - client_addr: String, + client_addr: SocketAddr, user_provider: Option, ) -> MysqlInstanceShim { // init a random salt @@ -66,9 +62,7 @@ impl MysqlInstanceShim { MysqlInstanceShim { query_handler, salt: scramble, - client_addr, - ctx: Arc::new(RwLock::new(None)), - session: Arc::new(Session::new()), + session: Arc::new(Session::new(client_addr, Channel::Mysql)), user_provider, } } @@ -115,11 +109,11 @@ impl AsyncMysqlShim for MysqlInstanceShi ) -> bool { // if not specified then **greptime** will be used let username = String::from_utf8_lossy(username); - let client_addr = self.client_addr.clone(); let mut user_info = None; + let addr = self.session.conn_info().client_host.to_string(); if let Some(user_provider) = &self.user_provider { - let user_id = Identity::UserId(&username, Some(&client_addr)); + let user_id = Identity::UserId(&username, Some(addr.as_str())); let password = match auth_plugin { "mysql_native_password" => Password::MysqlNativePassword(auth_data, salt), @@ -140,22 +134,9 @@ impl AsyncMysqlShim for MysqlInstanceShi } let user_info = user_info.unwrap_or_default(); - return match CtxBuilder::new() - .client_addr(client_addr) - .set_channel(Mysql) - .set_user_info(user_info) - .build() - { - Ok(ctx) => { - let mut a = self.ctx.write().await; - *a = Some(ctx); - true - } - Err(e) => { - error!(e; "create ctx failed when authing mysql conn"); - false - } - }; + self.session.set_user_info(user_info); + + true } async fn on_prepare<'a>(&'a mut self, _: &'a str, w: StatementMetaWriter<'a, W>) -> Result<()> { diff --git a/src/servers/src/mysql/server.rs b/src/servers/src/mysql/server.rs index 3bec0ebbbc..2a205094c7 100644 --- a/src/servers/src/mysql/server.rs +++ b/src/servers/src/mysql/server.rs @@ -127,11 +127,7 @@ impl MysqlServer { force_tls: bool, user_provider: Option, ) -> Result<()> { - let mut shim = MysqlInstanceShim::create( - query_handler, - stream.peer_addr()?.to_string(), - user_provider, - ); + let mut shim = MysqlInstanceShim::create(query_handler, stream.peer_addr()?, user_provider); let (mut r, w) = stream.into_split(); let mut w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w); let ops = IntermediaryOptions::default(); diff --git a/src/servers/tests/http/http_handler_test.rs b/src/servers/tests/http/http_handler_test.rs index b15d3844cd..05ed54cfb1 100644 --- a/src/servers/tests/http/http_handler_test.rs +++ b/src/servers/tests/http/http_handler_test.rs @@ -18,8 +18,8 @@ use axum::body::Body; use axum::extract::{Json, Query, RawBody, State}; use common_telemetry::metric; use metrics::counter; -use servers::auth::UserInfo; use servers::http::{handler as http_handler, script as script_handler, ApiState, JsonOutput}; +use session::context::UserInfo; use table::test_util::MemTable; use crate::{create_testing_script_handler, create_testing_sql_query_handler}; diff --git a/src/session/src/context.rs b/src/session/src/context.rs index 2a6f9bbe72..92d2cd8942 100644 --- a/src/session/src/context.rs +++ b/src/session/src/context.rs @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::net::SocketAddr; use std::sync::Arc; use arc_swap::ArcSwapOption; use common_telemetry::info; pub type QueryContextRef = Arc; +pub type ConnInfoRef = Arc; pub struct QueryContext { current_schema: ArcSwapOption, @@ -58,3 +60,78 @@ impl QueryContext { ) } } + +pub const DEFAULT_USERNAME: &str = "greptime"; + +#[derive(Clone, Debug)] +pub struct UserInfo { + username: String, +} + +impl Default for UserInfo { + fn default() -> Self { + Self { + username: DEFAULT_USERNAME.to_string(), + } + } +} + +impl UserInfo { + pub fn username(&self) -> &str { + self.username.as_str() + } + + pub fn new(username: impl Into) -> Self { + Self { + username: username.into(), + } + } +} + +pub struct ConnInfo { + pub client_host: SocketAddr, + pub channel: Channel, +} + +impl ConnInfo { + pub fn new(client_host: SocketAddr, channel: Channel) -> Self { + Self { + client_host, + channel, + } + } +} + +#[derive(Debug, PartialEq)] +pub enum Channel { + Grpc, + Http, + Mysql, + Postgres, + Opentsdb, + Influxdb, + Prometheus, +} + +#[cfg(test)] +mod test { + use crate::context::{Channel, UserInfo}; + use crate::Session; + + #[test] + fn test_session() { + let session = Session::new("127.0.0.1:9000".parse().unwrap(), Channel::Mysql); + // test user_info + assert_eq!(session.user_info().username(), "greptime"); + session.set_user_info(UserInfo::new("root")); + assert_eq!(session.user_info().username(), "root"); + + // test channel + assert_eq!(session.conn_info().channel, Channel::Mysql); + assert_eq!( + session.conn_info().client_host.ip().to_string(), + "127.0.0.1" + ); + assert_eq!(session.conn_info().client_host.port(), 9000); + } +} diff --git a/src/session/src/lib.rs b/src/session/src/lib.rs index 57437c3057..2e6d4de736 100644 --- a/src/session/src/lib.rs +++ b/src/session/src/lib.rs @@ -14,23 +14,38 @@ pub mod context; +use std::net::SocketAddr; use std::sync::Arc; -use crate::context::{QueryContext, QueryContextRef}; +use arc_swap::ArcSwap; + +use crate::context::{Channel, ConnInfo, ConnInfoRef, QueryContext, QueryContextRef, UserInfo}; -#[derive(Default)] pub struct Session { query_ctx: QueryContextRef, + user_info: ArcSwap, + conn_info: ConnInfoRef, } impl Session { - pub fn new() -> Self { + pub fn new(addr: SocketAddr, channel: Channel) -> Self { Session { query_ctx: Arc::new(QueryContext::new()), + user_info: ArcSwap::new(Arc::new(UserInfo::default())), + conn_info: Arc::new(ConnInfo::new(addr, channel)), } } pub fn context(&self) -> QueryContextRef { self.query_ctx.clone() } + pub fn conn_info(&self) -> ConnInfoRef { + self.conn_info.clone() + } + pub fn user_info(&self) -> Arc { + self.user_info.load().clone() + } + pub fn set_user_info(&self, user_info: UserInfo) { + self.user_info.store(Arc::new(user_info)); + } }