From e7a1d5de94c992e8811e62e84070d49c203b9d06 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Tue, 29 Jul 2025 07:48:14 +0100 Subject: [PATCH] proxy: cache for password hashing (#12011) ## Problem Password hashing for sql-over-http takes up a lot of CPU. Perhaps we can get away with temporarily caching some steps so we only need fewer rounds, which will save some CPU time. ## Summary of changes The output of pbkdf2 is the XOR of the outputs of each iteration round, eg `U1 ^ U2 ^ ... U15 ^ U16 ^ U17 ^ ... ^ Un`. We cache the suffix of the expression `U16 ^ U17 ^ ... ^ Un`. To compute the result from the cached suffix, we only need to compute the prefix `U1 ^ U2 ^ ... U15`. The suffix by itself is useless, which prevent's its use in brute-force attacks should this cached memory leak. We are also caching the full 4096 round hash in memory, which can be used for brute-force attacks, where this suffix could be used to speed it up. My hope/expectation is that since these will be in different allocations, it makes any such memory exploitation much much harder. Since the full hash cache might be invalidated while the suffix is cached, I'm storing the timestamp of the computation as a way to identity the match. I also added `zeroize()` to clear the sensitive state from the stack/heap. For the most security conscious customers, we hope to roll out OIDC soon, so they can disable passwords entirely. --- The numbers for the threadpool were pretty random, but according to our busiest region for sql-over-http, we only see about 150 unique endpoints every minute. So storing ~100 of the most common endpoints for that minute should be the vast majority of requests. 1 minute was chosen so we don't keep data in memory for too long. --- Cargo.lock | 1 + Cargo.toml | 3 +- proxy/Cargo.toml | 1 + proxy/src/auth/backend/hacks.rs | 6 +- proxy/src/auth/backend/mod.rs | 8 ++- proxy/src/auth/flow.rs | 8 ++- proxy/src/binary/local_proxy.rs | 6 +- proxy/src/binary/pg_sni_router.rs | 4 +- proxy/src/binary/proxy.rs | 9 ++- proxy/src/config.rs | 4 +- proxy/src/metrics.rs | 65 ++++++++++------- proxy/src/scram/cache.rs | 84 ++++++++++++++++++++++ proxy/src/scram/exchange.rs | 113 ++++++++++++++++++++++++------ proxy/src/scram/key.rs | 28 +++++++- proxy/src/scram/mod.rs | 71 ++++++++++++------- proxy/src/scram/pbkdf2.rs | 81 +++++++++++++++------ proxy/src/scram/secret.rs | 6 ++ proxy/src/scram/signature.rs | 23 +++--- proxy/src/scram/threadpool.rs | 17 +++-- proxy/src/serverless/backend.rs | 6 +- 20 files changed, 414 insertions(+), 130 deletions(-) create mode 100644 proxy/src/scram/cache.rs diff --git a/Cargo.lock b/Cargo.lock index b43f4fdea0..2c15a47c96 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5519,6 +5519,7 @@ dependencies = [ "workspace_hack", "x509-cert", "zerocopy 0.8.24", + "zeroize", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index d2dab67220..8051b3ee3e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -234,9 +234,10 @@ uuid = { version = "1.6.1", features = ["v4", "v7", "serde"] } walkdir = "2.3.2" rustls-native-certs = "0.8" whoami = "1.5.1" -zerocopy = { version = "0.8", features = ["derive", "simd"] } json-structural-diff = { version = "0.2.0" } x509-cert = { version = "0.2.5" } +zerocopy = { version = "0.8", features = ["derive", "simd"] } +zeroize = "1.8" ## TODO replace this with tracing env_logger = "0.11" diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index eb8c0ed037..0ece79c329 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -107,6 +107,7 @@ uuid.workspace = true x509-cert.workspace = true redis.workspace = true zerocopy.workspace = true +zeroize.workspace = true # uncomment this to use the real subzero-core crate # subzero-core = { git = "https://github.com/neondatabase/subzero", rev = "396264617e78e8be428682f87469bb25429af88a", features = ["postgresql"], optional = true } # this is a stub for the subzero-core crate diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index 1e5c076fb9..491f14b1b6 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -6,7 +6,7 @@ use crate::auth::{self, AuthFlow}; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::AuthSecret; -use crate::intern::EndpointIdInt; +use crate::intern::{EndpointIdInt, RoleNameInt}; use crate::sasl; use crate::stream::{self, Stream}; @@ -25,13 +25,15 @@ pub(crate) async fn authenticate_cleartext( ctx.set_auth_method(crate::context::AuthMethod::Cleartext); let ep = EndpointIdInt::from(&info.endpoint); + let role = RoleNameInt::from(&info.user); let auth_flow = AuthFlow::new( client, auth::CleartextPassword { secret, endpoint: ep, - pool: config.thread_pool.clone(), + role, + pool: config.scram_thread_pool.clone(), }, ); let auth_outcome = { diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index 30229cec23..a6df2a7011 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -25,7 +25,7 @@ use crate::control_plane::messages::EndpointRateLimitConfig; use crate::control_plane::{ self, AccessBlockerFlags, AuthSecret, ControlPlaneApi, EndpointAccessControl, RoleAccessControl, }; -use crate::intern::EndpointIdInt; +use crate::intern::{EndpointIdInt, RoleNameInt}; use crate::pqproto::BeMessage; use crate::proxy::NeonOptions; use crate::proxy::wake_compute::WakeComputeBackend; @@ -273,9 +273,11 @@ async fn authenticate_with_secret( ) -> auth::Result { if let Some(password) = unauthenticated_password { let ep = EndpointIdInt::from(&info.endpoint); + let role = RoleNameInt::from(&info.user); let auth_outcome = - validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?; + validate_password_and_exchange(&config.scram_thread_pool, ep, role, &password, secret) + .await?; let keys = match auth_outcome { crate::sasl::Outcome::Success(key) => key, crate::sasl::Outcome::Failure(reason) => { @@ -499,7 +501,7 @@ mod tests { static CONFIG: Lazy = Lazy::new(|| AuthenticationConfig { jwks_cache: JwkCache::default(), - thread_pool: ThreadPool::new(1), + scram_thread_pool: ThreadPool::new(1), scram_protocol_timeout: std::time::Duration::from_secs(5), ip_allowlist_check_enabled: true, is_vpc_acccess_proxy: false, diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index c825d5bf4b..00cd274e99 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -10,7 +10,7 @@ use super::backend::ComputeCredentialKeys; use super::{AuthError, PasswordHackPayload}; use crate::context::RequestContext; use crate::control_plane::AuthSecret; -use crate::intern::EndpointIdInt; +use crate::intern::{EndpointIdInt, RoleNameInt}; use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage}; use crate::sasl; use crate::scram::threadpool::ThreadPool; @@ -46,6 +46,7 @@ pub(crate) struct PasswordHack; pub(crate) struct CleartextPassword { pub(crate) pool: Arc, pub(crate) endpoint: EndpointIdInt, + pub(crate) role: RoleNameInt, pub(crate) secret: AuthSecret, } @@ -111,6 +112,7 @@ impl AuthFlow<'_, S, CleartextPassword> { let outcome = validate_password_and_exchange( &self.state.pool, self.state.endpoint, + self.state.role, password, self.state.secret, ) @@ -165,13 +167,15 @@ impl AuthFlow<'_, S, Scram<'_>> { pub(crate) async fn validate_password_and_exchange( pool: &ThreadPool, endpoint: EndpointIdInt, + role: RoleNameInt, password: &[u8], secret: AuthSecret, ) -> super::Result> { match secret { // perform scram authentication as both client and server to validate the keys AuthSecret::Scram(scram_secret) => { - let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?; + let outcome = + crate::scram::exchange(pool, endpoint, role, &scram_secret, password).await?; let client_key = match outcome { sasl::Outcome::Success(client_key) => client_key, diff --git a/proxy/src/binary/local_proxy.rs b/proxy/src/binary/local_proxy.rs index 3a17e6da83..86b64c62c9 100644 --- a/proxy/src/binary/local_proxy.rs +++ b/proxy/src/binary/local_proxy.rs @@ -29,7 +29,7 @@ use crate::config::{ }; use crate::control_plane::locks::ApiLocks; use crate::http::health_server::AppMetrics; -use crate::metrics::{Metrics, ServiceInfo, ThreadPoolMetrics}; +use crate::metrics::{Metrics, ServiceInfo}; use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo}; use crate::scram::threadpool::ThreadPool; use crate::serverless::cancel_set::CancelSet; @@ -114,8 +114,6 @@ pub async fn run() -> anyhow::Result<()> { let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook(); let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]); - Metrics::install(Arc::new(ThreadPoolMetrics::new(0))); - // TODO: refactor these to use labels debug!("Version: {GIT_VERSION}"); debug!("Build_tag: {BUILD_TAG}"); @@ -284,7 +282,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig http_config, authentication_config: AuthenticationConfig { jwks_cache: JwkCache::default(), - thread_pool: ThreadPool::new(0), + scram_thread_pool: ThreadPool::new(0), scram_protocol_timeout: Duration::from_secs(10), ip_allowlist_check_enabled: true, is_vpc_acccess_proxy: false, diff --git a/proxy/src/binary/pg_sni_router.rs b/proxy/src/binary/pg_sni_router.rs index b22b413b74..cdbf0f09ac 100644 --- a/proxy/src/binary/pg_sni_router.rs +++ b/proxy/src/binary/pg_sni_router.rs @@ -26,7 +26,7 @@ use utils::project_git_version; use utils::sentry_init::init_sentry; use crate::context::RequestContext; -use crate::metrics::{Metrics, ServiceInfo, ThreadPoolMetrics}; +use crate::metrics::{Metrics, ServiceInfo}; use crate::pglb::TlsRequired; use crate::pqproto::FeStartupPacket; use crate::protocol2::ConnectionInfo; @@ -80,8 +80,6 @@ pub async fn run() -> anyhow::Result<()> { let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook(); let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]); - Metrics::install(Arc::new(ThreadPoolMetrics::new(0))); - let args = cli().get_matches(); let destination: String = args .get_one::("dest") diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 0ea5a89945..29b0ad53f2 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -617,7 +617,12 @@ pub async fn run() -> anyhow::Result<()> { /// ProxyConfig is created at proxy startup, and lives forever. fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { let thread_pool = ThreadPool::new(args.scram_thread_pool_size); - Metrics::install(thread_pool.metrics.clone()); + Metrics::get() + .proxy + .scram_pool + .0 + .set(thread_pool.metrics.clone()) + .ok(); let tls_config = match (&args.tls_key, &args.tls_cert) { (Some(key_path), Some(cert_path)) => Some(config::configure_tls( @@ -690,7 +695,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { }; let authentication_config = AuthenticationConfig { jwks_cache: JwkCache::default(), - thread_pool, + scram_thread_pool: thread_pool, scram_protocol_timeout: args.scram_protocol_timeout, ip_allowlist_check_enabled: !args.is_private_access_proxy, is_vpc_acccess_proxy: args.is_private_access_proxy, diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 04080efcca..22902dbcab 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -19,7 +19,7 @@ use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings}; use crate::ext::TaskExt; use crate::intern::RoleNameInt; use crate::rate_limiter::{RateLimitAlgorithm, RateLimiterConfig}; -use crate::scram::threadpool::ThreadPool; +use crate::scram; use crate::serverless::GlobalConnPoolOptions; use crate::serverless::cancel_set::CancelSet; #[cfg(feature = "rest_broker")] @@ -75,7 +75,7 @@ pub struct HttpConfig { } pub struct AuthenticationConfig { - pub thread_pool: Arc, + pub scram_thread_pool: Arc, pub scram_protocol_timeout: tokio::time::Duration, pub ip_allowlist_check_enabled: bool, pub is_vpc_acccess_proxy: bool, diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index d5905efc5a..905c9b5279 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -5,6 +5,7 @@ use measured::label::{ FixedCardinalitySet, LabelGroupSet, LabelGroupVisitor, LabelName, LabelSet, LabelValue, StaticLabelSet, }; +use measured::metric::group::Encoding; use measured::metric::histogram::Thresholds; use measured::metric::name::MetricName; use measured::{ @@ -18,10 +19,10 @@ use crate::control_plane::messages::ColdStartInfo; use crate::error::ErrorKind; #[derive(MetricGroup)] -#[metric(new(thread_pool: Arc))] +#[metric(new())] pub struct Metrics { #[metric(namespace = "proxy")] - #[metric(init = ProxyMetrics::new(thread_pool))] + #[metric(init = ProxyMetrics::new())] pub proxy: ProxyMetrics, #[metric(namespace = "wake_compute_lock")] @@ -34,34 +35,27 @@ pub struct Metrics { pub cache: CacheMetrics, } -static SELF: OnceLock = OnceLock::new(); impl Metrics { - pub fn install(thread_pool: Arc) { - let mut metrics = Metrics::new(thread_pool); - - metrics.proxy.errors_total.init_all_dense(); - metrics.proxy.redis_errors_total.init_all_dense(); - metrics.proxy.redis_events_count.init_all_dense(); - metrics.proxy.retries_metric.init_all_dense(); - metrics.proxy.connection_failures_total.init_all_dense(); - - SELF.set(metrics) - .ok() - .expect("proxy metrics must not be installed more than once"); - } - + #[track_caller] pub fn get() -> &'static Self { - #[cfg(test)] - return SELF.get_or_init(|| Metrics::new(Arc::new(ThreadPoolMetrics::new(0)))); + static SELF: OnceLock = OnceLock::new(); - #[cfg(not(test))] - SELF.get() - .expect("proxy metrics must be installed by the main() function") + SELF.get_or_init(|| { + let mut metrics = Metrics::new(); + + metrics.proxy.errors_total.init_all_dense(); + metrics.proxy.redis_errors_total.init_all_dense(); + metrics.proxy.redis_events_count.init_all_dense(); + metrics.proxy.retries_metric.init_all_dense(); + metrics.proxy.connection_failures_total.init_all_dense(); + + metrics + }) } } #[derive(MetricGroup)] -#[metric(new(thread_pool: Arc))] +#[metric(new())] pub struct ProxyMetrics { #[metric(flatten)] pub db_connections: CounterPairVec, @@ -134,6 +128,9 @@ pub struct ProxyMetrics { /// Number of TLS handshake failures pub tls_handshake_failures: Counter, + /// Number of SHA 256 rounds executed. + pub sha_rounds: Counter, + /// HLL approximate cardinality of endpoints that are connecting pub connecting_endpoints: HyperLogLogVec, 32>, @@ -151,8 +148,25 @@ pub struct ProxyMetrics { pub connect_compute_lock: ApiLockMetrics, #[metric(namespace = "scram_pool")] - #[metric(init = thread_pool)] - pub scram_pool: Arc, + pub scram_pool: OnceLockWrapper>, +} + +/// A Wrapper over [`OnceLock`] to implement [`MetricGroup`]. +pub struct OnceLockWrapper(pub OnceLock); + +impl Default for OnceLockWrapper { + fn default() -> Self { + Self(OnceLock::new()) + } +} + +impl> MetricGroup for OnceLockWrapper { + fn collect_group_into(&self, enc: &mut Enc) -> Result<(), Enc::Err> { + if let Some(inner) = self.0.get() { + inner.collect_group_into(enc)?; + } + Ok(()) + } } #[derive(MetricGroup)] @@ -719,6 +733,7 @@ pub enum CacheKind { ProjectInfoEndpoints, ProjectInfoRoles, Schema, + Pbkdf2, } #[derive(FixedCardinalityLabel, Clone, Copy, Debug)] diff --git a/proxy/src/scram/cache.rs b/proxy/src/scram/cache.rs new file mode 100644 index 0000000000..9ade7af458 --- /dev/null +++ b/proxy/src/scram/cache.rs @@ -0,0 +1,84 @@ +use tokio::time::Instant; +use zeroize::Zeroize as _; + +use super::pbkdf2; +use crate::cache::Cached; +use crate::cache::common::{Cache, count_cache_insert, count_cache_outcome, eviction_listener}; +use crate::intern::{EndpointIdInt, RoleNameInt}; +use crate::metrics::{CacheKind, Metrics}; + +pub(crate) struct Pbkdf2Cache(moka::sync::Cache<(EndpointIdInt, RoleNameInt), Pbkdf2CacheEntry>); +pub(crate) type CachedPbkdf2<'a> = Cached<&'a Pbkdf2Cache>; + +impl Cache for Pbkdf2Cache { + type Key = (EndpointIdInt, RoleNameInt); + type Value = Pbkdf2CacheEntry; + + fn invalidate(&self, info: &(EndpointIdInt, RoleNameInt)) { + self.0.invalidate(info); + } +} + +/// To speed up password hashing for more active customers, we store the tail results of the +/// PBKDF2 algorithm. If the output of PBKDF2 is U1 ^ U2 ^ ⋯ ^ Uc, then we store +/// suffix = U17 ^ U18 ^ ⋯ ^ Uc. We only need to calculate U1 ^ U2 ^ ⋯ ^ U15 ^ U16 +/// to determine the final result. +/// +/// The suffix alone isn't enough to crack the password. The stored_key is still required. +/// While both are cached in memory, given they're in different locations is makes it much +/// harder to exploit, even if any such memory exploit exists in proxy. +#[derive(Clone)] +pub struct Pbkdf2CacheEntry { + /// corresponds to [`super::ServerSecret::cached_at`] + pub(super) cached_from: Instant, + pub(super) suffix: pbkdf2::Block, +} + +impl Drop for Pbkdf2CacheEntry { + fn drop(&mut self) { + self.suffix.zeroize(); + } +} + +impl Pbkdf2Cache { + pub fn new() -> Self { + const SIZE: u64 = 100; + const TTL: std::time::Duration = std::time::Duration::from_secs(60); + + let builder = moka::sync::Cache::builder() + .name("pbkdf2") + .max_capacity(SIZE) + // We use time_to_live so we don't refresh the lifetime for an invalid password attempt. + .time_to_live(TTL); + + Metrics::get() + .cache + .capacity + .set(CacheKind::Pbkdf2, SIZE as i64); + + let builder = + builder.eviction_listener(|_k, _v, cause| eviction_listener(CacheKind::Pbkdf2, cause)); + + Self(builder.build()) + } + + pub fn insert(&self, endpoint: EndpointIdInt, role: RoleNameInt, value: Pbkdf2CacheEntry) { + count_cache_insert(CacheKind::Pbkdf2); + self.0.insert((endpoint, role), value); + } + + fn get(&self, endpoint: EndpointIdInt, role: RoleNameInt) -> Option { + count_cache_outcome(CacheKind::Pbkdf2, self.0.get(&(endpoint, role))) + } + + pub fn get_entry( + &self, + endpoint: EndpointIdInt, + role: RoleNameInt, + ) -> Option> { + self.get(endpoint, role).map(|value| Cached { + token: Some((self, (endpoint, role))), + value, + }) + } +} diff --git a/proxy/src/scram/exchange.rs b/proxy/src/scram/exchange.rs index a0918fca9f..3f4b0d534b 100644 --- a/proxy/src/scram/exchange.rs +++ b/proxy/src/scram/exchange.rs @@ -4,10 +4,8 @@ use std::convert::Infallible; use base64::Engine as _; use base64::prelude::BASE64_STANDARD; -use hmac::{Hmac, Mac}; -use sha2::Sha256; +use tracing::{debug, trace}; -use super::ScramKey; use super::messages::{ ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN, }; @@ -15,8 +13,10 @@ use super::pbkdf2::Pbkdf2; use super::secret::ServerSecret; use super::signature::SignatureBuilder; use super::threadpool::ThreadPool; -use crate::intern::EndpointIdInt; +use super::{ScramKey, pbkdf2}; +use crate::intern::{EndpointIdInt, RoleNameInt}; use crate::sasl::{self, ChannelBinding, Error as SaslError}; +use crate::scram::cache::Pbkdf2CacheEntry; /// The only channel binding mode we currently support. #[derive(Debug)] @@ -77,46 +77,113 @@ impl<'a> Exchange<'a> { } } -// copied from async fn derive_client_key( pool: &ThreadPool, endpoint: EndpointIdInt, password: &[u8], salt: &[u8], iterations: u32, -) -> ScramKey { - let salted_password = pool - .spawn_job(endpoint, Pbkdf2::start(password, salt, iterations)) - .await; - - let make_key = |name| { - let key = Hmac::::new_from_slice(&salted_password) - .expect("HMAC is able to accept all key sizes") - .chain_update(name) - .finalize(); - - <[u8; 32]>::from(key.into_bytes()) - }; - - make_key(b"Client Key").into() +) -> pbkdf2::Block { + pool.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations)) + .await } +/// For cleartext flow, we need to derive the client key to +/// 1. authenticate the client. +/// 2. authenticate with compute. pub(crate) async fn exchange( pool: &ThreadPool, endpoint: EndpointIdInt, + role: RoleNameInt, + secret: &ServerSecret, + password: &[u8], +) -> sasl::Result> { + if secret.iterations > CACHED_ROUNDS { + exchange_with_cache(pool, endpoint, role, secret, password).await + } else { + let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?; + let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await; + Ok(validate_pbkdf2(secret, &hash)) + } +} + +/// Compute the client key using a cache. We cache the suffix of the pbkdf2 result only, +/// which is not enough by itself to perform an offline brute force. +async fn exchange_with_cache( + pool: &ThreadPool, + endpoint: EndpointIdInt, + role: RoleNameInt, secret: &ServerSecret, password: &[u8], ) -> sasl::Result> { let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?; - let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await; + debug_assert!( + secret.iterations > CACHED_ROUNDS, + "we should not cache password data if there isn't enough rounds needed" + ); + + // compute the prefix of the pbkdf2 output. + let prefix = derive_client_key(pool, endpoint, password, &salt, CACHED_ROUNDS).await; + + if let Some(entry) = pool.cache.get_entry(endpoint, role) { + // hot path: let's check the threadpool cache + if secret.cached_at == entry.cached_from { + // cache is valid. compute the full hash by adding the prefix to the suffix. + let mut hash = prefix; + pbkdf2::xor_assign(&mut hash, &entry.suffix); + let outcome = validate_pbkdf2(secret, &hash); + + if matches!(outcome, sasl::Outcome::Success(_)) { + trace!("password validated from cache"); + } + + return Ok(outcome); + } + + // cached key is no longer valid. + debug!("invalidating cached password"); + entry.invalidate(); + } + + // slow path: full password hash. + let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await; + let outcome = validate_pbkdf2(secret, &hash); + + let client_key = match outcome { + sasl::Outcome::Success(client_key) => client_key, + sasl::Outcome::Failure(_) => return Ok(outcome), + }; + + trace!("storing cached password"); + + // time to cache, compute the suffix by subtracting the prefix from the hash. + let mut suffix = hash; + pbkdf2::xor_assign(&mut suffix, &prefix); + + pool.cache.insert( + endpoint, + role, + Pbkdf2CacheEntry { + cached_from: secret.cached_at, + suffix, + }, + ); + + Ok(sasl::Outcome::Success(client_key)) +} + +fn validate_pbkdf2(secret: &ServerSecret, hash: &pbkdf2::Block) -> sasl::Outcome { + let client_key = super::ScramKey::client_key(&(*hash).into()); if secret.is_password_invalid(&client_key).into() { - Ok(sasl::Outcome::Failure("password doesn't match")) + sasl::Outcome::Failure("password doesn't match") } else { - Ok(sasl::Outcome::Success(client_key)) + sasl::Outcome::Success(client_key) } } +const CACHED_ROUNDS: u32 = 16; + impl SaslInitial { fn transition( &self, diff --git a/proxy/src/scram/key.rs b/proxy/src/scram/key.rs index fe55ff493b..7dc52fd409 100644 --- a/proxy/src/scram/key.rs +++ b/proxy/src/scram/key.rs @@ -1,6 +1,12 @@ //! Tools for client/server/stored key management. +use hmac::Mac as _; +use sha2::Digest as _; use subtle::ConstantTimeEq; +use zeroize::Zeroize as _; + +use crate::metrics::Metrics; +use crate::scram::pbkdf2::Prf; /// Faithfully taken from PostgreSQL. pub(crate) const SCRAM_KEY_LEN: usize = 32; @@ -14,6 +20,12 @@ pub(crate) struct ScramKey { bytes: [u8; SCRAM_KEY_LEN], } +impl Drop for ScramKey { + fn drop(&mut self) { + self.bytes.zeroize(); + } +} + impl PartialEq for ScramKey { fn eq(&self, other: &Self) -> bool { self.ct_eq(other).into() @@ -28,12 +40,26 @@ impl ConstantTimeEq for ScramKey { impl ScramKey { pub(crate) fn sha256(&self) -> Self { - super::sha256([self.as_ref()]).into() + Metrics::get().proxy.sha_rounds.inc_by(1); + Self { + bytes: sha2::Sha256::digest(self.as_bytes()).into(), + } } pub(crate) fn as_bytes(&self) -> [u8; SCRAM_KEY_LEN] { self.bytes } + + pub(crate) fn client_key(b: &[u8; 32]) -> Self { + // Prf::new_from_slice will run 2 sha256 rounds. + // Update + Finalize run 2 sha256 rounds. + Metrics::get().proxy.sha_rounds.inc_by(4); + + let mut prf = Prf::new_from_slice(b).expect("HMAC is able to accept all key sizes"); + prf.update(b"Client Key"); + let client_key: [u8; 32] = prf.finalize().into_bytes().into(); + client_key.into() + } } impl From<[u8; SCRAM_KEY_LEN]> for ScramKey { diff --git a/proxy/src/scram/mod.rs b/proxy/src/scram/mod.rs index 5f627e062c..04722d920b 100644 --- a/proxy/src/scram/mod.rs +++ b/proxy/src/scram/mod.rs @@ -6,6 +6,7 @@ //! * //! * +mod cache; mod countmin; mod exchange; mod key; @@ -18,10 +19,8 @@ pub mod threadpool; use base64::Engine as _; use base64::prelude::BASE64_STANDARD; pub(crate) use exchange::{Exchange, exchange}; -use hmac::{Hmac, Mac}; pub(crate) use key::ScramKey; pub(crate) use secret::ServerSecret; -use sha2::{Digest, Sha256}; const SCRAM_SHA_256: &str = "SCRAM-SHA-256"; const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS"; @@ -42,29 +41,13 @@ fn base64_decode_array(input: impl AsRef<[u8]>) -> Option<[u8; N Some(bytes) } -/// This function essentially is `Hmac(sha256, key, input)`. -/// Further reading: . -fn hmac_sha256<'a>(key: &[u8], parts: impl IntoIterator) -> [u8; 32] { - let mut mac = Hmac::::new_from_slice(key).expect("bad key size"); - parts.into_iter().for_each(|s| mac.update(s)); - - mac.finalize().into_bytes().into() -} - -fn sha256<'a>(parts: impl IntoIterator) -> [u8; 32] { - let mut hasher = Sha256::new(); - parts.into_iter().for_each(|s| hasher.update(s)); - - hasher.finalize().into() -} - #[cfg(test)] mod tests { use super::threadpool::ThreadPool; use super::{Exchange, ServerSecret}; - use crate::intern::EndpointIdInt; + use crate::intern::{EndpointIdInt, RoleNameInt}; use crate::sasl::{Mechanism, Step}; - use crate::types::EndpointId; + use crate::types::{EndpointId, RoleName}; #[test] fn snapshot() { @@ -114,23 +97,34 @@ mod tests { ); } - async fn run_round_trip_test(server_password: &str, client_password: &str) { - let pool = ThreadPool::new(1); - + async fn check( + pool: &ThreadPool, + scram_secret: &ServerSecret, + password: &[u8], + ) -> Result<(), &'static str> { let ep = EndpointId::from("foo"); let ep = EndpointIdInt::from(ep); + let role = RoleName::from("user"); + let role = RoleNameInt::from(&role); - let scram_secret = ServerSecret::build(server_password).await.unwrap(); - let outcome = super::exchange(&pool, ep, &scram_secret, client_password.as_bytes()) + let outcome = super::exchange(pool, ep, role, scram_secret, password) .await .unwrap(); match outcome { - crate::sasl::Outcome::Success(_) => {} - crate::sasl::Outcome::Failure(r) => panic!("{r}"), + crate::sasl::Outcome::Success(_) => Ok(()), + crate::sasl::Outcome::Failure(r) => Err(r), } } + async fn run_round_trip_test(server_password: &str, client_password: &str) { + let pool = ThreadPool::new(1); + let scram_secret = ServerSecret::build(server_password).await.unwrap(); + check(&pool, &scram_secret, client_password.as_bytes()) + .await + .unwrap(); + } + #[tokio::test] async fn round_trip() { run_round_trip_test("pencil", "pencil").await; @@ -141,4 +135,27 @@ mod tests { async fn failure() { run_round_trip_test("pencil", "eraser").await; } + + #[tokio::test] + #[tracing_test::traced_test] + async fn password_cache() { + let pool = ThreadPool::new(1); + let scram_secret = ServerSecret::build("password").await.unwrap(); + + // wrong passwords are not added to cache + check(&pool, &scram_secret, b"wrong").await.unwrap_err(); + assert!(!logs_contain("storing cached password")); + + // correct passwords get cached + check(&pool, &scram_secret, b"password").await.unwrap(); + assert!(logs_contain("storing cached password")); + + // wrong passwords do not match the cache + check(&pool, &scram_secret, b"wrong").await.unwrap_err(); + assert!(!logs_contain("password validated from cache")); + + // correct passwords match the cache + check(&pool, &scram_secret, b"password").await.unwrap(); + assert!(logs_contain("password validated from cache")); + } } diff --git a/proxy/src/scram/pbkdf2.rs b/proxy/src/scram/pbkdf2.rs index 7f48e00c41..1300310de2 100644 --- a/proxy/src/scram/pbkdf2.rs +++ b/proxy/src/scram/pbkdf2.rs @@ -1,25 +1,50 @@ +//! For postgres password authentication, we need to perform a PBKDF2 using +//! PRF=HMAC-SHA2-256, producing only 1 block (32 bytes) of output key. + +use hmac::Mac as _; use hmac::digest::consts::U32; use hmac::digest::generic_array::GenericArray; -use hmac::{Hmac, Mac}; -use sha2::Sha256; +use zeroize::Zeroize as _; + +use crate::metrics::Metrics; + +/// The Psuedo-random function used during PBKDF2 and the SCRAM-SHA-256 handshake. +pub type Prf = hmac::Hmac; +pub(crate) type Block = GenericArray; pub(crate) struct Pbkdf2 { - hmac: Hmac, - prev: GenericArray, - hi: GenericArray, + hmac: Prf, + /// U{r-1} for whatever iteration r we are currently on. + prev: Block, + /// the output of `fold(xor, U{1}..U{r})` for whatever iteration r we are currently on. + hi: Block, + /// number of iterations left iterations: u32, } +impl Drop for Pbkdf2 { + fn drop(&mut self) { + self.prev.zeroize(); + self.hi.zeroize(); + } +} + // inspired from impl Pbkdf2 { - pub(crate) fn start(str: &[u8], salt: &[u8], iterations: u32) -> Self { + pub(crate) fn start(pw: &[u8], salt: &[u8], iterations: u32) -> Self { // key the HMAC and derive the first block in-place - let mut hmac = - Hmac::::new_from_slice(str).expect("HMAC is able to accept all key sizes"); + let mut hmac = Prf::new_from_slice(pw).expect("HMAC is able to accept all key sizes"); + + // U1 = PRF(Password, Salt + INT_32_BE(i)) + // i = 1 since we only need 1 block of output. hmac.update(salt); hmac.update(&1u32.to_be_bytes()); let init_block = hmac.finalize_reset().into_bytes(); + // Prf::new_from_slice will run 2 sha256 rounds. + // Our update + finalize run 2 sha256 rounds for each pbkdf2 round. + Metrics::get().proxy.sha_rounds.inc_by(4); + Self { hmac, // one iteration spent above @@ -33,7 +58,11 @@ impl Pbkdf2 { (self.iterations).clamp(0, 4096) } - pub(crate) fn turn(&mut self) -> std::task::Poll<[u8; 32]> { + /// For "fairness", we implement PBKDF2 with cooperative yielding, which is why we use this `turn` + /// function that only executes a fixed number of iterations before continuing. + /// + /// Task must be rescheuled if this returns [`std::task::Poll::Pending`]. + pub(crate) fn turn(&mut self) -> std::task::Poll { let Self { hmac, prev, @@ -44,25 +73,37 @@ impl Pbkdf2 { // only do up to 4096 iterations per turn for fairness let n = (*iterations).clamp(0, 4096); for _ in 0..n { - hmac.update(prev); - let block = hmac.finalize_reset().into_bytes(); - - for (hi_byte, &b) in hi.iter_mut().zip(block.iter()) { - *hi_byte ^= b; - } - - *prev = block; + let next = single_round(hmac, prev); + xor_assign(hi, &next); + *prev = next; } + // Our update + finalize run 2 sha256 rounds for each pbkdf2 round. + Metrics::get().proxy.sha_rounds.inc_by(2 * n as u64); + *iterations -= n; if *iterations == 0 { - std::task::Poll::Ready((*hi).into()) + std::task::Poll::Ready(*hi) } else { std::task::Poll::Pending } } } +#[inline(always)] +pub fn xor_assign(x: &mut Block, y: &Block) { + for (x, &y) in std::iter::zip(x, y) { + *x ^= y; + } +} + +#[inline(always)] +fn single_round(prf: &mut Prf, ui: &Block) -> Block { + // Ui = PRF(Password, Ui-1) + prf.update(ui); + prf.finalize_reset().into_bytes() +} + #[cfg(test)] mod tests { use pbkdf2::pbkdf2_hmac_array; @@ -76,11 +117,11 @@ mod tests { let pass = b"Ne0n_!5_50_C007"; let mut job = Pbkdf2::start(pass, salt, 60000); - let hash = loop { + let hash: [u8; 32] = loop { let std::task::Poll::Ready(hash) = job.turn() else { continue; }; - break hash; + break hash.into(); }; let expected = pbkdf2_hmac_array::(pass, salt, 60000); diff --git a/proxy/src/scram/secret.rs b/proxy/src/scram/secret.rs index 0e070c2f27..a3a64f271c 100644 --- a/proxy/src/scram/secret.rs +++ b/proxy/src/scram/secret.rs @@ -3,6 +3,7 @@ use base64::Engine as _; use base64::prelude::BASE64_STANDARD; use subtle::{Choice, ConstantTimeEq}; +use tokio::time::Instant; use super::base64_decode_array; use super::key::ScramKey; @@ -11,6 +12,9 @@ use super::key::ScramKey; /// and is used throughout the authentication process. #[derive(Clone, Eq, PartialEq, Debug)] pub(crate) struct ServerSecret { + /// When this secret was cached. + pub(crate) cached_at: Instant, + /// Number of iterations for `PBKDF2` function. pub(crate) iterations: u32, /// Salt used to hash user's password. @@ -34,6 +38,7 @@ impl ServerSecret { params.split_once(':').zip(keys.split_once(':'))?; let secret = ServerSecret { + cached_at: Instant::now(), iterations: iterations.parse().ok()?, salt_base64: salt.into(), stored_key: base64_decode_array(stored_key)?.into(), @@ -54,6 +59,7 @@ impl ServerSecret { /// See `auth-scram.c : mock_scram_secret` for details. pub(crate) fn mock(nonce: [u8; 32]) -> Self { Self { + cached_at: Instant::now(), // this doesn't reveal much information as we're going to use // iteration count 1 for our generated passwords going forward. // PG16 users can set iteration count=1 already today. diff --git a/proxy/src/scram/signature.rs b/proxy/src/scram/signature.rs index a5b1c3e9f4..8e074272b6 100644 --- a/proxy/src/scram/signature.rs +++ b/proxy/src/scram/signature.rs @@ -1,6 +1,10 @@ //! Tools for client/server signature management. +use hmac::Mac as _; + use super::key::{SCRAM_KEY_LEN, ScramKey}; +use crate::metrics::Metrics; +use crate::scram::pbkdf2::Prf; /// A collection of message parts needed to derive the client's signature. #[derive(Debug)] @@ -12,15 +16,18 @@ pub(crate) struct SignatureBuilder<'a> { impl SignatureBuilder<'_> { pub(crate) fn build(&self, key: &ScramKey) -> Signature { - let parts = [ - self.client_first_message_bare.as_bytes(), - b",", - self.server_first_message.as_bytes(), - b",", - self.client_final_message_without_proof.as_bytes(), - ]; + // don't know exactly. this is a rough approx + Metrics::get().proxy.sha_rounds.inc_by(8); - super::hmac_sha256(key.as_ref(), parts).into() + let mut mac = Prf::new_from_slice(key.as_ref()).expect("HMAC accepts all key sizes"); + mac.update(self.client_first_message_bare.as_bytes()); + mac.update(b","); + mac.update(self.server_first_message.as_bytes()); + mac.update(b","); + mac.update(self.client_final_message_without_proof.as_bytes()); + Signature { + bytes: mac.finalize().into_bytes().into(), + } } } diff --git a/proxy/src/scram/threadpool.rs b/proxy/src/scram/threadpool.rs index ea2e29ede9..20a1df2b53 100644 --- a/proxy/src/scram/threadpool.rs +++ b/proxy/src/scram/threadpool.rs @@ -15,6 +15,8 @@ use futures::FutureExt; use rand::rngs::SmallRng; use rand::{Rng, SeedableRng}; +use super::cache::Pbkdf2Cache; +use super::pbkdf2; use super::pbkdf2::Pbkdf2; use crate::intern::EndpointIdInt; use crate::metrics::{ThreadPoolMetrics, ThreadPoolWorkerId}; @@ -23,6 +25,10 @@ use crate::scram::countmin::CountMinSketch; pub struct ThreadPool { runtime: Option, pub metrics: Arc, + + // we hash a lot of passwords. + // we keep a cache of partial hashes for faster validation. + pub(super) cache: Pbkdf2Cache, } /// How often to reset the sketch values @@ -68,6 +74,7 @@ impl ThreadPool { Self { runtime: Some(runtime), metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)), + cache: Pbkdf2Cache::new(), } }) } @@ -130,7 +137,7 @@ struct JobSpec { } impl Future for JobSpec { - type Output = [u8; 32]; + type Output = pbkdf2::Block; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { STATE.with_borrow_mut(|state| { @@ -166,10 +173,10 @@ impl Future for JobSpec { } } -pub(crate) struct JobHandle(tokio::task::JoinHandle<[u8; 32]>); +pub(crate) struct JobHandle(tokio::task::JoinHandle); impl Future for JobHandle { - type Output = [u8; 32]; + type Output = pbkdf2::Block; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.0.poll_unpin(cx) { @@ -203,10 +210,10 @@ mod tests { .spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096)) .await; - let expected = [ + let expected = &[ 10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242, 178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140, ]; - assert_eq!(actual, expected); + assert_eq!(actual.as_slice(), expected); } } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 0987b6927f..eb879f98e7 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -26,7 +26,7 @@ use crate::context::RequestContext; use crate::control_plane::client::ApiLockError; use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError}; use crate::error::{ErrorKind, ReportableError, UserFacingError}; -use crate::intern::EndpointIdInt; +use crate::intern::{EndpointIdInt, RoleNameInt}; use crate::pqproto::StartupMessageParams; use crate::proxy::{connect_auth, connect_compute}; use crate::rate_limiter::EndpointRateLimiter; @@ -76,9 +76,11 @@ impl PoolingBackend { }; let ep = EndpointIdInt::from(&user_info.endpoint); + let role = RoleNameInt::from(&user_info.user); let auth_outcome = crate::auth::validate_password_and_exchange( - &self.config.authentication_config.thread_pool, + &self.config.authentication_config.scram_thread_pool, ep, + role, password, secret, )