mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-14 17:23:09 +00:00
chore: add path check to http auth (#866)
* chore: add whitelist to http auth * chore: use const instead of format everytime
This commit is contained in:
@@ -93,6 +93,7 @@ pub(crate) fn query_context_from_db(
|
||||
}
|
||||
|
||||
const HTTP_API_VERSION: &str = "v1";
|
||||
const HTTP_API_PREFIX: &str = "/v1/";
|
||||
|
||||
pub struct HttpServer {
|
||||
sql_handler: SqlQueryHandlerRef,
|
||||
|
||||
@@ -25,6 +25,7 @@ use tower_http::auth::AsyncAuthorizeRequest;
|
||||
|
||||
use crate::auth::{Identity, UserProviderRef};
|
||||
use crate::error::{self, Result};
|
||||
use crate::http::HTTP_API_PREFIX;
|
||||
|
||||
pub struct HttpAuth<RespBody> {
|
||||
user_provider: Option<UserProviderRef>,
|
||||
@@ -61,7 +62,8 @@ where
|
||||
fn authorize(&mut self, mut request: Request<B>) -> Self::Future {
|
||||
let user_provider = self.user_provider.clone();
|
||||
Box::pin(async move {
|
||||
let user_provider = if let Some(user_provider) = &user_provider {
|
||||
let need_auth = request.uri().path().starts_with(HTTP_API_PREFIX);
|
||||
let user_provider = if let Some(user_provider) = user_provider.filter(|_| need_auth) {
|
||||
user_provider
|
||||
} else {
|
||||
request.extensions_mut().insert(UserInfo::default());
|
||||
@@ -192,7 +194,7 @@ mod tests {
|
||||
};
|
||||
|
||||
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
|
||||
let req = mock_http_request("Basic dXNlcm5hbWU6cGFzc3dvcmQ=").unwrap();
|
||||
let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
|
||||
let auth_res = http_auth.authorize(req).await.unwrap();
|
||||
let user_info: &UserInfo = auth_res.extensions().get().unwrap();
|
||||
let default = UserInfo::default();
|
||||
@@ -206,22 +208,43 @@ mod tests {
|
||||
};
|
||||
|
||||
// base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU="
|
||||
let req = mock_http_request("Basic Z3JlcHRpbWU6Z3JlcHRpbWU=").unwrap();
|
||||
let req = mock_http_request(Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="), None).unwrap();
|
||||
let req = http_auth.authorize(req).await.unwrap();
|
||||
let user_info: &UserInfo = req.extensions().get().unwrap();
|
||||
let default = UserInfo::default();
|
||||
assert_eq!(default.username(), user_info.username());
|
||||
|
||||
let req = mock_http_request_no_auth().unwrap();
|
||||
let req = mock_http_request(None, None).unwrap();
|
||||
let auth_res = http_auth.authorize(req).await;
|
||||
assert!(auth_res.is_err());
|
||||
|
||||
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
|
||||
let wrong_req = mock_http_request("Basic dXNlcm5hbWU6cGFzc3dvcmQ=").unwrap();
|
||||
let wrong_req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
|
||||
let auth_res = http_auth.authorize(wrong_req).await;
|
||||
assert!(auth_res.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_whitelist_no_auth() {
|
||||
// In mock user provider, right username:password == "greptime:greptime"
|
||||
let mock_user_provider = Some(Arc::new(MockUserProvider {}) as Arc<dyn UserProvider>);
|
||||
let mut http_auth: HttpAuth<BoxBody> = HttpAuth {
|
||||
user_provider: mock_user_provider,
|
||||
_ty: PhantomData,
|
||||
};
|
||||
|
||||
// base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU="
|
||||
// try auth path first
|
||||
let req = mock_http_request(None, None).unwrap();
|
||||
let req = http_auth.authorize(req).await;
|
||||
assert!(req.is_err());
|
||||
|
||||
// try whitelist path
|
||||
let req = mock_http_request(None, Some("http://localhost/health")).unwrap();
|
||||
let req = http_auth.authorize(req).await;
|
||||
assert!(req.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_basic() {
|
||||
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
|
||||
@@ -249,36 +272,32 @@ mod tests {
|
||||
#[test]
|
||||
fn test_auth_header() {
|
||||
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
|
||||
let req = mock_http_request("Basic dXNlcm5hbWU6cGFzc3dvcmQ=").unwrap();
|
||||
let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
|
||||
|
||||
let (auth_scheme, credential) = auth_header(&req).unwrap();
|
||||
matches!(auth_scheme, AuthScheme::Basic);
|
||||
assert_eq!("dXNlcm5hbWU6cGFzc3dvcmQ=", credential);
|
||||
|
||||
let wrong_req = mock_http_request("Basic dXNlcm5hbWU6 cGFzc3dvcmQ=").unwrap();
|
||||
let wrong_req = mock_http_request(Some("Basic dXNlcm5hbWU6 cGFzc3dvcmQ="), None).unwrap();
|
||||
let res = auth_header(&wrong_req);
|
||||
matches!(
|
||||
res.err(),
|
||||
Some(error::Error::InvalidAuthorizationHeader { .. })
|
||||
);
|
||||
|
||||
let wrong_req = mock_http_request("Digest dXNlcm5hbWU6cGFzc3dvcmQ=").unwrap();
|
||||
let wrong_req = mock_http_request(Some("Digest dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
|
||||
let res = auth_header(&wrong_req);
|
||||
matches!(res.err(), Some(error::Error::UnsupportedAuthScheme { .. }));
|
||||
}
|
||||
|
||||
fn mock_http_request(auth_header: &str) -> Result<Request<()>> {
|
||||
Ok(Request::builder()
|
||||
.uri("https://www.rust-lang.org/")
|
||||
.header(http::header::AUTHORIZATION, auth_header)
|
||||
.body(())
|
||||
.unwrap())
|
||||
}
|
||||
fn mock_http_request(auth_header: Option<&str>, uri: Option<&str>) -> Result<Request<()>> {
|
||||
let http_api_version = crate::http::HTTP_API_VERSION;
|
||||
let mut req = Request::builder()
|
||||
.uri(uri.unwrap_or(format!("http://localhost/{http_api_version}/sql").as_str()));
|
||||
if let Some(auth_header) = auth_header {
|
||||
req = req.header(http::header::AUTHORIZATION, auth_header);
|
||||
}
|
||||
|
||||
fn mock_http_request_no_auth() -> Result<Request<()>> {
|
||||
Ok(Request::builder()
|
||||
.uri("https://www.rust-lang.org/")
|
||||
.body(())
|
||||
.unwrap())
|
||||
Ok(req.body(()).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user