refactor: use split instead of serde_urlencoded in http auth (#1110)

* refactor: change from urlencoded to regex

* refactor: change from urlencoded to regex

* chore: add unit test

* chore: update comment

* chore: remove local benchmark test

* chore: minor fix

* chore: remove unused dep
This commit is contained in:
shuiyisong
2023-03-07 10:51:47 +08:00
committed by GitHub
parent e8cc9b4b29
commit 1b4236d698
5 changed files with 88 additions and 42 deletions

1
Cargo.lock generated
View File

@@ -6940,7 +6940,6 @@ dependencies = [
"script",
"serde",
"serde_json",
"serde_urlencoded",
"session",
"sha1",
"snafu",

View File

@@ -52,7 +52,6 @@ rustls-pemfile = "1.0"
schemars = "0.8"
serde.workspace = true
serde_json = "1.0"
serde_urlencoded = "0.7"
session = { path = "../session" }
sha1 = "0.10"
snafu = { version = "0.7", features = ["backtraces"] }

View File

@@ -82,7 +82,7 @@ pub enum Error {
#[snafu(display("Invalid config value: {}, {}", value, msg))]
InvalidConfig { value: String, msg: String },
#[snafu(display("Illegal runtime param: {}", msg))]
#[snafu(display("Illegal param: {}", msg))]
IllegalParam { msg: String },
#[snafu(display("Internal state error: {}", msg))]

View File

@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::marker::PhantomData;
use axum::http::{self, Request, StatusCode};
@@ -106,21 +105,12 @@ async fn authorize<B: Send + Sync + 'static>(
) -> crate::auth::Result<()> {
// try get database name
let query = request.uri().query().unwrap_or_default();
let input_database = match serde_urlencoded::from_str::<HashMap<String, String>>(query) {
Ok(query_map) => query_map
.get("db")
.context(IllegalParamSnafu {
msg: "fail to get valid database from http query",
})?
.to_owned(),
Err(e) => IllegalParamSnafu {
msg: format!("fail to parse http query: {e}"),
}
.fail()?,
};
let input_database = extract_db_from_query(query).context(IllegalParamSnafu {
msg: "db not provided or corrupted",
})?;
let (catalog, database) =
crate::parse_catalog_and_schema_from_client_database_name(&input_database);
crate::parse_catalog_and_schema_from_client_database_name(input_database);
let user_info = request
.extensions()
@@ -157,20 +147,7 @@ fn get_influxdb_credentials<B: Send + Sync + 'static>(
// try v1
let Some(query_str) = request.uri().query() else { return Ok(None) };
// TODO(shuiyisong): remove this for performance optimization
// `authorize` would deserialize query from urlencoded again
let query = match serde_urlencoded::from_str::<HashMap<String, String>>(query_str) {
Ok(query_map) => query_map,
Err(e) => IllegalParamSnafu {
msg: format!("fail to parse http query: {e}"),
}
.fail()?,
};
let username = query.get("u");
let password = query.get("p");
match (username, password) {
match extract_influxdb_user_from_query(query_str) {
(None, None) => Ok(None),
(Some(username), Some(password)) => {
Ok(Some((username.to_string(), password.to_string())))
@@ -281,8 +258,33 @@ fn need_auth<B>(req: &Request<B>) -> bool {
path.starts_with(HTTP_API_PREFIX)
}
fn extract_db_from_query(query: &str) -> Option<&str> {
for pair in query.split('&') {
if let Some(db) = pair.strip_prefix("db=") {
return if db.is_empty() { None } else { Some(db) };
}
}
None
}
fn extract_influxdb_user_from_query(query: &str) -> (Option<&str>, Option<&str>) {
let mut username = None;
let mut password = None;
for pair in query.split('&') {
if pair.starts_with("u=") && pair.len() > 2 {
username = Some(&pair[2..]);
} else if pair.starts_with("p=") && pair.len() > 2 {
password = Some(&pair[2..]);
}
}
(username, password)
}
#[cfg(test)]
mod tests {
use std::assert_matches::assert_matches;
use super::*;
#[test]
@@ -319,7 +321,7 @@ mod tests {
let wrong_credential = "dXNlcm5hbWU6cG Fzc3dvcmQ=";
let result = decode_basic(wrong_credential);
matches!(result.err(), Some(error::Error::InvalidBase64Value { .. }));
assert_matches!(result.err(), Some(error::Error::InvalidBase64Value { .. }));
}
#[test]
@@ -330,7 +332,7 @@ mod tests {
let auth_scheme_str = "basic dGVzdDp0ZXN0";
let scheme: AuthScheme = auth_scheme_str.try_into().unwrap();
matches!(scheme, AuthScheme::Basic(username, pwd) if username == "test" && pwd == "test");
assert_matches!(scheme, AuthScheme::Basic(username, pwd) if username == "test" && pwd == "test");
let unsupported = "digest";
let auth_scheme: Result<AuthScheme> = unsupported.try_into();
@@ -343,18 +345,18 @@ mod tests {
let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
let auth_scheme = auth_header(&req).unwrap();
matches!(auth_scheme, AuthScheme::Basic(username, pwd) if username == "username" && pwd == "password");
assert_matches!(auth_scheme, AuthScheme::Basic(username, pwd) if username == "username" && pwd == "password");
let wrong_req = mock_http_request(Some("Basic dXNlcm5hbWU6 cGFzc3dvcmQ="), None).unwrap();
let res = auth_header(&wrong_req);
matches!(
assert_matches!(
res.err(),
Some(error::Error::InvalidAuthorizationHeader { .. })
);
let wrong_req = mock_http_request(Some("Digest dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
let res = auth_header(&wrong_req);
matches!(res.err(), Some(error::Error::UnsupportedAuthScheme { .. }));
assert_matches!(res.err(), Some(error::Error::UnsupportedAuthScheme { .. }));
}
fn mock_http_request(auth_header: Option<&str>, uri: Option<&str>) -> Result<Request<()>> {
@@ -367,4 +369,50 @@ mod tests {
Ok(req.body(()).unwrap())
}
#[test]
fn test_extract_db() {
assert_matches!(extract_db_from_query(""), None);
assert_matches!(extract_db_from_query("&"), None);
assert_matches!(extract_db_from_query("db="), None);
assert_matches!(extract_db_from_query("db=foo"), Some("foo"));
assert_matches!(extract_db_from_query("name=bar"), None);
assert_matches!(extract_db_from_query("db=&name=bar"), None);
assert_matches!(extract_db_from_query("db=foo&name=bar"), Some("foo"));
assert_matches!(extract_db_from_query("name=bar&db="), None);
assert_matches!(extract_db_from_query("name=bar&db=foo"), Some("foo"));
assert_matches!(extract_db_from_query("name=bar&db=&name=bar"), None);
assert_matches!(
extract_db_from_query("name=bar&db=foo&name=bar"),
Some("foo")
);
}
#[test]
fn test_extract_user() {
assert_matches!(extract_influxdb_user_from_query(""), (None, None));
assert_matches!(extract_influxdb_user_from_query("u="), (None, None));
assert_matches!(
extract_influxdb_user_from_query("u=123"),
(Some("123"), None)
);
assert_matches!(
extract_influxdb_user_from_query("u=123&p="),
(Some("123"), None)
);
assert_matches!(
extract_influxdb_user_from_query("u=123&p=4"),
(Some("123"), Some("4"))
);
assert_matches!(extract_influxdb_user_from_query("p="), (None, None));
assert_matches!(extract_influxdb_user_from_query("p=4"), (None, Some("4")));
assert_matches!(
extract_influxdb_user_from_query("p=4&u="),
(None, Some("4"))
);
assert_matches!(
extract_influxdb_user_from_query("p=4&u=123"),
(Some("123"), Some("4"))
);
}
}

View File

@@ -134,10 +134,10 @@ async fn test_auth_by_plain_text() {
)
.await;
assert!(auth_result.is_err());
matches!(
assert!(matches!(
auth_result.err().unwrap(),
servers::auth::Error::UnsupportedPasswordType { .. }
);
));
// auth failed, err: user not exist.
let auth_result = user_provider
@@ -147,10 +147,10 @@ async fn test_auth_by_plain_text() {
)
.await;
assert!(auth_result.is_err());
matches!(
assert!(matches!(
auth_result.err().unwrap(),
servers::auth::Error::UserNotFound { .. }
);
));
// auth failed, err: wrong password
let auth_result = user_provider
@@ -160,10 +160,10 @@ async fn test_auth_by_plain_text() {
)
.await;
assert!(auth_result.is_err());
matches!(
assert!(matches!(
auth_result.err().unwrap(),
servers::auth::Error::UserPasswordMismatch { .. }
);
))
}
#[tokio::test]