diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index 68c5becafb..1dd51d1c0e 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -143,6 +143,7 @@ impl BackendType<'_, ClientCredentials<'_>> { &mut self, extra: &ConsoleReqExtra<'_>, client: &mut stream::PqStream, + use_cleartext_password_flow: bool, ) -> auth::Result>> { use BackendType::*; @@ -190,7 +191,7 @@ impl BackendType<'_, ClientCredentials<'_>> { (node, payload) } - Console(endpoint, creds) if creds.use_cleartext_password_flow => { + Console(endpoint, creds) if use_cleartext_password_flow => { // This is a hack to allow cleartext password in secure connections (wss). let payload = fetch_plaintext_password(client).await?; let creds = creds.as_ref(); @@ -220,17 +221,25 @@ impl BackendType<'_, ClientCredentials<'_>> { } /// Authenticate the client via the requested backend, possibly using credentials. + /// + /// If `use_cleartext_password_flow` is true, we use the old cleartext password + /// flow. It is used for websocket connections, which want to minimize the number + /// of round trips. #[instrument(skip_all)] pub async fn authenticate( mut self, extra: &ConsoleReqExtra<'_>, client: &mut stream::PqStream, + use_cleartext_password_flow: bool, ) -> auth::Result> { use BackendType::*; // Handle cases when `project` is missing in `creds`. // TODO: type safety: return `creds` with irrefutable `project`. - if let Some(res) = self.try_password_hack(extra, client).await? { + if let Some(res) = self + .try_password_hack(extra, client, use_cleartext_password_flow) + .await? + { info!("user successfully authenticated (using the password hack)"); return Ok(res); } diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 3b71bef9aa..0a3b84bb52 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -34,9 +34,6 @@ pub struct ClientCredentials<'a> { pub user: &'a str, pub dbname: &'a str, pub project: Option>, - /// If `True`, we'll use the old cleartext password flow. This is used for - /// websocket connections, which want to minimize the number of round trips. - pub use_cleartext_password_flow: bool, } impl ClientCredentials<'_> { @@ -53,7 +50,6 @@ impl<'a> ClientCredentials<'a> { user: self.user, dbname: self.dbname, project: self.project().map(Cow::Borrowed), - use_cleartext_password_flow: self.use_cleartext_password_flow, } } } @@ -63,7 +59,6 @@ impl<'a> ClientCredentials<'a> { params: &'a StartupMessageParams, sni: Option<&str>, common_name: Option<&str>, - use_cleartext_password_flow: bool, ) -> Result { use ClientCredsParseError::*; @@ -113,7 +108,6 @@ impl<'a> ClientCredentials<'a> { user = user, dbname = dbname, project = project.as_deref(), - use_cleartext_password_flow = use_cleartext_password_flow, "credentials" ); @@ -121,7 +115,6 @@ impl<'a> ClientCredentials<'a> { user, dbname, project, - use_cleartext_password_flow, }) } } @@ -148,7 +141,7 @@ mod tests { let options = StartupMessageParams::new([("user", "john_doe")]); // TODO: check that `creds.dbname` is None. - let creds = ClientCredentials::parse(&options, None, None, false)?; + let creds = ClientCredentials::parse(&options, None, None)?; assert_eq!(creds.user, "john_doe"); Ok(()) @@ -158,7 +151,7 @@ mod tests { fn parse_missing_project() -> anyhow::Result<()> { let options = StartupMessageParams::new([("user", "john_doe"), ("database", "world")]); - let creds = ClientCredentials::parse(&options, None, None, false)?; + let creds = ClientCredentials::parse(&options, None, None)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.dbname, "world"); assert_eq!(creds.project, None); @@ -173,7 +166,7 @@ mod tests { let sni = Some("foo.localhost"); let common_name = Some("localhost"); - let creds = ClientCredentials::parse(&options, sni, common_name, false)?; + let creds = ClientCredentials::parse(&options, sni, common_name)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.dbname, "world"); assert_eq!(creds.project.as_deref(), Some("foo")); @@ -189,7 +182,7 @@ mod tests { ("options", "-ckey=1 project=bar -c geqo=off"), ]); - let creds = ClientCredentials::parse(&options, None, None, false)?; + let creds = ClientCredentials::parse(&options, None, None)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.dbname, "world"); assert_eq!(creds.project.as_deref(), Some("bar")); @@ -208,7 +201,7 @@ mod tests { let sni = Some("baz.localhost"); let common_name = Some("localhost"); - let creds = ClientCredentials::parse(&options, sni, common_name, false)?; + let creds = ClientCredentials::parse(&options, sni, common_name)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.dbname, "world"); assert_eq!(creds.project.as_deref(), Some("baz")); @@ -227,8 +220,7 @@ mod tests { let sni = Some("second.localhost"); let common_name = Some("localhost"); - let err = - ClientCredentials::parse(&options, sni, common_name, false).expect_err("should fail"); + let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail"); match err { InconsistentProjectNames { domain, option } => { assert_eq!(option, "first"); @@ -245,8 +237,7 @@ mod tests { let sni = Some("project.localhost"); let common_name = Some("example.com"); - let err = - ClientCredentials::parse(&options, sni, common_name, false).expect_err("should fail"); + let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail"); match err { InconsistentSni { sni, cn } => { assert_eq!(sni, "project.localhost"); diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index b219cd0fa2..304214fe93 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -25,12 +25,11 @@ impl CancelMap { cancel_closure.try_cancel_query().await } - /// Run async action within an ephemeral session identified by [`CancelKeyData`]. - pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result - where - F: FnOnce(Session<'a>) -> R, - R: std::future::Future>, - { + /// Create a new session, with a new client-facing random cancellation key. + /// + /// Use `enable_query_cancellation` to register a database cancellation + /// key with it, and to get the client-facing key. + pub fn new_session<'a>(&'a self) -> anyhow::Result> { // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't // expose it and we don't want to do another roundtrip to query // for it. The client will be able to notice that this is not the @@ -44,17 +43,9 @@ impl CancelMap { .lock() .try_insert(key, None) .map_err(|_| anyhow!("query cancellation key already exists: {key}"))?; - - // This will guarantee that the session gets dropped - // as soon as the future is finished. - scopeguard::defer! { - self.0.lock().remove(&key); - info!("dropped query cancellation key {key}"); - } - info!("registered new query cancellation key {key}"); - let session = Session::new(key, self); - f(session).await + + Ok(Session::new(key, self)) } #[cfg(test)] @@ -111,7 +102,7 @@ impl<'a> Session<'a> { impl Session<'_> { /// Store the cancel token for the given session. /// This enables query cancellation in [`crate::proxy::handshake`]. - pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData { + pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData { info!("enabling query cancellation for this session"); self.cancel_map .0 @@ -122,6 +113,14 @@ impl Session<'_> { } } +impl<'a> Drop for Session<'a> { + fn drop(&mut self) { + let key = &self.key; + self.cancel_map.0.lock().remove(key); + info!("dropped query cancellation key {key}"); + } +} + #[cfg(test)] mod tests { use super::*; @@ -132,14 +131,14 @@ mod tests { static CANCEL_MAP: Lazy = Lazy::new(Default::default); let (tx, rx) = tokio::sync::oneshot::channel(); - let task = tokio::spawn(CANCEL_MAP.with_session(|session| async move { + + let session = CANCEL_MAP.new_session()?; + let task = tokio::spawn(async move { assert!(CANCEL_MAP.contains(&session)); tx.send(()).expect("failed to send"); futures::future::pending::<()>().await; // sleep forever - - Ok(()) - })); + }); // Wait until the task has been spawned. rx.await.context("failed to hear from the task")?; diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index e9e21b157c..2d69ee7435 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -109,16 +109,20 @@ pub async fn handle_ws_client( let result = config .auth_backend .as_ref() - .map(|_| auth::ClientCredentials::parse(¶ms, hostname, common_name, true)) + .map(|_| auth::ClientCredentials::parse(¶ms, hostname, common_name)) .transpose(); async { result }.or_else(|e| stream.throw_error(e)).await? }; - let client = Client::new(stream, creds, ¶ms, session_id); - cancel_map - .with_session(|session| client.handle_connection(session)) - .await + let client = Client::new( + stream, + creds, + ¶ms, + session_id, + cancel_map.new_session()?, + ); + client.handle_connection(true).await } /// Handle an incoming client connection, handshake and authentication. @@ -150,16 +154,20 @@ async fn handle_client( let result = config .auth_backend .as_ref() - .map(|_| auth::ClientCredentials::parse(¶ms, sni, common_name, false)) + .map(|_| auth::ClientCredentials::parse(¶ms, sni, common_name)) .transpose(); async { result }.or_else(|e| stream.throw_error(e)).await? }; - let client = Client::new(stream, creds, ¶ms, session_id); - cancel_map - .with_session(|session| client.handle_connection(session)) - .await + let client = Client::new( + stream, + creds, + ¶ms, + session_id, + cancel_map.new_session()?, + ); + client.handle_connection(false).await } /// Establish a (most probably, secure) connection with the client. @@ -238,6 +246,8 @@ struct Client<'a, S> { params: &'a StartupMessageParams, /// Unique connection ID. session_id: uuid::Uuid, + + session: cancellation::Session<'a>, } impl<'a, S> Client<'a, S> { @@ -247,19 +257,21 @@ impl<'a, S> Client<'a, S> { creds: auth::BackendType<'a, auth::ClientCredentials<'a>>, params: &'a StartupMessageParams, session_id: uuid::Uuid, + session: cancellation::Session<'a>, ) -> Self { Self { stream, creds, params, session_id, + session, } } } impl Client<'_, S> { - async fn handle_connection(self, session: cancellation::Session<'_>) -> anyhow::Result<()> { - let (mut client, mut db) = self.connect_to_db(session).await?; + async fn handle_connection(self, use_cleartext_password_flow: bool) -> anyhow::Result<()> { + let (mut client, mut db) = self.connect_to_db(use_cleartext_password_flow).await?; // Starting from here we only proxy the client's traffic. info!("performing the proxy pass..."); @@ -271,13 +283,14 @@ impl Client<'_, S> { #[instrument(skip_all)] async fn connect_to_db( self, - session: cancellation::Session<'_>, + use_cleartext_password_flow: bool, ) -> anyhow::Result<(MeasuredStream, MeasuredStream)> { let Self { mut stream, creds, params, session_id, + session, } = self; let extra = auth::ConsoleReqExtra { @@ -287,7 +300,9 @@ impl Client<'_, S> { let auth_result = async { // `&mut stream` doesn't let us merge those 2 lines. - let res = creds.authenticate(&extra, &mut stream).await; + let res = creds + .authenticate(&extra, &mut stream, use_cleartext_password_flow) + .await; async { res }.or_else(|e| stream.throw_error(e)).await } .await?;