diff --git a/Cargo.lock b/Cargo.lock index bb751e53b1..6439d25f14 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6940,7 +6940,6 @@ dependencies = [ "script", "serde", "serde_json", - "serde_urlencoded", "session", "sha1", "snafu", diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index d3e2634fbd..48bd54d466 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -52,7 +52,6 @@ rustls-pemfile = "1.0" schemars = "0.8" serde.workspace = true serde_json = "1.0" -serde_urlencoded = "0.7" session = { path = "../session" } sha1 = "0.10" snafu = { version = "0.7", features = ["backtraces"] } diff --git a/src/servers/src/auth.rs b/src/servers/src/auth.rs index 35b5ac533b..adff1a74d8 100644 --- a/src/servers/src/auth.rs +++ b/src/servers/src/auth.rs @@ -82,7 +82,7 @@ pub enum Error { #[snafu(display("Invalid config value: {}, {}", value, msg))] InvalidConfig { value: String, msg: String }, - #[snafu(display("Illegal runtime param: {}", msg))] + #[snafu(display("Illegal param: {}", msg))] IllegalParam { msg: String }, #[snafu(display("Internal state error: {}", msg))] diff --git a/src/servers/src/http/authorize.rs b/src/servers/src/http/authorize.rs index d438fe4e1a..4259912d6f 100644 --- a/src/servers/src/http/authorize.rs +++ b/src/servers/src/http/authorize.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; use std::marker::PhantomData; use axum::http::{self, Request, StatusCode}; @@ -106,21 +105,12 @@ async fn authorize( ) -> crate::auth::Result<()> { // try get database name let query = request.uri().query().unwrap_or_default(); - let input_database = match serde_urlencoded::from_str::>(query) { - Ok(query_map) => query_map - .get("db") - .context(IllegalParamSnafu { - msg: "fail to get valid database from http query", - })? - .to_owned(), - Err(e) => IllegalParamSnafu { - msg: format!("fail to parse http query: {e}"), - } - .fail()?, - }; + let input_database = extract_db_from_query(query).context(IllegalParamSnafu { + msg: "db not provided or corrupted", + })?; let (catalog, database) = - crate::parse_catalog_and_schema_from_client_database_name(&input_database); + crate::parse_catalog_and_schema_from_client_database_name(input_database); let user_info = request .extensions() @@ -157,20 +147,7 @@ fn get_influxdb_credentials( // try v1 let Some(query_str) = request.uri().query() else { return Ok(None) }; - // TODO(shuiyisong): remove this for performance optimization - // `authorize` would deserialize query from urlencoded again - let query = match serde_urlencoded::from_str::>(query_str) { - Ok(query_map) => query_map, - Err(e) => IllegalParamSnafu { - msg: format!("fail to parse http query: {e}"), - } - .fail()?, - }; - - let username = query.get("u"); - let password = query.get("p"); - - match (username, password) { + match extract_influxdb_user_from_query(query_str) { (None, None) => Ok(None), (Some(username), Some(password)) => { Ok(Some((username.to_string(), password.to_string()))) @@ -281,8 +258,33 @@ fn need_auth(req: &Request) -> bool { path.starts_with(HTTP_API_PREFIX) } +fn extract_db_from_query(query: &str) -> Option<&str> { + for pair in query.split('&') { + if let Some(db) = pair.strip_prefix("db=") { + return if db.is_empty() { None } else { Some(db) }; + } + } + None +} + +fn extract_influxdb_user_from_query(query: &str) -> (Option<&str>, Option<&str>) { + let mut username = None; + let mut password = None; + + for pair in query.split('&') { + if pair.starts_with("u=") && pair.len() > 2 { + username = Some(&pair[2..]); + } else if pair.starts_with("p=") && pair.len() > 2 { + password = Some(&pair[2..]); + } + } + (username, password) +} + #[cfg(test)] mod tests { + use std::assert_matches::assert_matches; + use super::*; #[test] @@ -319,7 +321,7 @@ mod tests { let wrong_credential = "dXNlcm5hbWU6cG Fzc3dvcmQ="; let result = decode_basic(wrong_credential); - matches!(result.err(), Some(error::Error::InvalidBase64Value { .. })); + assert_matches!(result.err(), Some(error::Error::InvalidBase64Value { .. })); } #[test] @@ -330,7 +332,7 @@ mod tests { let auth_scheme_str = "basic dGVzdDp0ZXN0"; let scheme: AuthScheme = auth_scheme_str.try_into().unwrap(); - matches!(scheme, AuthScheme::Basic(username, pwd) if username == "test" && pwd == "test"); + assert_matches!(scheme, AuthScheme::Basic(username, pwd) if username == "test" && pwd == "test"); let unsupported = "digest"; let auth_scheme: Result = unsupported.try_into(); @@ -343,18 +345,18 @@ mod tests { let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap(); let auth_scheme = auth_header(&req).unwrap(); - matches!(auth_scheme, AuthScheme::Basic(username, pwd) if username == "username" && pwd == "password"); + assert_matches!(auth_scheme, AuthScheme::Basic(username, pwd) if username == "username" && pwd == "password"); let wrong_req = mock_http_request(Some("Basic dXNlcm5hbWU6 cGFzc3dvcmQ="), None).unwrap(); let res = auth_header(&wrong_req); - matches!( + assert_matches!( res.err(), Some(error::Error::InvalidAuthorizationHeader { .. }) ); let wrong_req = mock_http_request(Some("Digest dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap(); let res = auth_header(&wrong_req); - matches!(res.err(), Some(error::Error::UnsupportedAuthScheme { .. })); + assert_matches!(res.err(), Some(error::Error::UnsupportedAuthScheme { .. })); } fn mock_http_request(auth_header: Option<&str>, uri: Option<&str>) -> Result> { @@ -367,4 +369,50 @@ mod tests { Ok(req.body(()).unwrap()) } + + #[test] + fn test_extract_db() { + assert_matches!(extract_db_from_query(""), None); + assert_matches!(extract_db_from_query("&"), None); + assert_matches!(extract_db_from_query("db="), None); + assert_matches!(extract_db_from_query("db=foo"), Some("foo")); + assert_matches!(extract_db_from_query("name=bar"), None); + assert_matches!(extract_db_from_query("db=&name=bar"), None); + assert_matches!(extract_db_from_query("db=foo&name=bar"), Some("foo")); + assert_matches!(extract_db_from_query("name=bar&db="), None); + assert_matches!(extract_db_from_query("name=bar&db=foo"), Some("foo")); + assert_matches!(extract_db_from_query("name=bar&db=&name=bar"), None); + assert_matches!( + extract_db_from_query("name=bar&db=foo&name=bar"), + Some("foo") + ); + } + + #[test] + fn test_extract_user() { + assert_matches!(extract_influxdb_user_from_query(""), (None, None)); + assert_matches!(extract_influxdb_user_from_query("u="), (None, None)); + assert_matches!( + extract_influxdb_user_from_query("u=123"), + (Some("123"), None) + ); + assert_matches!( + extract_influxdb_user_from_query("u=123&p="), + (Some("123"), None) + ); + assert_matches!( + extract_influxdb_user_from_query("u=123&p=4"), + (Some("123"), Some("4")) + ); + assert_matches!(extract_influxdb_user_from_query("p="), (None, None)); + assert_matches!(extract_influxdb_user_from_query("p=4"), (None, Some("4"))); + assert_matches!( + extract_influxdb_user_from_query("p=4&u="), + (None, Some("4")) + ); + assert_matches!( + extract_influxdb_user_from_query("p=4&u=123"), + (Some("123"), Some("4")) + ); + } } diff --git a/src/servers/tests/auth.rs b/src/servers/tests/auth.rs index 8e9c9b8546..0c38124189 100644 --- a/src/servers/tests/auth.rs +++ b/src/servers/tests/auth.rs @@ -134,10 +134,10 @@ async fn test_auth_by_plain_text() { ) .await; assert!(auth_result.is_err()); - matches!( + assert!(matches!( auth_result.err().unwrap(), servers::auth::Error::UnsupportedPasswordType { .. } - ); + )); // auth failed, err: user not exist. let auth_result = user_provider @@ -147,10 +147,10 @@ async fn test_auth_by_plain_text() { ) .await; assert!(auth_result.is_err()); - matches!( + assert!(matches!( auth_result.err().unwrap(), servers::auth::Error::UserNotFound { .. } - ); + )); // auth failed, err: wrong password let auth_result = user_provider @@ -160,10 +160,10 @@ async fn test_auth_by_plain_text() { ) .await; assert!(auth_result.is_err()); - matches!( + assert!(matches!( auth_result.err().unwrap(), servers::auth::Error::UserPasswordMismatch { .. } - ); + )) } #[tokio::test]