diff --git a/Cargo.lock b/Cargo.lock index d75257c314..b77aaf75a2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3416,6 +3416,12 @@ dependencies = [ "digest", ] +[[package]] +name = "md5" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" + [[package]] name = "memchr" version = "2.5.0" @@ -4355,16 +4361,18 @@ dependencies = [ [[package]] name = "pgwire" -version = "0.5.0" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5dacbf864d6cb6a0e676c9a1162ab7b315b5c8e6c87fa9b6e0ba9ba0a569adb1" +checksum = "d90fd7db2eab0a1b9cdde0ef2393f99b83c6198b1c2e62595e8d269d59b8ffca" dependencies = [ "async-trait", "bytes", "derive-new", "futures", "getset", + "hex", "log", + "md5", "postgres-types", "rand 0.8.5", "thiserror", diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index 2e2c133416..c1d3cab2b5 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -31,7 +31,7 @@ num_cpus = "1.13" once_cell = "1.16" openmetrics-parser = "0.4" opensrv-mysql = "0.3" -pgwire = "0.5" +pgwire = "0.6.1" prost = "0.11" rand = "0.8" regex = "1.6" diff --git a/src/servers/src/postgres/auth_handler.rs b/src/servers/src/postgres/auth_handler.rs index 611fc7d94a..3b5149f063 100644 --- a/src/servers/src/postgres/auth_handler.rs +++ b/src/servers/src/postgres/auth_handler.rs @@ -33,10 +33,33 @@ struct PgPwdVerifier { user_provider: Option, } +#[allow(dead_code)] +struct LoginInfo { + user: Option, + database: Option, + host: String, +} + +impl LoginInfo { + pub fn from_client_info(client: &C) -> LoginInfo + where + C: ClientInfo, + { + LoginInfo { + user: client.metadata().get(super::METADATA_USER).map(Into::into), + database: client + .metadata() + .get(super::METADATA_DATABASE) + .map(Into::into), + host: client.socket_addr().ip().to_string(), + } + } +} + impl PgPwdVerifier { - async fn verify_pwd(&self, pwd: &str, meta: HashMap) -> Result { + async fn verify_pwd(&self, pwd: &str, login: LoginInfo) -> Result { if let Some(user_provider) = &self.user_provider { - let user_name = match meta.get("user") { + let user_name = match login.user { Some(name) => name, None => return Ok(false), }; @@ -44,7 +67,7 @@ impl PgPwdVerifier { // TODO(fys): pass user_info to context let _user_info = user_provider .auth( - Identity::UserId(user_name, None), + Identity::UserId(&user_name, None), Password::PlainText(pwd.as_bytes()), ) .await @@ -140,8 +163,8 @@ impl StartupHandler for PgAuthStartupHandler { } } PgWireFrontendMessage::Password(ref pwd) => { - let meta = client.metadata().clone(); - if let Ok(true) = self.verifier.verify_pwd(pwd.password(), meta).await { + let login_info = LoginInfo::from_client_info(client); + if let Ok(true) = self.verifier.verify_pwd(pwd.password(), login_info).await { auth::finish_authentication(client, &self.param_provider).await } else { let error_info = ErrorInfo::new( diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index 6cf82465a0..36dbd80d33 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -42,14 +42,12 @@ impl PostgresServerHandler { } } -const CLIENT_METADATA_DATABASE: &str = "database"; - fn query_context_from_client_info(client: &C) -> Arc where C: ClientInfo, { let query_context = QueryContext::new(); - if let Some(current_schema) = client.metadata().get(CLIENT_METADATA_DATABASE) { + if let Some(current_schema) = client.metadata().get(super::METADATA_DATABASE) { query_context.set_current_schema(current_schema); } diff --git a/src/servers/src/postgres/mod.rs b/src/servers/src/postgres/mod.rs index 46a02e7fc1..5b325ec374 100644 --- a/src/servers/src/postgres/mod.rs +++ b/src/servers/src/postgres/mod.rs @@ -16,4 +16,7 @@ mod auth_handler; mod handler; mod server; +pub(crate) const METADATA_USER: &str = "user"; +pub(crate) const METADATA_DATABASE: &str = "database"; + pub use server::PostgresServer;