mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-08 14:22:58 +00:00
feat: support http basic authentication (#733)
* feat: support http auth * add some unit test and log * fix * cr * remove unused #[derive(Clone)]
This commit is contained in:
@@ -67,9 +67,13 @@ pub enum StatusCode {
|
||||
/// User not exist
|
||||
UserNotFound = 7000,
|
||||
/// Unsupported password type
|
||||
UnsupportedPwdType = 7001,
|
||||
UnsupportedPasswordType = 7001,
|
||||
/// Username and password does not match
|
||||
UserPwdMismatch = 7002,
|
||||
UserPasswordMismatch = 7002,
|
||||
/// Not found http authorization header
|
||||
AuthHeaderNotFound = 7003,
|
||||
/// Invalid http authorization header
|
||||
InvalidAuthHeader = 7004,
|
||||
// ====== End of auth related status code =====
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ api = { path = "../api" }
|
||||
async-trait = "0.1"
|
||||
axum = "0.6"
|
||||
axum-macros = "0.3"
|
||||
base64 = "0.13"
|
||||
bytes = "1.2"
|
||||
common-base = { path = "../common/base" }
|
||||
common-catalog = { path = "../common/catalog" }
|
||||
@@ -23,6 +24,7 @@ common-time = { path = "../common/time" }
|
||||
datatypes = { path = "../datatypes" }
|
||||
futures = "0.3"
|
||||
hex = { version = "0.4" }
|
||||
http-body = "0.4"
|
||||
humantime-serde = "1.1"
|
||||
hyper = { version = "0.14", features = ["full"] }
|
||||
influxdb_line_protocol = { git = "https://github.com/evenyag/influxdb_iox", branch = "feat/line-protocol" }
|
||||
|
||||
@@ -24,7 +24,7 @@ use snafu::{Backtrace, ErrorCompat, Snafu};
|
||||
pub trait UserProvider: Send + Sync {
|
||||
fn name(&self) -> &str;
|
||||
|
||||
async fn auth(&self, id: Identity<'_>, pwd: Password<'_>) -> Result<UserInfo, Error>;
|
||||
async fn auth(&self, id: Identity<'_>, password: Password<'_>) -> Result<UserInfo, Error>;
|
||||
}
|
||||
|
||||
pub type UserProviderRef = Arc<dyn UserProvider>;
|
||||
@@ -37,17 +37,17 @@ pub enum Identity<'a> {
|
||||
UserId(Username<'a>, Option<HostOrIp<'a>>),
|
||||
}
|
||||
|
||||
pub type HashedPwd<'a> = &'a [u8];
|
||||
pub type HashedPassword<'a> = &'a [u8];
|
||||
pub type Salt<'a> = &'a [u8];
|
||||
pub type Pwd<'a> = &'a [u8];
|
||||
|
||||
/// Authentication information sent by the client.
|
||||
pub enum Password<'a> {
|
||||
PlainText(Pwd<'a>),
|
||||
MysqlNativePwd(HashedPwd<'a>, Salt<'a>),
|
||||
PgMD5(HashedPwd<'a>, Salt<'a>),
|
||||
PlainText(&'a str),
|
||||
MysqlNativePassword(HashedPassword<'a>, Salt<'a>),
|
||||
PgMD5(HashedPassword<'a>, Salt<'a>),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct UserInfo {
|
||||
username: String,
|
||||
}
|
||||
@@ -76,25 +76,25 @@ impl UserInfo {
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(visibility(pub))]
|
||||
pub enum Error {
|
||||
#[snafu(display("User not exist"))]
|
||||
UserNotExist { backtrace: Backtrace },
|
||||
#[snafu(display("User not found"))]
|
||||
UserNotFound { backtrace: Backtrace },
|
||||
|
||||
#[snafu(display("Unsupported Password Type: {}", pwd_type))]
|
||||
UnsupportedPwdType {
|
||||
pwd_type: String,
|
||||
#[snafu(display("Unsupported password type: {}", password_type))]
|
||||
UnsupportedPasswordType {
|
||||
password_type: String,
|
||||
backtrace: Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Username and password does not match"))]
|
||||
WrongPwd { backtrace: Backtrace },
|
||||
UserPasswordMismatch { backtrace: Backtrace },
|
||||
}
|
||||
|
||||
impl ErrorExt for Error {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
Error::UserNotExist { .. } => StatusCode::UserNotFound,
|
||||
Error::UnsupportedPwdType { .. } => StatusCode::UnsupportedPwdType,
|
||||
Error::WrongPwd { .. } => StatusCode::UserPwdMismatch,
|
||||
Error::UserNotFound { .. } => StatusCode::UserNotFound,
|
||||
Error::UnsupportedPasswordType { .. } => StatusCode::UnsupportedPasswordType,
|
||||
Error::UserPasswordMismatch { .. } => StatusCode::UserPasswordMismatch,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -108,11 +108,10 @@ impl ErrorExt for Error {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
pub mod test {
|
||||
use super::{Identity, Password, UserInfo, UserProvider};
|
||||
use crate::auth;
|
||||
|
||||
struct MockUserProvider {}
|
||||
pub struct MockUserProvider {}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl UserProvider for MockUserProvider {
|
||||
@@ -127,27 +126,34 @@ mod tests {
|
||||
) -> Result<UserInfo, super::Error> {
|
||||
match id {
|
||||
Identity::UserId(username, _host) => match password {
|
||||
Password::PlainText(pwd) => {
|
||||
Password::PlainText(password) => {
|
||||
if username == "greptime" {
|
||||
if pwd == b"greptime" {
|
||||
if password == "greptime" {
|
||||
return Ok(UserInfo {
|
||||
username: "greptime".to_string(),
|
||||
});
|
||||
} else {
|
||||
return super::WrongPwdSnafu {}.fail();
|
||||
return super::UserPasswordMismatchSnafu {}.fail();
|
||||
}
|
||||
} else {
|
||||
return super::UserNotExistSnafu {}.fail();
|
||||
return super::UserNotFoundSnafu {}.fail();
|
||||
}
|
||||
}
|
||||
_ => super::UnsupportedPwdTypeSnafu {
|
||||
pwd_type: "mysql_native_pwd",
|
||||
_ => super::UnsupportedPasswordTypeSnafu {
|
||||
password_type: "mysql_native_password",
|
||||
}
|
||||
.fail(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::test::MockUserProvider;
|
||||
use super::{Identity, Password, UserProvider};
|
||||
use crate::auth;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_auth_by_plain_text() {
|
||||
@@ -158,43 +164,46 @@ mod tests {
|
||||
let auth_result = user_provider
|
||||
.auth(
|
||||
Identity::UserId("greptime", None),
|
||||
Password::PlainText(b"greptime"),
|
||||
Password::PlainText("greptime"),
|
||||
)
|
||||
.await;
|
||||
assert!(auth_result.is_ok());
|
||||
assert_eq!("greptime", auth_result.unwrap().user_name());
|
||||
|
||||
// auth failed, unsupported pwd type
|
||||
// auth failed, unsupported password type
|
||||
let auth_result = user_provider
|
||||
.auth(
|
||||
Identity::UserId("greptime", None),
|
||||
Password::MysqlNativePwd(b"hashed_value", b"salt"),
|
||||
Password::MysqlNativePassword(b"hashed_value", b"salt"),
|
||||
)
|
||||
.await;
|
||||
assert!(auth_result.is_err());
|
||||
matches!(
|
||||
auth_result.err().unwrap(),
|
||||
auth::Error::UnsupportedPwdType { .. }
|
||||
auth::Error::UnsupportedPasswordType { .. }
|
||||
);
|
||||
|
||||
// auth failed, err: user not exist.
|
||||
let auth_result = user_provider
|
||||
.auth(
|
||||
Identity::UserId("not_exist_username", None),
|
||||
Password::PlainText(b"greptime"),
|
||||
Password::PlainText("greptime"),
|
||||
)
|
||||
.await;
|
||||
assert!(auth_result.is_err());
|
||||
matches!(auth_result.err().unwrap(), auth::Error::UserNotExist { .. });
|
||||
matches!(auth_result.err().unwrap(), auth::Error::UserNotFound { .. });
|
||||
|
||||
// auth failed, err: wrong password
|
||||
let auth_result = user_provider
|
||||
.auth(
|
||||
Identity::UserId("greptime", None),
|
||||
Password::PlainText(b"wrong_pwd"),
|
||||
Password::PlainText("wrong_password"),
|
||||
)
|
||||
.await;
|
||||
assert!(auth_result.is_err());
|
||||
matches!(auth_result.err().unwrap(), auth::Error::WrongPwd { .. });
|
||||
matches!(
|
||||
auth_result.err().unwrap(),
|
||||
auth::Error::UserPasswordMismatch { .. }
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,11 +85,11 @@ pub struct ClientInfo {
|
||||
pub channel: Channel,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum Channel {
|
||||
GRPC,
|
||||
HTTP,
|
||||
MYSQL,
|
||||
Grpc,
|
||||
Http,
|
||||
Mysql,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
@@ -105,7 +105,7 @@ mod test {
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::auth::UserInfo;
|
||||
use crate::context::Channel::{self, HTTP};
|
||||
use crate::context::Channel::{self, Http};
|
||||
use crate::context::{ClientInfo, Context, CtxBuilder};
|
||||
|
||||
#[test]
|
||||
@@ -113,7 +113,7 @@ mod test {
|
||||
let mut ctx = Context {
|
||||
client_info: ClientInfo {
|
||||
client_host: Default::default(),
|
||||
channel: Channel::GRPC,
|
||||
channel: Channel::Grpc,
|
||||
},
|
||||
user_info: UserInfo::new("greptime"),
|
||||
quota: Default::default(),
|
||||
@@ -137,7 +137,7 @@ mod test {
|
||||
fn test_build() {
|
||||
let ctx = CtxBuilder::new()
|
||||
.client_addr("127.0.0.1:4001".to_string())
|
||||
.set_channel(HTTP)
|
||||
.set_channel(Http)
|
||||
.set_user_info(UserInfo::new("greptime"))
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
@@ -14,11 +14,14 @@
|
||||
|
||||
use std::any::Any;
|
||||
use std::net::SocketAddr;
|
||||
use std::string::FromUtf8Error;
|
||||
|
||||
use axum::http::StatusCode as HttpStatusCode;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::Json;
|
||||
use base64::DecodeError;
|
||||
use common_error::prelude::*;
|
||||
use hyper::header::ToStrError;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::auth;
|
||||
@@ -203,6 +206,33 @@ pub enum Error {
|
||||
#[snafu(backtrace)]
|
||||
source: auth::Error,
|
||||
},
|
||||
|
||||
#[snafu(display("Not found http authorization header"))]
|
||||
NotFoundAuthHeader {},
|
||||
|
||||
#[snafu(display("Invalid visibility ASCII chars, source: {}", source))]
|
||||
InvisibleASCII {
|
||||
source: ToStrError,
|
||||
backtrace: Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Unsupported http auth scheme, name: {}", name))]
|
||||
UnsupportedAuthScheme { name: String },
|
||||
|
||||
#[snafu(display("Invalid http authorization header"))]
|
||||
InvalidAuthorizationHeader { backtrace: Backtrace },
|
||||
|
||||
#[snafu(display("Invalid base64 value, source: {:?}", source))]
|
||||
InvalidBase64Value {
|
||||
source: DecodeError,
|
||||
backtrace: Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Invalid utf-8 value, source: {:?}", source))]
|
||||
InvalidUtf8Value {
|
||||
source: FromUtf8Error,
|
||||
backtrace: Backtrace,
|
||||
},
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
@@ -248,6 +278,13 @@ impl ErrorExt for Error {
|
||||
TlsRequired { .. } => StatusCode::Unknown,
|
||||
StartFrontend { source, .. } => source.status_code(),
|
||||
Auth { source, .. } => source.status_code(),
|
||||
|
||||
NotFoundAuthHeader { .. } => StatusCode::AuthHeaderNotFound,
|
||||
InvisibleASCII { .. }
|
||||
| UnsupportedAuthScheme { .. }
|
||||
| InvalidAuthorizationHeader { .. }
|
||||
| InvalidBase64Value { .. }
|
||||
| InvalidUtf8Value { .. } => StatusCode::InvalidAuthHeader,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
mod context;
|
||||
mod authorize;
|
||||
pub mod handler;
|
||||
pub mod influxdb;
|
||||
pub mod opentsdb;
|
||||
@@ -26,8 +26,8 @@ use std::time::Duration;
|
||||
use aide::axum::{routing as apirouting, ApiRouter, IntoApiResponse};
|
||||
use aide::openapi::{Info, OpenApi, Server as OpenAPIServer};
|
||||
use async_trait::async_trait;
|
||||
use axum::body::BoxBody;
|
||||
use axum::error_handling::HandleErrorLayer;
|
||||
use axum::middleware::{self};
|
||||
use axum::response::{Html, Json};
|
||||
use axum::{routing, BoxError, Extension, Router};
|
||||
use common_error::prelude::ErrorExt;
|
||||
@@ -45,9 +45,12 @@ use tokio::sync::oneshot::{self, Sender};
|
||||
use tokio::sync::Mutex;
|
||||
use tower::timeout::TimeoutLayer;
|
||||
use tower::ServiceBuilder;
|
||||
use tower_http::auth::AsyncRequireAuthorizationLayer;
|
||||
use tower_http::trace::TraceLayer;
|
||||
|
||||
use self::authorize::HttpAuth;
|
||||
use self::influxdb::influxdb_write;
|
||||
use crate::auth::UserProviderRef;
|
||||
use crate::error::{AlreadyStartedSnafu, Result, StartHttpSnafu};
|
||||
use crate::query_handler::{
|
||||
InfluxdbLineProtocolHandlerRef, OpentsdbProtocolHandlerRef, PrometheusProtocolHandlerRef,
|
||||
@@ -65,6 +68,7 @@ pub struct HttpServer {
|
||||
prom_handler: Option<PrometheusProtocolHandlerRef>,
|
||||
script_handler: Option<ScriptHandlerRef>,
|
||||
shutdown_tx: Mutex<Option<Sender<()>>>,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
@@ -295,6 +299,7 @@ impl HttpServer {
|
||||
opentsdb_handler: None,
|
||||
influxdb_handler: None,
|
||||
prom_handler: None,
|
||||
user_provider: None,
|
||||
script_handler: None,
|
||||
shutdown_tx: Mutex::new(None),
|
||||
}
|
||||
@@ -332,6 +337,14 @@ impl HttpServer {
|
||||
self.prom_handler.get_or_insert(handler);
|
||||
}
|
||||
|
||||
pub fn set_user_provider(&mut self, user_provider: UserProviderRef) {
|
||||
debug_assert!(
|
||||
self.user_provider.is_none(),
|
||||
"User provider can be set only once!"
|
||||
);
|
||||
self.user_provider.get_or_insert(user_provider);
|
||||
}
|
||||
|
||||
pub fn make_app(&self) -> Router {
|
||||
let mut api = OpenApi {
|
||||
info: Info {
|
||||
@@ -393,7 +406,9 @@ impl HttpServer {
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.layer(TimeoutLayer::new(self.options.timeout))
|
||||
// custom layer
|
||||
.layer(middleware::from_fn(context::build_ctx)),
|
||||
.layer(AsyncRequireAuthorizationLayer::new(
|
||||
HttpAuth::<BoxBody>::new(self.user_provider.clone()),
|
||||
)),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
282
src/servers/src/http/authorize.rs
Normal file
282
src/servers/src/http/authorize.rs
Normal file
@@ -0,0 +1,282 @@
|
||||
// Copyright 2022 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use axum::http::{self, Request, StatusCode};
|
||||
use axum::response::Response;
|
||||
use common_telemetry::error;
|
||||
use futures::future::BoxFuture;
|
||||
use http_body::Body;
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
use tower_http::auth::AsyncAuthorizeRequest;
|
||||
|
||||
use crate::auth::{Identity, UserInfo, UserProviderRef};
|
||||
use crate::error::{self, Result};
|
||||
|
||||
pub struct HttpAuth<RespBody> {
|
||||
user_provider: Option<UserProviderRef>,
|
||||
_ty: PhantomData<RespBody>,
|
||||
}
|
||||
|
||||
impl<RespBody> HttpAuth<RespBody> {
|
||||
pub fn new(user_provider: Option<UserProviderRef>) -> Self {
|
||||
Self {
|
||||
user_provider,
|
||||
_ty: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<RespBody> Clone for HttpAuth<RespBody> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
user_provider: self.user_provider.clone(),
|
||||
_ty: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, RespBody> AsyncAuthorizeRequest<B> for HttpAuth<RespBody>
|
||||
where
|
||||
B: Send + Sync + 'static,
|
||||
RespBody: Body + Default,
|
||||
{
|
||||
type RequestBody = B;
|
||||
type ResponseBody = RespBody;
|
||||
type Future = BoxFuture<'static, std::result::Result<Request<B>, Response<Self::ResponseBody>>>;
|
||||
|
||||
fn authorize(&mut self, mut request: Request<B>) -> Self::Future {
|
||||
let user_provider = self.user_provider.clone();
|
||||
Box::pin(async move {
|
||||
let user_provider = if let Some(user_provider) = &user_provider {
|
||||
user_provider
|
||||
} else {
|
||||
request.extensions_mut().insert(UserInfo::default());
|
||||
return Ok(request);
|
||||
};
|
||||
|
||||
let (scheme, credential) = match auth_header(&request) {
|
||||
Ok(auth_header) => auth_header,
|
||||
Err(e) => {
|
||||
error!("failed to get http authorize header, err: {:?}", e);
|
||||
return Err(unauthorized_resp());
|
||||
}
|
||||
};
|
||||
|
||||
match scheme {
|
||||
AuthScheme::Basic => {
|
||||
let (username, password) = match decode_basic(credential) {
|
||||
Ok(basic_auth) => basic_auth,
|
||||
Err(e) => {
|
||||
error!("failed to decode basic authorize, err: {:?}", e);
|
||||
return Err(unauthorized_resp());
|
||||
}
|
||||
};
|
||||
match user_provider
|
||||
.auth(
|
||||
Identity::UserId(&username, None),
|
||||
crate::auth::Password::PlainText(&password),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(user_info) => {
|
||||
request.extensions_mut().insert(user_info);
|
||||
Ok(request)
|
||||
}
|
||||
Err(e) => {
|
||||
error!("failed to auth, err: {:?}", e);
|
||||
Err(unauthorized_resp())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn unauthorized_resp<RespBody>() -> Response<RespBody>
|
||||
where
|
||||
RespBody: Body + Default,
|
||||
{
|
||||
let mut res = Response::new(RespBody::default());
|
||||
*res.status_mut() = StatusCode::UNAUTHORIZED;
|
||||
res
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum AuthScheme {
|
||||
Basic,
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for AuthScheme {
|
||||
type Error = error::Error;
|
||||
|
||||
fn try_from(value: &str) -> Result<Self> {
|
||||
match value.to_lowercase().as_str() {
|
||||
"basic" => Ok(AuthScheme::Basic),
|
||||
other => error::UnsupportedAuthSchemeSnafu { name: other }.fail(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type Credential<'a> = &'a str;
|
||||
|
||||
fn auth_header<B>(req: &Request<B>) -> Result<(AuthScheme, Credential)> {
|
||||
let auth_header = req
|
||||
.headers()
|
||||
.get(http::header::AUTHORIZATION)
|
||||
.context(error::NotFoundAuthHeaderSnafu)?
|
||||
.to_str()
|
||||
.context(error::InvisibleASCIISnafu)?;
|
||||
|
||||
let (auth_scheme, encoded_credentials) = auth_header
|
||||
.split_once(' ')
|
||||
.context(error::InvalidAuthorizationHeaderSnafu)?;
|
||||
|
||||
if encoded_credentials.contains(' ') {
|
||||
return error::InvalidAuthorizationHeaderSnafu {}.fail();
|
||||
}
|
||||
|
||||
Ok((auth_scheme.try_into()?, encoded_credentials))
|
||||
}
|
||||
|
||||
type Username = String;
|
||||
type Password = String;
|
||||
|
||||
fn decode_basic(credential: Credential) -> Result<(Username, Password)> {
|
||||
let decoded = base64::decode(credential).context(error::InvalidBase64ValueSnafu)?;
|
||||
let as_utf8 = String::from_utf8(decoded).context(error::InvalidUtf8ValueSnafu)?;
|
||||
|
||||
if let Some((user_id, password)) = as_utf8.split_once(':') {
|
||||
return Ok((user_id.to_string(), password.to_string()));
|
||||
}
|
||||
|
||||
error::InvalidAuthorizationHeaderSnafu {}.fail()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::body::BoxBody;
|
||||
use axum::http;
|
||||
use hyper::Request;
|
||||
use tower_http::auth::AsyncAuthorizeRequest;
|
||||
|
||||
use super::{auth_header, decode_basic, AuthScheme, HttpAuth};
|
||||
use crate::auth::test::MockUserProvider;
|
||||
use crate::auth::{UserInfo, UserProvider};
|
||||
use crate::error;
|
||||
use crate::error::Result;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_http_auth() {
|
||||
let mut http_auth: HttpAuth<BoxBody> = HttpAuth {
|
||||
user_provider: None,
|
||||
_ty: PhantomData,
|
||||
};
|
||||
|
||||
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
|
||||
let req = mock_http_request("Basic dXNlcm5hbWU6cGFzc3dvcmQ=").unwrap();
|
||||
let auth_res = http_auth.authorize(req).await.unwrap();
|
||||
let user_info: &UserInfo = auth_res.extensions().get().unwrap();
|
||||
let default = UserInfo::default();
|
||||
assert_eq!(default.user_name(), user_info.user_name());
|
||||
|
||||
// In mock user provider, right username:password == "greptime:greptime"
|
||||
let mock_user_provider = Some(Arc::new(MockUserProvider {}) as Arc<dyn UserProvider>);
|
||||
let mut http_auth: HttpAuth<BoxBody> = HttpAuth {
|
||||
user_provider: mock_user_provider,
|
||||
_ty: PhantomData,
|
||||
};
|
||||
|
||||
// base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU="
|
||||
let req = mock_http_request("Basic Z3JlcHRpbWU6Z3JlcHRpbWU=").unwrap();
|
||||
let req = http_auth.authorize(req).await.unwrap();
|
||||
let user_info: &UserInfo = req.extensions().get().unwrap();
|
||||
let default = UserInfo::default();
|
||||
assert_eq!(default.user_name(), user_info.user_name());
|
||||
|
||||
let req = mock_http_request_no_auth().unwrap();
|
||||
let auth_res = http_auth.authorize(req).await;
|
||||
assert!(auth_res.is_err());
|
||||
|
||||
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
|
||||
let wrong_req = mock_http_request("Basic dXNlcm5hbWU6cGFzc3dvcmQ=").unwrap();
|
||||
let auth_res = http_auth.authorize(wrong_req).await;
|
||||
assert!(auth_res.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_basic() {
|
||||
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
|
||||
let credential = "dXNlcm5hbWU6cGFzc3dvcmQ=";
|
||||
let (username, pwd) = decode_basic(credential).unwrap();
|
||||
assert_eq!("username", username);
|
||||
assert_eq!("password", pwd);
|
||||
|
||||
let wrong_credential = "dXNlcm5hbWU6cG Fzc3dvcmQ=";
|
||||
let result = decode_basic(wrong_credential);
|
||||
matches!(result.err(), Some(error::Error::InvalidBase64Value { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_try_into_auth_scheme() {
|
||||
let auth_scheme_str = "basic";
|
||||
let auth_scheme: AuthScheme = auth_scheme_str.try_into().unwrap();
|
||||
matches!(auth_scheme, AuthScheme::Basic);
|
||||
|
||||
let unsupported = "digest";
|
||||
let auth_scheme: Result<AuthScheme> = unsupported.try_into();
|
||||
assert!(auth_scheme.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auth_header() {
|
||||
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
|
||||
let req = mock_http_request("Basic dXNlcm5hbWU6cGFzc3dvcmQ=").unwrap();
|
||||
|
||||
let (auth_scheme, credential) = auth_header(&req).unwrap();
|
||||
matches!(auth_scheme, AuthScheme::Basic);
|
||||
assert_eq!("dXNlcm5hbWU6cGFzc3dvcmQ=", credential);
|
||||
|
||||
let wrong_req = mock_http_request("Basic dXNlcm5hbWU6 cGFzc3dvcmQ=").unwrap();
|
||||
let res = auth_header(&wrong_req);
|
||||
matches!(
|
||||
res.err(),
|
||||
Some(error::Error::InvalidAuthorizationHeader { .. })
|
||||
);
|
||||
|
||||
let wrong_req = mock_http_request("Digest dXNlcm5hbWU6cGFzc3dvcmQ=").unwrap();
|
||||
let res = auth_header(&wrong_req);
|
||||
matches!(res.err(), Some(error::Error::UnsupportedAuthScheme { .. }));
|
||||
}
|
||||
|
||||
fn mock_http_request(auth_header: &str) -> Result<Request<()>> {
|
||||
Ok(Request::builder()
|
||||
.uri("https://www.rust-lang.org/")
|
||||
.header(http::header::AUTHORIZATION, auth_header)
|
||||
.body(())
|
||||
.unwrap())
|
||||
}
|
||||
|
||||
fn mock_http_request_no_auth() -> Result<Request<()>> {
|
||||
Ok(Request::builder()
|
||||
.uri("https://www.rust-lang.org/")
|
||||
.body(())
|
||||
.unwrap())
|
||||
}
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
// Copyright 2022 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use axum::http::{Request, StatusCode};
|
||||
use axum::middleware::Next;
|
||||
use axum::response::Response;
|
||||
|
||||
pub async fn build_ctx<B>(req: Request<B>, next: Next<B>) -> Result<Response, StatusCode> {
|
||||
// TODO(fys): auth and set context
|
||||
Ok(next.run(req).await)
|
||||
}
|
||||
@@ -18,12 +18,14 @@ use std::time::Instant;
|
||||
|
||||
use aide::transform::TransformOperation;
|
||||
use axum::extract::{Json, Query, State};
|
||||
use axum::Extension;
|
||||
use common_error::status_code::StatusCode;
|
||||
use common_telemetry::metric;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use session::context::QueryContext;
|
||||
|
||||
use crate::auth::UserInfo;
|
||||
use crate::http::{ApiState, JsonResponse};
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize, JsonSchema)]
|
||||
@@ -37,6 +39,8 @@ pub struct SqlQuery {
|
||||
pub async fn sql(
|
||||
State(state): State<ApiState>,
|
||||
Query(params): Query<SqlQuery>,
|
||||
// TODO(fys): pass _user_info into query context
|
||||
_user_info: Extension<UserInfo>,
|
||||
) -> Json<JsonResponse> {
|
||||
let sql_handler = &state.sql_handler;
|
||||
let start = Instant::now();
|
||||
|
||||
@@ -27,7 +27,7 @@ use tokio::io::AsyncWrite;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::auth::{Identity, Password, UserProviderRef};
|
||||
use crate::context::Channel::MYSQL;
|
||||
use crate::context::Channel::Mysql;
|
||||
use crate::context::{Context, CtxBuilder};
|
||||
use crate::error::{self, Result};
|
||||
use crate::mysql::writer::MysqlResultWriter;
|
||||
@@ -121,14 +121,14 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
|
||||
if let Some(user_provider) = &self.user_provider {
|
||||
let user_id = Identity::UserId(&username, Some(&client_addr));
|
||||
|
||||
let pwd = match auth_plugin {
|
||||
"mysql_native_password" => Password::MysqlNativePwd(auth_data, salt),
|
||||
let password = match auth_plugin {
|
||||
"mysql_native_password" => Password::MysqlNativePassword(auth_data, salt),
|
||||
other => {
|
||||
error!("Unsupported mysql auth plugin: {}", other);
|
||||
return false;
|
||||
}
|
||||
};
|
||||
match user_provider.auth(user_id, pwd).await {
|
||||
match user_provider.auth(user_id, password).await {
|
||||
Ok(userinfo) => {
|
||||
user_info = Some(userinfo);
|
||||
}
|
||||
@@ -142,7 +142,7 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
|
||||
|
||||
return match CtxBuilder::new()
|
||||
.client_addr(client_addr)
|
||||
.set_channel(MYSQL)
|
||||
.set_channel(Mysql)
|
||||
.set_user_info(user_info)
|
||||
.build()
|
||||
{
|
||||
|
||||
@@ -57,7 +57,7 @@ impl LoginInfo {
|
||||
}
|
||||
|
||||
impl PgPwdVerifier {
|
||||
async fn verify_pwd(&self, pwd: &str, login: LoginInfo) -> Result<bool> {
|
||||
async fn verify_pwd(&self, password: &str, login: LoginInfo) -> Result<bool> {
|
||||
if let Some(user_provider) = &self.user_provider {
|
||||
let user_name = match login.user {
|
||||
Some(name) => name,
|
||||
@@ -68,7 +68,7 @@ impl PgPwdVerifier {
|
||||
let _user_info = user_provider
|
||||
.auth(
|
||||
Identity::UserId(&user_name, None),
|
||||
Password::PlainText(pwd.as_bytes()),
|
||||
Password::PlainText(password),
|
||||
)
|
||||
.await
|
||||
.context(error::AuthSnafu)?;
|
||||
|
||||
@@ -18,6 +18,7 @@ use axum::body::Body;
|
||||
use axum::extract::{Json, Query, RawBody, State};
|
||||
use common_telemetry::metric;
|
||||
use metrics::counter;
|
||||
use servers::auth::UserInfo;
|
||||
use servers::http::{handler as http_handler, script as script_handler, ApiState, JsonOutput};
|
||||
use table::test_util::MemTable;
|
||||
|
||||
@@ -32,6 +33,7 @@ async fn test_sql_not_provided() {
|
||||
script_handler: None,
|
||||
}),
|
||||
Query(http_handler::SqlQuery::default()),
|
||||
axum::Extension(UserInfo::default()),
|
||||
)
|
||||
.await;
|
||||
assert!(!json.success());
|
||||
@@ -55,6 +57,7 @@ async fn test_sql_output_rows() {
|
||||
script_handler: None,
|
||||
}),
|
||||
query,
|
||||
axum::Extension(UserInfo::default()),
|
||||
)
|
||||
.await;
|
||||
assert!(json.success(), "{:?}", json);
|
||||
|
||||
Reference in New Issue
Block a user