chore: add path check to http auth (#866)

* chore: add whitelist to http auth

* chore: use const instead of format everytime
This commit is contained in:
shuiyisong
2023-01-12 10:20:18 +08:00
committed by GitHub
parent 4015dd8075
commit b91c77b862
2 changed files with 40 additions and 20 deletions

View File

@@ -93,6 +93,7 @@ pub(crate) fn query_context_from_db(
}
const HTTP_API_VERSION: &str = "v1";
const HTTP_API_PREFIX: &str = "/v1/";
pub struct HttpServer {
sql_handler: SqlQueryHandlerRef,

View File

@@ -25,6 +25,7 @@ use tower_http::auth::AsyncAuthorizeRequest;
use crate::auth::{Identity, UserProviderRef};
use crate::error::{self, Result};
use crate::http::HTTP_API_PREFIX;
pub struct HttpAuth<RespBody> {
user_provider: Option<UserProviderRef>,
@@ -61,7 +62,8 @@ where
fn authorize(&mut self, mut request: Request<B>) -> Self::Future {
let user_provider = self.user_provider.clone();
Box::pin(async move {
let user_provider = if let Some(user_provider) = &user_provider {
let need_auth = request.uri().path().starts_with(HTTP_API_PREFIX);
let user_provider = if let Some(user_provider) = user_provider.filter(|_| need_auth) {
user_provider
} else {
request.extensions_mut().insert(UserInfo::default());
@@ -192,7 +194,7 @@ mod tests {
};
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
let req = mock_http_request("Basic dXNlcm5hbWU6cGFzc3dvcmQ=").unwrap();
let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
let auth_res = http_auth.authorize(req).await.unwrap();
let user_info: &UserInfo = auth_res.extensions().get().unwrap();
let default = UserInfo::default();
@@ -206,22 +208,43 @@ mod tests {
};
// base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU="
let req = mock_http_request("Basic Z3JlcHRpbWU6Z3JlcHRpbWU=").unwrap();
let req = mock_http_request(Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="), None).unwrap();
let req = http_auth.authorize(req).await.unwrap();
let user_info: &UserInfo = req.extensions().get().unwrap();
let default = UserInfo::default();
assert_eq!(default.username(), user_info.username());
let req = mock_http_request_no_auth().unwrap();
let req = mock_http_request(None, None).unwrap();
let auth_res = http_auth.authorize(req).await;
assert!(auth_res.is_err());
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
let wrong_req = mock_http_request("Basic dXNlcm5hbWU6cGFzc3dvcmQ=").unwrap();
let wrong_req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
let auth_res = http_auth.authorize(wrong_req).await;
assert!(auth_res.is_err());
}
#[tokio::test]
async fn test_whitelist_no_auth() {
// In mock user provider, right username:password == "greptime:greptime"
let mock_user_provider = Some(Arc::new(MockUserProvider {}) as Arc<dyn UserProvider>);
let mut http_auth: HttpAuth<BoxBody> = HttpAuth {
user_provider: mock_user_provider,
_ty: PhantomData,
};
// base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU="
// try auth path first
let req = mock_http_request(None, None).unwrap();
let req = http_auth.authorize(req).await;
assert!(req.is_err());
// try whitelist path
let req = mock_http_request(None, Some("http://localhost/health")).unwrap();
let req = http_auth.authorize(req).await;
assert!(req.is_ok());
}
#[test]
fn test_decode_basic() {
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
@@ -249,36 +272,32 @@ mod tests {
#[test]
fn test_auth_header() {
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
let req = mock_http_request("Basic dXNlcm5hbWU6cGFzc3dvcmQ=").unwrap();
let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
let (auth_scheme, credential) = auth_header(&req).unwrap();
matches!(auth_scheme, AuthScheme::Basic);
assert_eq!("dXNlcm5hbWU6cGFzc3dvcmQ=", credential);
let wrong_req = mock_http_request("Basic dXNlcm5hbWU6 cGFzc3dvcmQ=").unwrap();
let wrong_req = mock_http_request(Some("Basic dXNlcm5hbWU6 cGFzc3dvcmQ="), None).unwrap();
let res = auth_header(&wrong_req);
matches!(
res.err(),
Some(error::Error::InvalidAuthorizationHeader { .. })
);
let wrong_req = mock_http_request("Digest dXNlcm5hbWU6cGFzc3dvcmQ=").unwrap();
let wrong_req = mock_http_request(Some("Digest dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
let res = auth_header(&wrong_req);
matches!(res.err(), Some(error::Error::UnsupportedAuthScheme { .. }));
}
fn mock_http_request(auth_header: &str) -> Result<Request<()>> {
Ok(Request::builder()
.uri("https://www.rust-lang.org/")
.header(http::header::AUTHORIZATION, auth_header)
.body(())
.unwrap())
}
fn mock_http_request(auth_header: Option<&str>, uri: Option<&str>) -> Result<Request<()>> {
let http_api_version = crate::http::HTTP_API_VERSION;
let mut req = Request::builder()
.uri(uri.unwrap_or(format!("http://localhost/{http_api_version}/sql").as_str()));
if let Some(auth_header) = auth_header {
req = req.header(http::header::AUTHORIZATION, auth_header);
}
fn mock_http_request_no_auth() -> Result<Request<()>> {
Ok(Request::builder()
.uri("https://www.rust-lang.org/")
.body(())
.unwrap())
Ok(req.body(()).unwrap())
}
}