chore(http): change authorization header (#5389)

* chore/change-authorization-header:
 ### Add Custom Authorization Header Support

 - **Files Modified**: `http.rs`, `authorize.rs`, `authorize.rs` (tests)
 - **Key Changes**:
   - Introduced a custom authorization header `x-greptime-auth` in `http.rs`.
   - Updated authorization logic in `authorize.rs` to support both `x-greptime-auth` and the standard `Authorization` header.
   - Enhanced test cases in `authorize.rs` to validate the new custom header functionality.

* chore: add more tests
This commit is contained in:
Lei, HUANG
2025-01-20 15:09:44 +08:00
committed by GitHub
parent 80790daae0
commit 64ce9d3744
3 changed files with 37 additions and 13 deletions

View File

@@ -104,6 +104,9 @@ pub const HTTP_API_PREFIX: &str = "/v1/";
/// Default http body limit (64M).
const DEFAULT_BODY_LIMIT: ReadableSize = ReadableSize::mb(64);
/// Authorization header
pub const AUTHORIZATION_HEADER: &str = "x-greptime-auth";
// TODO(fys): This is a temporary workaround, it will be improved later
pub static PUBLIC_APIS: [&str; 2] = ["/v1/influxdb/ping", "/v1/influxdb/health"];

View File

@@ -31,7 +31,7 @@ use session::context::QueryContextBuilder;
use snafu::{ensure, OptionExt, ResultExt};
use super::header::{GreptimeDbName, GREPTIME_TIMEZONE_HEADER_NAME};
use super::PUBLIC_APIS;
use super::{AUTHORIZATION_HEADER, PUBLIC_APIS};
use crate::error::{
self, InvalidAuthHeaderInvisibleASCIISnafu, InvalidAuthHeaderSnafu, InvalidParameterSnafu,
NotFoundInfluxAuthSnafu, Result, UnsupportedAuthSchemeSnafu, UrlDecodeSnafu,
@@ -246,7 +246,8 @@ type Credential<'a> = &'a str;
fn auth_header<B>(req: &Request<B>) -> Result<AuthScheme> {
let auth_header = req
.headers()
.get(http::header::AUTHORIZATION)
.get(AUTHORIZATION_HEADER)
.or_else(|| req.headers().get(http::header::AUTHORIZATION))
.context(error::NotFoundAuthHeaderSnafu)?
.to_str()
.context(InvalidAuthHeaderInvisibleASCIISnafu)?;

View File

@@ -20,12 +20,12 @@ use axum::http;
use http_body::Body;
use hyper::{Request, StatusCode};
use servers::http::authorize::inner_auth;
use servers::http::AUTHORIZATION_HEADER;
use session::context::QueryContext;
#[tokio::test]
async fn test_http_auth() {
async fn check_http_auth(header_key: &str) {
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
let req = mock_http_request(header_key, Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
let req = inner_auth(None, req).await.unwrap();
let ctx: &QueryContext = req.extensions().get().unwrap();
let user_info = ctx.current_user();
@@ -36,14 +36,14 @@ async fn test_http_auth() {
let mock_user_provider = Some(Arc::new(MockUserProvider::default()) as Arc<dyn UserProvider>);
// base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU="
let req = mock_http_request(Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="), None).unwrap();
let req = mock_http_request(header_key, Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="), None).unwrap();
let req = inner_auth(mock_user_provider.clone(), req).await.unwrap();
let ctx: &QueryContext = req.extensions().get().unwrap();
let user_info = ctx.current_user();
let default = auth::userinfo_by_name(None);
assert_eq!(default.username(), user_info.username());
let req = mock_http_request(None, None).unwrap();
let req = mock_http_request(header_key, None, None).unwrap();
let auth_res = inner_auth(mock_user_provider.clone(), req).await;
assert!(auth_res.is_err());
let mut resp = auth_res.unwrap_err();
@@ -54,7 +54,8 @@ async fn test_http_auth() {
);
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
let wrong_req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
let wrong_req =
mock_http_request(header_key, Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
let auth_res = inner_auth(mock_user_provider, wrong_req).await;
assert!(auth_res.is_err());
let mut resp = auth_res.unwrap_err();
@@ -66,7 +67,12 @@ async fn test_http_auth() {
}
#[tokio::test]
async fn test_schema_validating() {
async fn test_http_auth() {
check_http_auth(http::header::AUTHORIZATION.as_str()).await;
check_http_auth(AUTHORIZATION_HEADER).await;
}
async fn check_schema_validating(header: &str) {
// In mock user provider, right username:password == "greptime:greptime"
let mock_user_provider = Some(Arc::new(MockUserProvider::default()) as Arc<dyn UserProvider>);
@@ -74,6 +80,7 @@ async fn test_schema_validating() {
// http://localhost/{http_api_version}/sql?db=greptime
let version = servers::http::HTTP_API_VERSION;
let req = mock_http_request(
header,
Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="),
Some(format!("http://localhost/{version}/sql?db=public").as_str()),
)
@@ -86,6 +93,7 @@ async fn test_schema_validating() {
// wrong database
let req = mock_http_request(
header,
Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="),
Some(format!("http://localhost/{version}/sql?db=wrong").as_str()),
)
@@ -101,13 +109,18 @@ async fn test_schema_validating() {
}
#[tokio::test]
async fn test_whitelist_no_auth() {
async fn test_schema_validating() {
check_schema_validating(http::header::AUTHORIZATION.as_str()).await;
check_schema_validating(AUTHORIZATION_HEADER).await;
}
async fn check_auth_header(header_key: &str) {
// In mock user provider, right username:password == "greptime:greptime"
let mock_user_provider = Some(Arc::new(MockUserProvider::default()) as Arc<dyn UserProvider>);
// base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU="
// try auth path first
let req = mock_http_request(None, None).unwrap();
let req = mock_http_request(header_key, None, None).unwrap();
let auth_res = inner_auth(mock_user_provider.clone(), req).await;
assert!(auth_res.is_err());
let mut resp = auth_res.unwrap_err();
@@ -118,13 +131,20 @@ async fn test_whitelist_no_auth() {
);
// try whitelist path
let req = mock_http_request(None, Some("http://localhost/health")).unwrap();
let req = mock_http_request(header_key, None, Some("http://localhost/health")).unwrap();
let req = inner_auth(mock_user_provider, req).await;
assert!(req.is_ok());
}
#[tokio::test]
async fn test_whitelist_no_auth() {
check_auth_header(http::header::AUTHORIZATION.as_str()).await;
check_auth_header(AUTHORIZATION_HEADER).await;
}
// copy from http::authorize
fn mock_http_request(
auth_header_key: &str,
auth_header: Option<&str>,
uri: Option<&str>,
) -> servers::error::Result<Request<()>> {
@@ -132,7 +152,7 @@ fn mock_http_request(
let mut req = Request::builder()
.uri(uri.unwrap_or(format!("http://localhost/{http_api_version}/sql?db=public").as_str()));
if let Some(auth_header) = auth_header {
req = req.header(http::header::AUTHORIZATION, auth_header);
req = req.header(auth_header_key, auth_header);
}
Ok(req.body(()).unwrap())
}