diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index eadb9abd43..64ef108e11 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -87,6 +87,10 @@ impl AuthError { pub fn too_many_connections() -> Self { AuthErrorImpl::TooManyConnections.into() } + + pub fn is_auth_failed(&self) -> bool { + matches!(self.0.as_ref(), AuthErrorImpl::AuthFailed(_)) + } } impl> From for AuthError { diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index 0c867dfd61..923bd02560 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -192,14 +192,46 @@ async fn auth_quirks( if !check_peer_addr_is_in_list(&info.inner.peer_addr, &allowed_ips) { return Err(auth::AuthError::ip_address_not_allowed()); } - let secret = api.get_role_secret(extra, &info).await?.unwrap_or_else(|| { + let cached_secret = api.get_role_secret(extra, &info).await?; + + let secret = cached_secret.clone().unwrap_or_else(|| { // If we don't have an authentication secret, we mock one to // prevent malicious probing (possible due to missing protocol steps). // This mocked secret will never lead to successful authentication. info!("authentication info not found, mocking it"); AuthSecret::Scram(scram::ServerSecret::mock(&info.inner.user, rand::random())) }); + match authenticate_with_secret( + secret, + info, + client, + unauthenticated_password, + allow_cleartext, + config, + latency_timer, + ) + .await + { + Ok(keys) => Ok(keys), + Err(e) => { + if e.is_auth_failed() { + // The password could have been changed, so we invalidate the cache. + cached_secret.invalidate(); + } + Err(e) + } + } +} +async fn authenticate_with_secret( + secret: AuthSecret, + info: ComputeUserInfo, + client: &mut stream::PqStream>, + unauthenticated_password: Option>, + allow_cleartext: bool, + config: &'static AuthenticationConfig, + latency_timer: &mut LatencyTimer, +) -> auth::Result> { if let Some(password) = unauthenticated_password { let auth_outcome = validate_password_and_exchange(&password, secret)?; let keys = match auth_outcome { diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index 7ef5e950b0..e4cf1e8c8e 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -254,6 +254,7 @@ pub type NodeInfoCache = TimedLru, NodeInfo>; pub type CachedNodeInfo = timed_lru::Cached<&'static NodeInfoCache>; pub type AllowedIpsCache = TimedLru>>; pub type RoleSecretCache = TimedLru<(SmolStr, SmolStr), Option>; +pub type CachedRoleSecret = timed_lru::Cached<&'static RoleSecretCache>; /// This will allocate per each call, but the http requests alone /// already require a few allocations, so it should be fine. @@ -264,7 +265,7 @@ pub trait Api { &self, extra: &ConsoleReqExtra, creds: &ComputeUserInfo, - ) -> Result, errors::GetAuthInfoError>; + ) -> Result; async fn get_allowed_ips( &self, diff --git a/proxy/src/console/provider/mock.rs b/proxy/src/console/provider/mock.rs index 9c4a7447c6..dba5e5863f 100644 --- a/proxy/src/console/provider/mock.rs +++ b/proxy/src/console/provider/mock.rs @@ -6,6 +6,7 @@ use super::{ errors::{ApiError, GetAuthInfoError, WakeComputeError}, AuthInfo, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo, }; +use crate::console::provider::CachedRoleSecret; use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl}; use async_trait::async_trait; use futures::TryFutureExt; @@ -146,8 +147,10 @@ impl super::Api for Api { &self, _extra: &ConsoleReqExtra, creds: &ComputeUserInfo, - ) -> Result, GetAuthInfoError> { - Ok(self.do_get_auth_info(creds).await?.secret) + ) -> Result { + Ok(CachedRoleSecret::new_uncached( + self.do_get_auth_info(creds).await?.secret, + )) } async fn get_allowed_ips( diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index 5bb91313c4..628d98df49 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -3,7 +3,8 @@ use super::{ super::messages::{ConsoleError, GetRoleSecret, WakeCompute}, errors::{ApiError, GetAuthInfoError, WakeComputeError}, - ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo, + ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedNodeInfo, CachedRoleSecret, ConsoleReqExtra, + NodeInfo, }; use crate::metrics::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER}; use crate::{auth::backend::ComputeUserInfo, compute, http, scram}; @@ -163,20 +164,21 @@ impl super::Api for Api { &self, extra: &ConsoleReqExtra, creds: &ComputeUserInfo, - ) -> Result, GetAuthInfoError> { + ) -> Result { let ep = creds.endpoint.clone(); let user = creds.inner.user.clone(); if let Some(role_secret) = self.caches.role_secret.get(&(ep.clone(), user.clone())) { - return Ok(role_secret.clone()); + return Ok(role_secret); } let auth_info = self.do_get_auth_info(extra, creds).await?; - self.caches + let (_, secret) = self + .caches .role_secret .insert((ep.clone(), user), auth_info.secret.clone()); self.caches .allowed_ips .insert(ep, Arc::new(auth_info.allowed_ips)); - Ok(auth_info.secret) + Ok(secret) } async fn get_allowed_ips(