mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-04 12:22:55 +00:00
chore(http): change authorization header (#5389)
* chore/change-authorization-header: ### Add Custom Authorization Header Support - **Files Modified**: `http.rs`, `authorize.rs`, `authorize.rs` (tests) - **Key Changes**: - Introduced a custom authorization header `x-greptime-auth` in `http.rs`. - Updated authorization logic in `authorize.rs` to support both `x-greptime-auth` and the standard `Authorization` header. - Enhanced test cases in `authorize.rs` to validate the new custom header functionality. * chore: add more tests
This commit is contained in:
@@ -104,6 +104,9 @@ pub const HTTP_API_PREFIX: &str = "/v1/";
|
||||
/// Default http body limit (64M).
|
||||
const DEFAULT_BODY_LIMIT: ReadableSize = ReadableSize::mb(64);
|
||||
|
||||
/// Authorization header
|
||||
pub const AUTHORIZATION_HEADER: &str = "x-greptime-auth";
|
||||
|
||||
// TODO(fys): This is a temporary workaround, it will be improved later
|
||||
pub static PUBLIC_APIS: [&str; 2] = ["/v1/influxdb/ping", "/v1/influxdb/health"];
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ use session::context::QueryContextBuilder;
|
||||
use snafu::{ensure, OptionExt, ResultExt};
|
||||
|
||||
use super::header::{GreptimeDbName, GREPTIME_TIMEZONE_HEADER_NAME};
|
||||
use super::PUBLIC_APIS;
|
||||
use super::{AUTHORIZATION_HEADER, PUBLIC_APIS};
|
||||
use crate::error::{
|
||||
self, InvalidAuthHeaderInvisibleASCIISnafu, InvalidAuthHeaderSnafu, InvalidParameterSnafu,
|
||||
NotFoundInfluxAuthSnafu, Result, UnsupportedAuthSchemeSnafu, UrlDecodeSnafu,
|
||||
@@ -246,7 +246,8 @@ type Credential<'a> = &'a str;
|
||||
fn auth_header<B>(req: &Request<B>) -> Result<AuthScheme> {
|
||||
let auth_header = req
|
||||
.headers()
|
||||
.get(http::header::AUTHORIZATION)
|
||||
.get(AUTHORIZATION_HEADER)
|
||||
.or_else(|| req.headers().get(http::header::AUTHORIZATION))
|
||||
.context(error::NotFoundAuthHeaderSnafu)?
|
||||
.to_str()
|
||||
.context(InvalidAuthHeaderInvisibleASCIISnafu)?;
|
||||
|
||||
@@ -20,12 +20,12 @@ use axum::http;
|
||||
use http_body::Body;
|
||||
use hyper::{Request, StatusCode};
|
||||
use servers::http::authorize::inner_auth;
|
||||
use servers::http::AUTHORIZATION_HEADER;
|
||||
use session::context::QueryContext;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_http_auth() {
|
||||
async fn check_http_auth(header_key: &str) {
|
||||
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
|
||||
let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
|
||||
let req = mock_http_request(header_key, Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
|
||||
let req = inner_auth(None, req).await.unwrap();
|
||||
let ctx: &QueryContext = req.extensions().get().unwrap();
|
||||
let user_info = ctx.current_user();
|
||||
@@ -36,14 +36,14 @@ async fn test_http_auth() {
|
||||
let mock_user_provider = Some(Arc::new(MockUserProvider::default()) as Arc<dyn UserProvider>);
|
||||
|
||||
// base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU="
|
||||
let req = mock_http_request(Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="), None).unwrap();
|
||||
let req = mock_http_request(header_key, Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="), None).unwrap();
|
||||
let req = inner_auth(mock_user_provider.clone(), req).await.unwrap();
|
||||
let ctx: &QueryContext = req.extensions().get().unwrap();
|
||||
let user_info = ctx.current_user();
|
||||
let default = auth::userinfo_by_name(None);
|
||||
assert_eq!(default.username(), user_info.username());
|
||||
|
||||
let req = mock_http_request(None, None).unwrap();
|
||||
let req = mock_http_request(header_key, None, None).unwrap();
|
||||
let auth_res = inner_auth(mock_user_provider.clone(), req).await;
|
||||
assert!(auth_res.is_err());
|
||||
let mut resp = auth_res.unwrap_err();
|
||||
@@ -54,7 +54,8 @@ async fn test_http_auth() {
|
||||
);
|
||||
|
||||
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
|
||||
let wrong_req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
|
||||
let wrong_req =
|
||||
mock_http_request(header_key, Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
|
||||
let auth_res = inner_auth(mock_user_provider, wrong_req).await;
|
||||
assert!(auth_res.is_err());
|
||||
let mut resp = auth_res.unwrap_err();
|
||||
@@ -66,7 +67,12 @@ async fn test_http_auth() {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_schema_validating() {
|
||||
async fn test_http_auth() {
|
||||
check_http_auth(http::header::AUTHORIZATION.as_str()).await;
|
||||
check_http_auth(AUTHORIZATION_HEADER).await;
|
||||
}
|
||||
|
||||
async fn check_schema_validating(header: &str) {
|
||||
// In mock user provider, right username:password == "greptime:greptime"
|
||||
let mock_user_provider = Some(Arc::new(MockUserProvider::default()) as Arc<dyn UserProvider>);
|
||||
|
||||
@@ -74,6 +80,7 @@ async fn test_schema_validating() {
|
||||
// http://localhost/{http_api_version}/sql?db=greptime
|
||||
let version = servers::http::HTTP_API_VERSION;
|
||||
let req = mock_http_request(
|
||||
header,
|
||||
Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="),
|
||||
Some(format!("http://localhost/{version}/sql?db=public").as_str()),
|
||||
)
|
||||
@@ -86,6 +93,7 @@ async fn test_schema_validating() {
|
||||
|
||||
// wrong database
|
||||
let req = mock_http_request(
|
||||
header,
|
||||
Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="),
|
||||
Some(format!("http://localhost/{version}/sql?db=wrong").as_str()),
|
||||
)
|
||||
@@ -101,13 +109,18 @@ async fn test_schema_validating() {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_whitelist_no_auth() {
|
||||
async fn test_schema_validating() {
|
||||
check_schema_validating(http::header::AUTHORIZATION.as_str()).await;
|
||||
check_schema_validating(AUTHORIZATION_HEADER).await;
|
||||
}
|
||||
|
||||
async fn check_auth_header(header_key: &str) {
|
||||
// In mock user provider, right username:password == "greptime:greptime"
|
||||
let mock_user_provider = Some(Arc::new(MockUserProvider::default()) as Arc<dyn UserProvider>);
|
||||
|
||||
// base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU="
|
||||
// try auth path first
|
||||
let req = mock_http_request(None, None).unwrap();
|
||||
let req = mock_http_request(header_key, None, None).unwrap();
|
||||
let auth_res = inner_auth(mock_user_provider.clone(), req).await;
|
||||
assert!(auth_res.is_err());
|
||||
let mut resp = auth_res.unwrap_err();
|
||||
@@ -118,13 +131,20 @@ async fn test_whitelist_no_auth() {
|
||||
);
|
||||
|
||||
// try whitelist path
|
||||
let req = mock_http_request(None, Some("http://localhost/health")).unwrap();
|
||||
let req = mock_http_request(header_key, None, Some("http://localhost/health")).unwrap();
|
||||
let req = inner_auth(mock_user_provider, req).await;
|
||||
assert!(req.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_whitelist_no_auth() {
|
||||
check_auth_header(http::header::AUTHORIZATION.as_str()).await;
|
||||
check_auth_header(AUTHORIZATION_HEADER).await;
|
||||
}
|
||||
|
||||
// copy from http::authorize
|
||||
fn mock_http_request(
|
||||
auth_header_key: &str,
|
||||
auth_header: Option<&str>,
|
||||
uri: Option<&str>,
|
||||
) -> servers::error::Result<Request<()>> {
|
||||
@@ -132,7 +152,7 @@ fn mock_http_request(
|
||||
let mut req = Request::builder()
|
||||
.uri(uri.unwrap_or(format!("http://localhost/{http_api_version}/sql?db=public").as_str()));
|
||||
if let Some(auth_header) = auth_header {
|
||||
req = req.header(http::header::AUTHORIZATION, auth_header);
|
||||
req = req.header(auth_header_key, auth_header);
|
||||
}
|
||||
Ok(req.body(()).unwrap())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user