diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index fe3a46b33b..976120e215 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -104,6 +104,9 @@ pub const HTTP_API_PREFIX: &str = "/v1/"; /// Default http body limit (64M). const DEFAULT_BODY_LIMIT: ReadableSize = ReadableSize::mb(64); +/// Authorization header +pub const AUTHORIZATION_HEADER: &str = "x-greptime-auth"; + // TODO(fys): This is a temporary workaround, it will be improved later pub static PUBLIC_APIS: [&str; 2] = ["/v1/influxdb/ping", "/v1/influxdb/health"]; diff --git a/src/servers/src/http/authorize.rs b/src/servers/src/http/authorize.rs index 7aee9fa9b4..d7bacf04c3 100644 --- a/src/servers/src/http/authorize.rs +++ b/src/servers/src/http/authorize.rs @@ -31,7 +31,7 @@ use session::context::QueryContextBuilder; use snafu::{ensure, OptionExt, ResultExt}; use super::header::{GreptimeDbName, GREPTIME_TIMEZONE_HEADER_NAME}; -use super::PUBLIC_APIS; +use super::{AUTHORIZATION_HEADER, PUBLIC_APIS}; use crate::error::{ self, InvalidAuthHeaderInvisibleASCIISnafu, InvalidAuthHeaderSnafu, InvalidParameterSnafu, NotFoundInfluxAuthSnafu, Result, UnsupportedAuthSchemeSnafu, UrlDecodeSnafu, @@ -246,7 +246,8 @@ type Credential<'a> = &'a str; fn auth_header(req: &Request) -> Result { let auth_header = req .headers() - .get(http::header::AUTHORIZATION) + .get(AUTHORIZATION_HEADER) + .or_else(|| req.headers().get(http::header::AUTHORIZATION)) .context(error::NotFoundAuthHeaderSnafu)? .to_str() .context(InvalidAuthHeaderInvisibleASCIISnafu)?; diff --git a/src/servers/tests/http/authorize.rs b/src/servers/tests/http/authorize.rs index 36b721c877..fd825eed59 100644 --- a/src/servers/tests/http/authorize.rs +++ b/src/servers/tests/http/authorize.rs @@ -20,12 +20,12 @@ use axum::http; use http_body::Body; use hyper::{Request, StatusCode}; use servers::http::authorize::inner_auth; +use servers::http::AUTHORIZATION_HEADER; use session::context::QueryContext; -#[tokio::test] -async fn test_http_auth() { +async fn check_http_auth(header_key: &str) { // base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ=" - let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap(); + let req = mock_http_request(header_key, Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap(); let req = inner_auth(None, req).await.unwrap(); let ctx: &QueryContext = req.extensions().get().unwrap(); let user_info = ctx.current_user(); @@ -36,14 +36,14 @@ async fn test_http_auth() { let mock_user_provider = Some(Arc::new(MockUserProvider::default()) as Arc); // base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU=" - let req = mock_http_request(Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="), None).unwrap(); + let req = mock_http_request(header_key, Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="), None).unwrap(); let req = inner_auth(mock_user_provider.clone(), req).await.unwrap(); let ctx: &QueryContext = req.extensions().get().unwrap(); let user_info = ctx.current_user(); let default = auth::userinfo_by_name(None); assert_eq!(default.username(), user_info.username()); - let req = mock_http_request(None, None).unwrap(); + let req = mock_http_request(header_key, None, None).unwrap(); let auth_res = inner_auth(mock_user_provider.clone(), req).await; assert!(auth_res.is_err()); let mut resp = auth_res.unwrap_err(); @@ -54,7 +54,8 @@ async fn test_http_auth() { ); // base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ=" - let wrong_req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap(); + let wrong_req = + mock_http_request(header_key, Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap(); let auth_res = inner_auth(mock_user_provider, wrong_req).await; assert!(auth_res.is_err()); let mut resp = auth_res.unwrap_err(); @@ -66,7 +67,12 @@ async fn test_http_auth() { } #[tokio::test] -async fn test_schema_validating() { +async fn test_http_auth() { + check_http_auth(http::header::AUTHORIZATION.as_str()).await; + check_http_auth(AUTHORIZATION_HEADER).await; +} + +async fn check_schema_validating(header: &str) { // In mock user provider, right username:password == "greptime:greptime" let mock_user_provider = Some(Arc::new(MockUserProvider::default()) as Arc); @@ -74,6 +80,7 @@ async fn test_schema_validating() { // http://localhost/{http_api_version}/sql?db=greptime let version = servers::http::HTTP_API_VERSION; let req = mock_http_request( + header, Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="), Some(format!("http://localhost/{version}/sql?db=public").as_str()), ) @@ -86,6 +93,7 @@ async fn test_schema_validating() { // wrong database let req = mock_http_request( + header, Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="), Some(format!("http://localhost/{version}/sql?db=wrong").as_str()), ) @@ -101,13 +109,18 @@ async fn test_schema_validating() { } #[tokio::test] -async fn test_whitelist_no_auth() { +async fn test_schema_validating() { + check_schema_validating(http::header::AUTHORIZATION.as_str()).await; + check_schema_validating(AUTHORIZATION_HEADER).await; +} + +async fn check_auth_header(header_key: &str) { // In mock user provider, right username:password == "greptime:greptime" let mock_user_provider = Some(Arc::new(MockUserProvider::default()) as Arc); // base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU=" // try auth path first - let req = mock_http_request(None, None).unwrap(); + let req = mock_http_request(header_key, None, None).unwrap(); let auth_res = inner_auth(mock_user_provider.clone(), req).await; assert!(auth_res.is_err()); let mut resp = auth_res.unwrap_err(); @@ -118,13 +131,20 @@ async fn test_whitelist_no_auth() { ); // try whitelist path - let req = mock_http_request(None, Some("http://localhost/health")).unwrap(); + let req = mock_http_request(header_key, None, Some("http://localhost/health")).unwrap(); let req = inner_auth(mock_user_provider, req).await; assert!(req.is_ok()); } +#[tokio::test] +async fn test_whitelist_no_auth() { + check_auth_header(http::header::AUTHORIZATION.as_str()).await; + check_auth_header(AUTHORIZATION_HEADER).await; +} + // copy from http::authorize fn mock_http_request( + auth_header_key: &str, auth_header: Option<&str>, uri: Option<&str>, ) -> servers::error::Result> { @@ -132,7 +152,7 @@ fn mock_http_request( let mut req = Request::builder() .uri(uri.unwrap_or(format!("http://localhost/{http_api_version}/sql?db=public").as_str())); if let Some(auth_header) = auth_header { - req = req.header(http::header::AUTHORIZATION, auth_header); + req = req.header(auth_header_key, auth_header); } Ok(req.body(()).unwrap()) }