mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2025-12-22 22:20:02 +00:00
refactor: use split instead of serde_urlencoded in http auth (#1110)
* refactor: change from urlencoded to regex * refactor: change from urlencoded to regex * chore: add unit test * chore: update comment * chore: remove local benchmark test * chore: minor fix * chore: remove unused dep
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -6940,7 +6940,6 @@ dependencies = [
|
||||
"script",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"session",
|
||||
"sha1",
|
||||
"snafu",
|
||||
|
||||
@@ -52,7 +52,6 @@ rustls-pemfile = "1.0"
|
||||
schemars = "0.8"
|
||||
serde.workspace = true
|
||||
serde_json = "1.0"
|
||||
serde_urlencoded = "0.7"
|
||||
session = { path = "../session" }
|
||||
sha1 = "0.10"
|
||||
snafu = { version = "0.7", features = ["backtraces"] }
|
||||
|
||||
@@ -82,7 +82,7 @@ pub enum Error {
|
||||
#[snafu(display("Invalid config value: {}, {}", value, msg))]
|
||||
InvalidConfig { value: String, msg: String },
|
||||
|
||||
#[snafu(display("Illegal runtime param: {}", msg))]
|
||||
#[snafu(display("Illegal param: {}", msg))]
|
||||
IllegalParam { msg: String },
|
||||
|
||||
#[snafu(display("Internal state error: {}", msg))]
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use axum::http::{self, Request, StatusCode};
|
||||
@@ -106,21 +105,12 @@ async fn authorize<B: Send + Sync + 'static>(
|
||||
) -> crate::auth::Result<()> {
|
||||
// try get database name
|
||||
let query = request.uri().query().unwrap_or_default();
|
||||
let input_database = match serde_urlencoded::from_str::<HashMap<String, String>>(query) {
|
||||
Ok(query_map) => query_map
|
||||
.get("db")
|
||||
.context(IllegalParamSnafu {
|
||||
msg: "fail to get valid database from http query",
|
||||
})?
|
||||
.to_owned(),
|
||||
Err(e) => IllegalParamSnafu {
|
||||
msg: format!("fail to parse http query: {e}"),
|
||||
}
|
||||
.fail()?,
|
||||
};
|
||||
let input_database = extract_db_from_query(query).context(IllegalParamSnafu {
|
||||
msg: "db not provided or corrupted",
|
||||
})?;
|
||||
|
||||
let (catalog, database) =
|
||||
crate::parse_catalog_and_schema_from_client_database_name(&input_database);
|
||||
crate::parse_catalog_and_schema_from_client_database_name(input_database);
|
||||
|
||||
let user_info = request
|
||||
.extensions()
|
||||
@@ -157,20 +147,7 @@ fn get_influxdb_credentials<B: Send + Sync + 'static>(
|
||||
// try v1
|
||||
let Some(query_str) = request.uri().query() else { return Ok(None) };
|
||||
|
||||
// TODO(shuiyisong): remove this for performance optimization
|
||||
// `authorize` would deserialize query from urlencoded again
|
||||
let query = match serde_urlencoded::from_str::<HashMap<String, String>>(query_str) {
|
||||
Ok(query_map) => query_map,
|
||||
Err(e) => IllegalParamSnafu {
|
||||
msg: format!("fail to parse http query: {e}"),
|
||||
}
|
||||
.fail()?,
|
||||
};
|
||||
|
||||
let username = query.get("u");
|
||||
let password = query.get("p");
|
||||
|
||||
match (username, password) {
|
||||
match extract_influxdb_user_from_query(query_str) {
|
||||
(None, None) => Ok(None),
|
||||
(Some(username), Some(password)) => {
|
||||
Ok(Some((username.to_string(), password.to_string())))
|
||||
@@ -281,8 +258,33 @@ fn need_auth<B>(req: &Request<B>) -> bool {
|
||||
path.starts_with(HTTP_API_PREFIX)
|
||||
}
|
||||
|
||||
fn extract_db_from_query(query: &str) -> Option<&str> {
|
||||
for pair in query.split('&') {
|
||||
if let Some(db) = pair.strip_prefix("db=") {
|
||||
return if db.is_empty() { None } else { Some(db) };
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn extract_influxdb_user_from_query(query: &str) -> (Option<&str>, Option<&str>) {
|
||||
let mut username = None;
|
||||
let mut password = None;
|
||||
|
||||
for pair in query.split('&') {
|
||||
if pair.starts_with("u=") && pair.len() > 2 {
|
||||
username = Some(&pair[2..]);
|
||||
} else if pair.starts_with("p=") && pair.len() > 2 {
|
||||
password = Some(&pair[2..]);
|
||||
}
|
||||
}
|
||||
(username, password)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::assert_matches::assert_matches;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
@@ -319,7 +321,7 @@ mod tests {
|
||||
|
||||
let wrong_credential = "dXNlcm5hbWU6cG Fzc3dvcmQ=";
|
||||
let result = decode_basic(wrong_credential);
|
||||
matches!(result.err(), Some(error::Error::InvalidBase64Value { .. }));
|
||||
assert_matches!(result.err(), Some(error::Error::InvalidBase64Value { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -330,7 +332,7 @@ mod tests {
|
||||
|
||||
let auth_scheme_str = "basic dGVzdDp0ZXN0";
|
||||
let scheme: AuthScheme = auth_scheme_str.try_into().unwrap();
|
||||
matches!(scheme, AuthScheme::Basic(username, pwd) if username == "test" && pwd == "test");
|
||||
assert_matches!(scheme, AuthScheme::Basic(username, pwd) if username == "test" && pwd == "test");
|
||||
|
||||
let unsupported = "digest";
|
||||
let auth_scheme: Result<AuthScheme> = unsupported.try_into();
|
||||
@@ -343,18 +345,18 @@ mod tests {
|
||||
let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
|
||||
|
||||
let auth_scheme = auth_header(&req).unwrap();
|
||||
matches!(auth_scheme, AuthScheme::Basic(username, pwd) if username == "username" && pwd == "password");
|
||||
assert_matches!(auth_scheme, AuthScheme::Basic(username, pwd) if username == "username" && pwd == "password");
|
||||
|
||||
let wrong_req = mock_http_request(Some("Basic dXNlcm5hbWU6 cGFzc3dvcmQ="), None).unwrap();
|
||||
let res = auth_header(&wrong_req);
|
||||
matches!(
|
||||
assert_matches!(
|
||||
res.err(),
|
||||
Some(error::Error::InvalidAuthorizationHeader { .. })
|
||||
);
|
||||
|
||||
let wrong_req = mock_http_request(Some("Digest dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
|
||||
let res = auth_header(&wrong_req);
|
||||
matches!(res.err(), Some(error::Error::UnsupportedAuthScheme { .. }));
|
||||
assert_matches!(res.err(), Some(error::Error::UnsupportedAuthScheme { .. }));
|
||||
}
|
||||
|
||||
fn mock_http_request(auth_header: Option<&str>, uri: Option<&str>) -> Result<Request<()>> {
|
||||
@@ -367,4 +369,50 @@ mod tests {
|
||||
|
||||
Ok(req.body(()).unwrap())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_db() {
|
||||
assert_matches!(extract_db_from_query(""), None);
|
||||
assert_matches!(extract_db_from_query("&"), None);
|
||||
assert_matches!(extract_db_from_query("db="), None);
|
||||
assert_matches!(extract_db_from_query("db=foo"), Some("foo"));
|
||||
assert_matches!(extract_db_from_query("name=bar"), None);
|
||||
assert_matches!(extract_db_from_query("db=&name=bar"), None);
|
||||
assert_matches!(extract_db_from_query("db=foo&name=bar"), Some("foo"));
|
||||
assert_matches!(extract_db_from_query("name=bar&db="), None);
|
||||
assert_matches!(extract_db_from_query("name=bar&db=foo"), Some("foo"));
|
||||
assert_matches!(extract_db_from_query("name=bar&db=&name=bar"), None);
|
||||
assert_matches!(
|
||||
extract_db_from_query("name=bar&db=foo&name=bar"),
|
||||
Some("foo")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_user() {
|
||||
assert_matches!(extract_influxdb_user_from_query(""), (None, None));
|
||||
assert_matches!(extract_influxdb_user_from_query("u="), (None, None));
|
||||
assert_matches!(
|
||||
extract_influxdb_user_from_query("u=123"),
|
||||
(Some("123"), None)
|
||||
);
|
||||
assert_matches!(
|
||||
extract_influxdb_user_from_query("u=123&p="),
|
||||
(Some("123"), None)
|
||||
);
|
||||
assert_matches!(
|
||||
extract_influxdb_user_from_query("u=123&p=4"),
|
||||
(Some("123"), Some("4"))
|
||||
);
|
||||
assert_matches!(extract_influxdb_user_from_query("p="), (None, None));
|
||||
assert_matches!(extract_influxdb_user_from_query("p=4"), (None, Some("4")));
|
||||
assert_matches!(
|
||||
extract_influxdb_user_from_query("p=4&u="),
|
||||
(None, Some("4"))
|
||||
);
|
||||
assert_matches!(
|
||||
extract_influxdb_user_from_query("p=4&u=123"),
|
||||
(Some("123"), Some("4"))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -134,10 +134,10 @@ async fn test_auth_by_plain_text() {
|
||||
)
|
||||
.await;
|
||||
assert!(auth_result.is_err());
|
||||
matches!(
|
||||
assert!(matches!(
|
||||
auth_result.err().unwrap(),
|
||||
servers::auth::Error::UnsupportedPasswordType { .. }
|
||||
);
|
||||
));
|
||||
|
||||
// auth failed, err: user not exist.
|
||||
let auth_result = user_provider
|
||||
@@ -147,10 +147,10 @@ async fn test_auth_by_plain_text() {
|
||||
)
|
||||
.await;
|
||||
assert!(auth_result.is_err());
|
||||
matches!(
|
||||
assert!(matches!(
|
||||
auth_result.err().unwrap(),
|
||||
servers::auth::Error::UserNotFound { .. }
|
||||
);
|
||||
));
|
||||
|
||||
// auth failed, err: wrong password
|
||||
let auth_result = user_provider
|
||||
@@ -160,10 +160,10 @@ async fn test_auth_by_plain_text() {
|
||||
)
|
||||
.await;
|
||||
assert!(auth_result.is_err());
|
||||
matches!(
|
||||
assert!(matches!(
|
||||
auth_result.err().unwrap(),
|
||||
servers::auth::Error::UserPasswordMismatch { .. }
|
||||
);
|
||||
))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
Reference in New Issue
Block a user