From 4d68e3108fbf1822044e8c5d370c1cbb3a1c0e8f Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Tue, 24 Jan 2023 21:35:56 +0200 Subject: [PATCH] Refactor use_cleartext_password_flow. It's not a property of the credentials that we receive from the client, so remove it from ClientCredentials. Instead, pass it as an argument directly to 'authenticate' function, where it's actually used. All the rest of the changes is just plumbing to pass it through the call stack to 'authenticate' --- proxy/src/auth/backend.rs | 13 +++++++++-- proxy/src/auth/credentials.rs | 23 ++++++------------- proxy/src/cancellation.rs | 41 ++++++++++++++++----------------- proxy/src/proxy.rs | 43 +++++++++++++++++++++++------------ 4 files changed, 67 insertions(+), 53 deletions(-) diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index 68c5becafb..1dd51d1c0e 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -143,6 +143,7 @@ impl BackendType<'_, ClientCredentials<'_>> { &mut self, extra: &ConsoleReqExtra<'_>, client: &mut stream::PqStream, + use_cleartext_password_flow: bool, ) -> auth::Result>> { use BackendType::*; @@ -190,7 +191,7 @@ impl BackendType<'_, ClientCredentials<'_>> { (node, payload) } - Console(endpoint, creds) if creds.use_cleartext_password_flow => { + Console(endpoint, creds) if use_cleartext_password_flow => { // This is a hack to allow cleartext password in secure connections (wss). let payload = fetch_plaintext_password(client).await?; let creds = creds.as_ref(); @@ -220,17 +221,25 @@ impl BackendType<'_, ClientCredentials<'_>> { } /// Authenticate the client via the requested backend, possibly using credentials. + /// + /// If `use_cleartext_password_flow` is true, we use the old cleartext password + /// flow. It is used for websocket connections, which want to minimize the number + /// of round trips. #[instrument(skip_all)] pub async fn authenticate( mut self, extra: &ConsoleReqExtra<'_>, client: &mut stream::PqStream, + use_cleartext_password_flow: bool, ) -> auth::Result> { use BackendType::*; // Handle cases when `project` is missing in `creds`. // TODO: type safety: return `creds` with irrefutable `project`. - if let Some(res) = self.try_password_hack(extra, client).await? { + if let Some(res) = self + .try_password_hack(extra, client, use_cleartext_password_flow) + .await? + { info!("user successfully authenticated (using the password hack)"); return Ok(res); } diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 3b71bef9aa..0a3b84bb52 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -34,9 +34,6 @@ pub struct ClientCredentials<'a> { pub user: &'a str, pub dbname: &'a str, pub project: Option>, - /// If `True`, we'll use the old cleartext password flow. This is used for - /// websocket connections, which want to minimize the number of round trips. - pub use_cleartext_password_flow: bool, } impl ClientCredentials<'_> { @@ -53,7 +50,6 @@ impl<'a> ClientCredentials<'a> { user: self.user, dbname: self.dbname, project: self.project().map(Cow::Borrowed), - use_cleartext_password_flow: self.use_cleartext_password_flow, } } } @@ -63,7 +59,6 @@ impl<'a> ClientCredentials<'a> { params: &'a StartupMessageParams, sni: Option<&str>, common_name: Option<&str>, - use_cleartext_password_flow: bool, ) -> Result { use ClientCredsParseError::*; @@ -113,7 +108,6 @@ impl<'a> ClientCredentials<'a> { user = user, dbname = dbname, project = project.as_deref(), - use_cleartext_password_flow = use_cleartext_password_flow, "credentials" ); @@ -121,7 +115,6 @@ impl<'a> ClientCredentials<'a> { user, dbname, project, - use_cleartext_password_flow, }) } } @@ -148,7 +141,7 @@ mod tests { let options = StartupMessageParams::new([("user", "john_doe")]); // TODO: check that `creds.dbname` is None. - let creds = ClientCredentials::parse(&options, None, None, false)?; + let creds = ClientCredentials::parse(&options, None, None)?; assert_eq!(creds.user, "john_doe"); Ok(()) @@ -158,7 +151,7 @@ mod tests { fn parse_missing_project() -> anyhow::Result<()> { let options = StartupMessageParams::new([("user", "john_doe"), ("database", "world")]); - let creds = ClientCredentials::parse(&options, None, None, false)?; + let creds = ClientCredentials::parse(&options, None, None)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.dbname, "world"); assert_eq!(creds.project, None); @@ -173,7 +166,7 @@ mod tests { let sni = Some("foo.localhost"); let common_name = Some("localhost"); - let creds = ClientCredentials::parse(&options, sni, common_name, false)?; + let creds = ClientCredentials::parse(&options, sni, common_name)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.dbname, "world"); assert_eq!(creds.project.as_deref(), Some("foo")); @@ -189,7 +182,7 @@ mod tests { ("options", "-ckey=1 project=bar -c geqo=off"), ]); - let creds = ClientCredentials::parse(&options, None, None, false)?; + let creds = ClientCredentials::parse(&options, None, None)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.dbname, "world"); assert_eq!(creds.project.as_deref(), Some("bar")); @@ -208,7 +201,7 @@ mod tests { let sni = Some("baz.localhost"); let common_name = Some("localhost"); - let creds = ClientCredentials::parse(&options, sni, common_name, false)?; + let creds = ClientCredentials::parse(&options, sni, common_name)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.dbname, "world"); assert_eq!(creds.project.as_deref(), Some("baz")); @@ -227,8 +220,7 @@ mod tests { let sni = Some("second.localhost"); let common_name = Some("localhost"); - let err = - ClientCredentials::parse(&options, sni, common_name, false).expect_err("should fail"); + let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail"); match err { InconsistentProjectNames { domain, option } => { assert_eq!(option, "first"); @@ -245,8 +237,7 @@ mod tests { let sni = Some("project.localhost"); let common_name = Some("example.com"); - let err = - ClientCredentials::parse(&options, sni, common_name, false).expect_err("should fail"); + let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail"); match err { InconsistentSni { sni, cn } => { assert_eq!(sni, "project.localhost"); diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index b219cd0fa2..304214fe93 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -25,12 +25,11 @@ impl CancelMap { cancel_closure.try_cancel_query().await } - /// Run async action within an ephemeral session identified by [`CancelKeyData`]. - pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result - where - F: FnOnce(Session<'a>) -> R, - R: std::future::Future>, - { + /// Create a new session, with a new client-facing random cancellation key. + /// + /// Use `enable_query_cancellation` to register a database cancellation + /// key with it, and to get the client-facing key. + pub fn new_session<'a>(&'a self) -> anyhow::Result> { // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't // expose it and we don't want to do another roundtrip to query // for it. The client will be able to notice that this is not the @@ -44,17 +43,9 @@ impl CancelMap { .lock() .try_insert(key, None) .map_err(|_| anyhow!("query cancellation key already exists: {key}"))?; - - // This will guarantee that the session gets dropped - // as soon as the future is finished. - scopeguard::defer! { - self.0.lock().remove(&key); - info!("dropped query cancellation key {key}"); - } - info!("registered new query cancellation key {key}"); - let session = Session::new(key, self); - f(session).await + + Ok(Session::new(key, self)) } #[cfg(test)] @@ -111,7 +102,7 @@ impl<'a> Session<'a> { impl Session<'_> { /// Store the cancel token for the given session. /// This enables query cancellation in [`crate::proxy::handshake`]. - pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData { + pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData { info!("enabling query cancellation for this session"); self.cancel_map .0 @@ -122,6 +113,14 @@ impl Session<'_> { } } +impl<'a> Drop for Session<'a> { + fn drop(&mut self) { + let key = &self.key; + self.cancel_map.0.lock().remove(key); + info!("dropped query cancellation key {key}"); + } +} + #[cfg(test)] mod tests { use super::*; @@ -132,14 +131,14 @@ mod tests { static CANCEL_MAP: Lazy = Lazy::new(Default::default); let (tx, rx) = tokio::sync::oneshot::channel(); - let task = tokio::spawn(CANCEL_MAP.with_session(|session| async move { + + let session = CANCEL_MAP.new_session()?; + let task = tokio::spawn(async move { assert!(CANCEL_MAP.contains(&session)); tx.send(()).expect("failed to send"); futures::future::pending::<()>().await; // sleep forever - - Ok(()) - })); + }); // Wait until the task has been spawned. rx.await.context("failed to hear from the task")?; diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index e9e21b157c..2d69ee7435 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -109,16 +109,20 @@ pub async fn handle_ws_client( let result = config .auth_backend .as_ref() - .map(|_| auth::ClientCredentials::parse(¶ms, hostname, common_name, true)) + .map(|_| auth::ClientCredentials::parse(¶ms, hostname, common_name)) .transpose(); async { result }.or_else(|e| stream.throw_error(e)).await? }; - let client = Client::new(stream, creds, ¶ms, session_id); - cancel_map - .with_session(|session| client.handle_connection(session)) - .await + let client = Client::new( + stream, + creds, + ¶ms, + session_id, + cancel_map.new_session()?, + ); + client.handle_connection(true).await } /// Handle an incoming client connection, handshake and authentication. @@ -150,16 +154,20 @@ async fn handle_client( let result = config .auth_backend .as_ref() - .map(|_| auth::ClientCredentials::parse(¶ms, sni, common_name, false)) + .map(|_| auth::ClientCredentials::parse(¶ms, sni, common_name)) .transpose(); async { result }.or_else(|e| stream.throw_error(e)).await? }; - let client = Client::new(stream, creds, ¶ms, session_id); - cancel_map - .with_session(|session| client.handle_connection(session)) - .await + let client = Client::new( + stream, + creds, + ¶ms, + session_id, + cancel_map.new_session()?, + ); + client.handle_connection(false).await } /// Establish a (most probably, secure) connection with the client. @@ -238,6 +246,8 @@ struct Client<'a, S> { params: &'a StartupMessageParams, /// Unique connection ID. session_id: uuid::Uuid, + + session: cancellation::Session<'a>, } impl<'a, S> Client<'a, S> { @@ -247,19 +257,21 @@ impl<'a, S> Client<'a, S> { creds: auth::BackendType<'a, auth::ClientCredentials<'a>>, params: &'a StartupMessageParams, session_id: uuid::Uuid, + session: cancellation::Session<'a>, ) -> Self { Self { stream, creds, params, session_id, + session, } } } impl Client<'_, S> { - async fn handle_connection(self, session: cancellation::Session<'_>) -> anyhow::Result<()> { - let (mut client, mut db) = self.connect_to_db(session).await?; + async fn handle_connection(self, use_cleartext_password_flow: bool) -> anyhow::Result<()> { + let (mut client, mut db) = self.connect_to_db(use_cleartext_password_flow).await?; // Starting from here we only proxy the client's traffic. info!("performing the proxy pass..."); @@ -271,13 +283,14 @@ impl Client<'_, S> { #[instrument(skip_all)] async fn connect_to_db( self, - session: cancellation::Session<'_>, + use_cleartext_password_flow: bool, ) -> anyhow::Result<(MeasuredStream, MeasuredStream)> { let Self { mut stream, creds, params, session_id, + session, } = self; let extra = auth::ConsoleReqExtra { @@ -287,7 +300,9 @@ impl Client<'_, S> { let auth_result = async { // `&mut stream` doesn't let us merge those 2 lines. - let res = creds.authenticate(&extra, &mut stream).await; + let res = creds + .authenticate(&extra, &mut stream, use_cleartext_password_flow) + .await; async { res }.or_else(|e| stream.throw_error(e)).await } .await?;