proxy: cache for password hashing (#12011)

## Problem

Password hashing for sql-over-http takes up a lot of CPU. Perhaps we can
get away with temporarily caching some steps so we only need fewer
rounds, which will save some CPU time.

## Summary of changes

The output of pbkdf2 is the XOR of the outputs of each iteration round,
eg `U1 ^ U2 ^ ... U15 ^ U16 ^ U17 ^ ... ^ Un`. We cache the suffix of
the expression `U16 ^ U17 ^ ... ^ Un`. To compute the result from the
cached suffix, we only need to compute the prefix `U1 ^ U2 ^ ... U15`.
The suffix by itself is useless, which prevent's its use in brute-force
attacks should this cached memory leak.

We are also caching the full 4096 round hash in memory, which can be
used for brute-force attacks, where this suffix could be used to speed
it up. My hope/expectation is that since these will be in different
allocations, it makes any such memory exploitation much much harder.
Since the full hash cache might be invalidated while the suffix is
cached, I'm storing the timestamp of the computation as a way to
identity the match.

I also added `zeroize()` to clear the sensitive state from the
stack/heap.

For the most security conscious customers, we hope to roll out OIDC
soon, so they can disable passwords entirely.

---

The numbers for the threadpool were pretty random, but according to our
busiest region for sql-over-http, we only see about 150 unique endpoints
every minute. So storing ~100 of the most common endpoints for that
minute should be the vast majority of requests.

1 minute was chosen so we don't keep data in memory for too long.
This commit is contained in:
Conrad Ludgate
2025-07-29 07:48:14 +01:00
committed by GitHub
parent 6be572177c
commit e7a1d5de94
20 changed files with 414 additions and 130 deletions

1
Cargo.lock generated
View File

@@ -5519,6 +5519,7 @@ dependencies = [
"workspace_hack",
"x509-cert",
"zerocopy 0.8.24",
"zeroize",
]
[[package]]

View File

@@ -234,9 +234,10 @@ uuid = { version = "1.6.1", features = ["v4", "v7", "serde"] }
walkdir = "2.3.2"
rustls-native-certs = "0.8"
whoami = "1.5.1"
zerocopy = { version = "0.8", features = ["derive", "simd"] }
json-structural-diff = { version = "0.2.0" }
x509-cert = { version = "0.2.5" }
zerocopy = { version = "0.8", features = ["derive", "simd"] }
zeroize = "1.8"
## TODO replace this with tracing
env_logger = "0.11"

View File

@@ -107,6 +107,7 @@ uuid.workspace = true
x509-cert.workspace = true
redis.workspace = true
zerocopy.workspace = true
zeroize.workspace = true
# uncomment this to use the real subzero-core crate
# subzero-core = { git = "https://github.com/neondatabase/subzero", rev = "396264617e78e8be428682f87469bb25429af88a", features = ["postgresql"], optional = true }
# this is a stub for the subzero-core crate

View File

@@ -6,7 +6,7 @@ use crate::auth::{self, AuthFlow};
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::AuthSecret;
use crate::intern::EndpointIdInt;
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::sasl;
use crate::stream::{self, Stream};
@@ -25,13 +25,15 @@ pub(crate) async fn authenticate_cleartext(
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
let ep = EndpointIdInt::from(&info.endpoint);
let role = RoleNameInt::from(&info.user);
let auth_flow = AuthFlow::new(
client,
auth::CleartextPassword {
secret,
endpoint: ep,
pool: config.thread_pool.clone(),
role,
pool: config.scram_thread_pool.clone(),
},
);
let auth_outcome = {

View File

@@ -25,7 +25,7 @@ use crate::control_plane::messages::EndpointRateLimitConfig;
use crate::control_plane::{
self, AccessBlockerFlags, AuthSecret, ControlPlaneApi, EndpointAccessControl, RoleAccessControl,
};
use crate::intern::EndpointIdInt;
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::pqproto::BeMessage;
use crate::proxy::NeonOptions;
use crate::proxy::wake_compute::WakeComputeBackend;
@@ -273,9 +273,11 @@ async fn authenticate_with_secret(
) -> auth::Result<ComputeCredentials> {
if let Some(password) = unauthenticated_password {
let ep = EndpointIdInt::from(&info.endpoint);
let role = RoleNameInt::from(&info.user);
let auth_outcome =
validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
validate_password_and_exchange(&config.scram_thread_pool, ep, role, &password, secret)
.await?;
let keys = match auth_outcome {
crate::sasl::Outcome::Success(key) => key,
crate::sasl::Outcome::Failure(reason) => {
@@ -499,7 +501,7 @@ mod tests {
static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool: ThreadPool::new(1),
scram_thread_pool: ThreadPool::new(1),
scram_protocol_timeout: std::time::Duration::from_secs(5),
ip_allowlist_check_enabled: true,
is_vpc_acccess_proxy: false,

View File

@@ -10,7 +10,7 @@ use super::backend::ComputeCredentialKeys;
use super::{AuthError, PasswordHackPayload};
use crate::context::RequestContext;
use crate::control_plane::AuthSecret;
use crate::intern::EndpointIdInt;
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
use crate::sasl;
use crate::scram::threadpool::ThreadPool;
@@ -46,6 +46,7 @@ pub(crate) struct PasswordHack;
pub(crate) struct CleartextPassword {
pub(crate) pool: Arc<ThreadPool>,
pub(crate) endpoint: EndpointIdInt,
pub(crate) role: RoleNameInt,
pub(crate) secret: AuthSecret,
}
@@ -111,6 +112,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
let outcome = validate_password_and_exchange(
&self.state.pool,
self.state.endpoint,
self.state.role,
password,
self.state.secret,
)
@@ -165,13 +167,15 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
pub(crate) async fn validate_password_and_exchange(
pool: &ThreadPool,
endpoint: EndpointIdInt,
role: RoleNameInt,
password: &[u8],
secret: AuthSecret,
) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
match secret {
// perform scram authentication as both client and server to validate the keys
AuthSecret::Scram(scram_secret) => {
let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
let outcome =
crate::scram::exchange(pool, endpoint, role, &scram_secret, password).await?;
let client_key = match outcome {
sasl::Outcome::Success(client_key) => client_key,

View File

@@ -29,7 +29,7 @@ use crate::config::{
};
use crate::control_plane::locks::ApiLocks;
use crate::http::health_server::AppMetrics;
use crate::metrics::{Metrics, ServiceInfo, ThreadPoolMetrics};
use crate::metrics::{Metrics, ServiceInfo};
use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo};
use crate::scram::threadpool::ThreadPool;
use crate::serverless::cancel_set::CancelSet;
@@ -114,8 +114,6 @@ pub async fn run() -> anyhow::Result<()> {
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
// TODO: refactor these to use labels
debug!("Version: {GIT_VERSION}");
debug!("Build_tag: {BUILD_TAG}");
@@ -284,7 +282,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
http_config,
authentication_config: AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool: ThreadPool::new(0),
scram_thread_pool: ThreadPool::new(0),
scram_protocol_timeout: Duration::from_secs(10),
ip_allowlist_check_enabled: true,
is_vpc_acccess_proxy: false,

View File

@@ -26,7 +26,7 @@ use utils::project_git_version;
use utils::sentry_init::init_sentry;
use crate::context::RequestContext;
use crate::metrics::{Metrics, ServiceInfo, ThreadPoolMetrics};
use crate::metrics::{Metrics, ServiceInfo};
use crate::pglb::TlsRequired;
use crate::pqproto::FeStartupPacket;
use crate::protocol2::ConnectionInfo;
@@ -80,8 +80,6 @@ pub async fn run() -> anyhow::Result<()> {
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
let args = cli().get_matches();
let destination: String = args
.get_one::<String>("dest")

View File

@@ -617,7 +617,12 @@ pub async fn run() -> anyhow::Result<()> {
/// ProxyConfig is created at proxy startup, and lives forever.
fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
let thread_pool = ThreadPool::new(args.scram_thread_pool_size);
Metrics::install(thread_pool.metrics.clone());
Metrics::get()
.proxy
.scram_pool
.0
.set(thread_pool.metrics.clone())
.ok();
let tls_config = match (&args.tls_key, &args.tls_cert) {
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(
@@ -690,7 +695,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
};
let authentication_config = AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool,
scram_thread_pool: thread_pool,
scram_protocol_timeout: args.scram_protocol_timeout,
ip_allowlist_check_enabled: !args.is_private_access_proxy,
is_vpc_acccess_proxy: args.is_private_access_proxy,

View File

@@ -19,7 +19,7 @@ use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use crate::ext::TaskExt;
use crate::intern::RoleNameInt;
use crate::rate_limiter::{RateLimitAlgorithm, RateLimiterConfig};
use crate::scram::threadpool::ThreadPool;
use crate::scram;
use crate::serverless::GlobalConnPoolOptions;
use crate::serverless::cancel_set::CancelSet;
#[cfg(feature = "rest_broker")]
@@ -75,7 +75,7 @@ pub struct HttpConfig {
}
pub struct AuthenticationConfig {
pub thread_pool: Arc<ThreadPool>,
pub scram_thread_pool: Arc<scram::threadpool::ThreadPool>,
pub scram_protocol_timeout: tokio::time::Duration,
pub ip_allowlist_check_enabled: bool,
pub is_vpc_acccess_proxy: bool,

View File

@@ -5,6 +5,7 @@ use measured::label::{
FixedCardinalitySet, LabelGroupSet, LabelGroupVisitor, LabelName, LabelSet, LabelValue,
StaticLabelSet,
};
use measured::metric::group::Encoding;
use measured::metric::histogram::Thresholds;
use measured::metric::name::MetricName;
use measured::{
@@ -18,10 +19,10 @@ use crate::control_plane::messages::ColdStartInfo;
use crate::error::ErrorKind;
#[derive(MetricGroup)]
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>))]
#[metric(new())]
pub struct Metrics {
#[metric(namespace = "proxy")]
#[metric(init = ProxyMetrics::new(thread_pool))]
#[metric(init = ProxyMetrics::new())]
pub proxy: ProxyMetrics,
#[metric(namespace = "wake_compute_lock")]
@@ -34,34 +35,27 @@ pub struct Metrics {
pub cache: CacheMetrics,
}
static SELF: OnceLock<Metrics> = OnceLock::new();
impl Metrics {
pub fn install(thread_pool: Arc<ThreadPoolMetrics>) {
let mut metrics = Metrics::new(thread_pool);
metrics.proxy.errors_total.init_all_dense();
metrics.proxy.redis_errors_total.init_all_dense();
metrics.proxy.redis_events_count.init_all_dense();
metrics.proxy.retries_metric.init_all_dense();
metrics.proxy.connection_failures_total.init_all_dense();
SELF.set(metrics)
.ok()
.expect("proxy metrics must not be installed more than once");
}
#[track_caller]
pub fn get() -> &'static Self {
#[cfg(test)]
return SELF.get_or_init(|| Metrics::new(Arc::new(ThreadPoolMetrics::new(0))));
static SELF: OnceLock<Metrics> = OnceLock::new();
#[cfg(not(test))]
SELF.get()
.expect("proxy metrics must be installed by the main() function")
SELF.get_or_init(|| {
let mut metrics = Metrics::new();
metrics.proxy.errors_total.init_all_dense();
metrics.proxy.redis_errors_total.init_all_dense();
metrics.proxy.redis_events_count.init_all_dense();
metrics.proxy.retries_metric.init_all_dense();
metrics.proxy.connection_failures_total.init_all_dense();
metrics
})
}
}
#[derive(MetricGroup)]
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>))]
#[metric(new())]
pub struct ProxyMetrics {
#[metric(flatten)]
pub db_connections: CounterPairVec<NumDbConnectionsGauge>,
@@ -134,6 +128,9 @@ pub struct ProxyMetrics {
/// Number of TLS handshake failures
pub tls_handshake_failures: Counter,
/// Number of SHA 256 rounds executed.
pub sha_rounds: Counter,
/// HLL approximate cardinality of endpoints that are connecting
pub connecting_endpoints: HyperLogLogVec<StaticLabelSet<Protocol>, 32>,
@@ -151,8 +148,25 @@ pub struct ProxyMetrics {
pub connect_compute_lock: ApiLockMetrics,
#[metric(namespace = "scram_pool")]
#[metric(init = thread_pool)]
pub scram_pool: Arc<ThreadPoolMetrics>,
pub scram_pool: OnceLockWrapper<Arc<ThreadPoolMetrics>>,
}
/// A Wrapper over [`OnceLock`] to implement [`MetricGroup`].
pub struct OnceLockWrapper<T>(pub OnceLock<T>);
impl<T> Default for OnceLockWrapper<T> {
fn default() -> Self {
Self(OnceLock::new())
}
}
impl<Enc: Encoding, T: MetricGroup<Enc>> MetricGroup<Enc> for OnceLockWrapper<T> {
fn collect_group_into(&self, enc: &mut Enc) -> Result<(), Enc::Err> {
if let Some(inner) = self.0.get() {
inner.collect_group_into(enc)?;
}
Ok(())
}
}
#[derive(MetricGroup)]
@@ -719,6 +733,7 @@ pub enum CacheKind {
ProjectInfoEndpoints,
ProjectInfoRoles,
Schema,
Pbkdf2,
}
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]

84
proxy/src/scram/cache.rs Normal file
View File

@@ -0,0 +1,84 @@
use tokio::time::Instant;
use zeroize::Zeroize as _;
use super::pbkdf2;
use crate::cache::Cached;
use crate::cache::common::{Cache, count_cache_insert, count_cache_outcome, eviction_listener};
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::metrics::{CacheKind, Metrics};
pub(crate) struct Pbkdf2Cache(moka::sync::Cache<(EndpointIdInt, RoleNameInt), Pbkdf2CacheEntry>);
pub(crate) type CachedPbkdf2<'a> = Cached<&'a Pbkdf2Cache>;
impl Cache for Pbkdf2Cache {
type Key = (EndpointIdInt, RoleNameInt);
type Value = Pbkdf2CacheEntry;
fn invalidate(&self, info: &(EndpointIdInt, RoleNameInt)) {
self.0.invalidate(info);
}
}
/// To speed up password hashing for more active customers, we store the tail results of the
/// PBKDF2 algorithm. If the output of PBKDF2 is U1 ^ U2 ^ ⋯ ^ Uc, then we store
/// suffix = U17 ^ U18 ^ ⋯ ^ Uc. We only need to calculate U1 ^ U2 ^ ⋯ ^ U15 ^ U16
/// to determine the final result.
///
/// The suffix alone isn't enough to crack the password. The stored_key is still required.
/// While both are cached in memory, given they're in different locations is makes it much
/// harder to exploit, even if any such memory exploit exists in proxy.
#[derive(Clone)]
pub struct Pbkdf2CacheEntry {
/// corresponds to [`super::ServerSecret::cached_at`]
pub(super) cached_from: Instant,
pub(super) suffix: pbkdf2::Block,
}
impl Drop for Pbkdf2CacheEntry {
fn drop(&mut self) {
self.suffix.zeroize();
}
}
impl Pbkdf2Cache {
pub fn new() -> Self {
const SIZE: u64 = 100;
const TTL: std::time::Duration = std::time::Duration::from_secs(60);
let builder = moka::sync::Cache::builder()
.name("pbkdf2")
.max_capacity(SIZE)
// We use time_to_live so we don't refresh the lifetime for an invalid password attempt.
.time_to_live(TTL);
Metrics::get()
.cache
.capacity
.set(CacheKind::Pbkdf2, SIZE as i64);
let builder =
builder.eviction_listener(|_k, _v, cause| eviction_listener(CacheKind::Pbkdf2, cause));
Self(builder.build())
}
pub fn insert(&self, endpoint: EndpointIdInt, role: RoleNameInt, value: Pbkdf2CacheEntry) {
count_cache_insert(CacheKind::Pbkdf2);
self.0.insert((endpoint, role), value);
}
fn get(&self, endpoint: EndpointIdInt, role: RoleNameInt) -> Option<Pbkdf2CacheEntry> {
count_cache_outcome(CacheKind::Pbkdf2, self.0.get(&(endpoint, role)))
}
pub fn get_entry(
&self,
endpoint: EndpointIdInt,
role: RoleNameInt,
) -> Option<CachedPbkdf2<'_>> {
self.get(endpoint, role).map(|value| Cached {
token: Some((self, (endpoint, role))),
value,
})
}
}

View File

@@ -4,10 +4,8 @@ use std::convert::Infallible;
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
use hmac::{Hmac, Mac};
use sha2::Sha256;
use tracing::{debug, trace};
use super::ScramKey;
use super::messages::{
ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
};
@@ -15,8 +13,10 @@ use super::pbkdf2::Pbkdf2;
use super::secret::ServerSecret;
use super::signature::SignatureBuilder;
use super::threadpool::ThreadPool;
use crate::intern::EndpointIdInt;
use super::{ScramKey, pbkdf2};
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::sasl::{self, ChannelBinding, Error as SaslError};
use crate::scram::cache::Pbkdf2CacheEntry;
/// The only channel binding mode we currently support.
#[derive(Debug)]
@@ -77,46 +77,113 @@ impl<'a> Exchange<'a> {
}
}
// copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L236-L248>
async fn derive_client_key(
pool: &ThreadPool,
endpoint: EndpointIdInt,
password: &[u8],
salt: &[u8],
iterations: u32,
) -> ScramKey {
let salted_password = pool
.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
.await;
let make_key = |name| {
let key = Hmac::<Sha256>::new_from_slice(&salted_password)
.expect("HMAC is able to accept all key sizes")
.chain_update(name)
.finalize();
<[u8; 32]>::from(key.into_bytes())
};
make_key(b"Client Key").into()
) -> pbkdf2::Block {
pool.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
.await
}
/// For cleartext flow, we need to derive the client key to
/// 1. authenticate the client.
/// 2. authenticate with compute.
pub(crate) async fn exchange(
pool: &ThreadPool,
endpoint: EndpointIdInt,
role: RoleNameInt,
secret: &ServerSecret,
password: &[u8],
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
if secret.iterations > CACHED_ROUNDS {
exchange_with_cache(pool, endpoint, role, secret, password).await
} else {
let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?;
let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
Ok(validate_pbkdf2(secret, &hash))
}
}
/// Compute the client key using a cache. We cache the suffix of the pbkdf2 result only,
/// which is not enough by itself to perform an offline brute force.
async fn exchange_with_cache(
pool: &ThreadPool,
endpoint: EndpointIdInt,
role: RoleNameInt,
secret: &ServerSecret,
password: &[u8],
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?;
let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
debug_assert!(
secret.iterations > CACHED_ROUNDS,
"we should not cache password data if there isn't enough rounds needed"
);
// compute the prefix of the pbkdf2 output.
let prefix = derive_client_key(pool, endpoint, password, &salt, CACHED_ROUNDS).await;
if let Some(entry) = pool.cache.get_entry(endpoint, role) {
// hot path: let's check the threadpool cache
if secret.cached_at == entry.cached_from {
// cache is valid. compute the full hash by adding the prefix to the suffix.
let mut hash = prefix;
pbkdf2::xor_assign(&mut hash, &entry.suffix);
let outcome = validate_pbkdf2(secret, &hash);
if matches!(outcome, sasl::Outcome::Success(_)) {
trace!("password validated from cache");
}
return Ok(outcome);
}
// cached key is no longer valid.
debug!("invalidating cached password");
entry.invalidate();
}
// slow path: full password hash.
let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
let outcome = validate_pbkdf2(secret, &hash);
let client_key = match outcome {
sasl::Outcome::Success(client_key) => client_key,
sasl::Outcome::Failure(_) => return Ok(outcome),
};
trace!("storing cached password");
// time to cache, compute the suffix by subtracting the prefix from the hash.
let mut suffix = hash;
pbkdf2::xor_assign(&mut suffix, &prefix);
pool.cache.insert(
endpoint,
role,
Pbkdf2CacheEntry {
cached_from: secret.cached_at,
suffix,
},
);
Ok(sasl::Outcome::Success(client_key))
}
fn validate_pbkdf2(secret: &ServerSecret, hash: &pbkdf2::Block) -> sasl::Outcome<ScramKey> {
let client_key = super::ScramKey::client_key(&(*hash).into());
if secret.is_password_invalid(&client_key).into() {
Ok(sasl::Outcome::Failure("password doesn't match"))
sasl::Outcome::Failure("password doesn't match")
} else {
Ok(sasl::Outcome::Success(client_key))
sasl::Outcome::Success(client_key)
}
}
const CACHED_ROUNDS: u32 = 16;
impl SaslInitial {
fn transition(
&self,

View File

@@ -1,6 +1,12 @@
//! Tools for client/server/stored key management.
use hmac::Mac as _;
use sha2::Digest as _;
use subtle::ConstantTimeEq;
use zeroize::Zeroize as _;
use crate::metrics::Metrics;
use crate::scram::pbkdf2::Prf;
/// Faithfully taken from PostgreSQL.
pub(crate) const SCRAM_KEY_LEN: usize = 32;
@@ -14,6 +20,12 @@ pub(crate) struct ScramKey {
bytes: [u8; SCRAM_KEY_LEN],
}
impl Drop for ScramKey {
fn drop(&mut self) {
self.bytes.zeroize();
}
}
impl PartialEq for ScramKey {
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).into()
@@ -28,12 +40,26 @@ impl ConstantTimeEq for ScramKey {
impl ScramKey {
pub(crate) fn sha256(&self) -> Self {
super::sha256([self.as_ref()]).into()
Metrics::get().proxy.sha_rounds.inc_by(1);
Self {
bytes: sha2::Sha256::digest(self.as_bytes()).into(),
}
}
pub(crate) fn as_bytes(&self) -> [u8; SCRAM_KEY_LEN] {
self.bytes
}
pub(crate) fn client_key(b: &[u8; 32]) -> Self {
// Prf::new_from_slice will run 2 sha256 rounds.
// Update + Finalize run 2 sha256 rounds.
Metrics::get().proxy.sha_rounds.inc_by(4);
let mut prf = Prf::new_from_slice(b).expect("HMAC is able to accept all key sizes");
prf.update(b"Client Key");
let client_key: [u8; 32] = prf.finalize().into_bytes().into();
client_key.into()
}
}
impl From<[u8; SCRAM_KEY_LEN]> for ScramKey {

View File

@@ -6,6 +6,7 @@
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-scram.c>
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth-scram.c>
mod cache;
mod countmin;
mod exchange;
mod key;
@@ -18,10 +19,8 @@ pub mod threadpool;
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
pub(crate) use exchange::{Exchange, exchange};
use hmac::{Hmac, Mac};
pub(crate) use key::ScramKey;
pub(crate) use secret::ServerSecret;
use sha2::{Digest, Sha256};
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
@@ -42,29 +41,13 @@ fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N
Some(bytes)
}
/// This function essentially is `Hmac(sha256, key, input)`.
/// Further reading: <https://datatracker.ietf.org/doc/html/rfc2104>.
fn hmac_sha256<'a>(key: &[u8], parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("bad key size");
parts.into_iter().for_each(|s| mac.update(s));
mac.finalize().into_bytes().into()
}
fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
let mut hasher = Sha256::new();
parts.into_iter().for_each(|s| hasher.update(s));
hasher.finalize().into()
}
#[cfg(test)]
mod tests {
use super::threadpool::ThreadPool;
use super::{Exchange, ServerSecret};
use crate::intern::EndpointIdInt;
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::sasl::{Mechanism, Step};
use crate::types::EndpointId;
use crate::types::{EndpointId, RoleName};
#[test]
fn snapshot() {
@@ -114,23 +97,34 @@ mod tests {
);
}
async fn run_round_trip_test(server_password: &str, client_password: &str) {
let pool = ThreadPool::new(1);
async fn check(
pool: &ThreadPool,
scram_secret: &ServerSecret,
password: &[u8],
) -> Result<(), &'static str> {
let ep = EndpointId::from("foo");
let ep = EndpointIdInt::from(ep);
let role = RoleName::from("user");
let role = RoleNameInt::from(&role);
let scram_secret = ServerSecret::build(server_password).await.unwrap();
let outcome = super::exchange(&pool, ep, &scram_secret, client_password.as_bytes())
let outcome = super::exchange(pool, ep, role, scram_secret, password)
.await
.unwrap();
match outcome {
crate::sasl::Outcome::Success(_) => {}
crate::sasl::Outcome::Failure(r) => panic!("{r}"),
crate::sasl::Outcome::Success(_) => Ok(()),
crate::sasl::Outcome::Failure(r) => Err(r),
}
}
async fn run_round_trip_test(server_password: &str, client_password: &str) {
let pool = ThreadPool::new(1);
let scram_secret = ServerSecret::build(server_password).await.unwrap();
check(&pool, &scram_secret, client_password.as_bytes())
.await
.unwrap();
}
#[tokio::test]
async fn round_trip() {
run_round_trip_test("pencil", "pencil").await;
@@ -141,4 +135,27 @@ mod tests {
async fn failure() {
run_round_trip_test("pencil", "eraser").await;
}
#[tokio::test]
#[tracing_test::traced_test]
async fn password_cache() {
let pool = ThreadPool::new(1);
let scram_secret = ServerSecret::build("password").await.unwrap();
// wrong passwords are not added to cache
check(&pool, &scram_secret, b"wrong").await.unwrap_err();
assert!(!logs_contain("storing cached password"));
// correct passwords get cached
check(&pool, &scram_secret, b"password").await.unwrap();
assert!(logs_contain("storing cached password"));
// wrong passwords do not match the cache
check(&pool, &scram_secret, b"wrong").await.unwrap_err();
assert!(!logs_contain("password validated from cache"));
// correct passwords match the cache
check(&pool, &scram_secret, b"password").await.unwrap();
assert!(logs_contain("password validated from cache"));
}
}

View File

@@ -1,25 +1,50 @@
//! For postgres password authentication, we need to perform a PBKDF2 using
//! PRF=HMAC-SHA2-256, producing only 1 block (32 bytes) of output key.
use hmac::Mac as _;
use hmac::digest::consts::U32;
use hmac::digest::generic_array::GenericArray;
use hmac::{Hmac, Mac};
use sha2::Sha256;
use zeroize::Zeroize as _;
use crate::metrics::Metrics;
/// The Psuedo-random function used during PBKDF2 and the SCRAM-SHA-256 handshake.
pub type Prf = hmac::Hmac<sha2::Sha256>;
pub(crate) type Block = GenericArray<u8, U32>;
pub(crate) struct Pbkdf2 {
hmac: Hmac<Sha256>,
prev: GenericArray<u8, U32>,
hi: GenericArray<u8, U32>,
hmac: Prf,
/// U{r-1} for whatever iteration r we are currently on.
prev: Block,
/// the output of `fold(xor, U{1}..U{r})` for whatever iteration r we are currently on.
hi: Block,
/// number of iterations left
iterations: u32,
}
impl Drop for Pbkdf2 {
fn drop(&mut self) {
self.prev.zeroize();
self.hi.zeroize();
}
}
// inspired from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L36-L61>
impl Pbkdf2 {
pub(crate) fn start(str: &[u8], salt: &[u8], iterations: u32) -> Self {
pub(crate) fn start(pw: &[u8], salt: &[u8], iterations: u32) -> Self {
// key the HMAC and derive the first block in-place
let mut hmac =
Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
let mut hmac = Prf::new_from_slice(pw).expect("HMAC is able to accept all key sizes");
// U1 = PRF(Password, Salt + INT_32_BE(i))
// i = 1 since we only need 1 block of output.
hmac.update(salt);
hmac.update(&1u32.to_be_bytes());
let init_block = hmac.finalize_reset().into_bytes();
// Prf::new_from_slice will run 2 sha256 rounds.
// Our update + finalize run 2 sha256 rounds for each pbkdf2 round.
Metrics::get().proxy.sha_rounds.inc_by(4);
Self {
hmac,
// one iteration spent above
@@ -33,7 +58,11 @@ impl Pbkdf2 {
(self.iterations).clamp(0, 4096)
}
pub(crate) fn turn(&mut self) -> std::task::Poll<[u8; 32]> {
/// For "fairness", we implement PBKDF2 with cooperative yielding, which is why we use this `turn`
/// function that only executes a fixed number of iterations before continuing.
///
/// Task must be rescheuled if this returns [`std::task::Poll::Pending`].
pub(crate) fn turn(&mut self) -> std::task::Poll<Block> {
let Self {
hmac,
prev,
@@ -44,25 +73,37 @@ impl Pbkdf2 {
// only do up to 4096 iterations per turn for fairness
let n = (*iterations).clamp(0, 4096);
for _ in 0..n {
hmac.update(prev);
let block = hmac.finalize_reset().into_bytes();
for (hi_byte, &b) in hi.iter_mut().zip(block.iter()) {
*hi_byte ^= b;
}
*prev = block;
let next = single_round(hmac, prev);
xor_assign(hi, &next);
*prev = next;
}
// Our update + finalize run 2 sha256 rounds for each pbkdf2 round.
Metrics::get().proxy.sha_rounds.inc_by(2 * n as u64);
*iterations -= n;
if *iterations == 0 {
std::task::Poll::Ready((*hi).into())
std::task::Poll::Ready(*hi)
} else {
std::task::Poll::Pending
}
}
}
#[inline(always)]
pub fn xor_assign(x: &mut Block, y: &Block) {
for (x, &y) in std::iter::zip(x, y) {
*x ^= y;
}
}
#[inline(always)]
fn single_round(prf: &mut Prf, ui: &Block) -> Block {
// Ui = PRF(Password, Ui-1)
prf.update(ui);
prf.finalize_reset().into_bytes()
}
#[cfg(test)]
mod tests {
use pbkdf2::pbkdf2_hmac_array;
@@ -76,11 +117,11 @@ mod tests {
let pass = b"Ne0n_!5_50_C007";
let mut job = Pbkdf2::start(pass, salt, 60000);
let hash = loop {
let hash: [u8; 32] = loop {
let std::task::Poll::Ready(hash) = job.turn() else {
continue;
};
break hash;
break hash.into();
};
let expected = pbkdf2_hmac_array::<Sha256, 32>(pass, salt, 60000);

View File

@@ -3,6 +3,7 @@
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
use subtle::{Choice, ConstantTimeEq};
use tokio::time::Instant;
use super::base64_decode_array;
use super::key::ScramKey;
@@ -11,6 +12,9 @@ use super::key::ScramKey;
/// and is used throughout the authentication process.
#[derive(Clone, Eq, PartialEq, Debug)]
pub(crate) struct ServerSecret {
/// When this secret was cached.
pub(crate) cached_at: Instant,
/// Number of iterations for `PBKDF2` function.
pub(crate) iterations: u32,
/// Salt used to hash user's password.
@@ -34,6 +38,7 @@ impl ServerSecret {
params.split_once(':').zip(keys.split_once(':'))?;
let secret = ServerSecret {
cached_at: Instant::now(),
iterations: iterations.parse().ok()?,
salt_base64: salt.into(),
stored_key: base64_decode_array(stored_key)?.into(),
@@ -54,6 +59,7 @@ impl ServerSecret {
/// See `auth-scram.c : mock_scram_secret` for details.
pub(crate) fn mock(nonce: [u8; 32]) -> Self {
Self {
cached_at: Instant::now(),
// this doesn't reveal much information as we're going to use
// iteration count 1 for our generated passwords going forward.
// PG16 users can set iteration count=1 already today.

View File

@@ -1,6 +1,10 @@
//! Tools for client/server signature management.
use hmac::Mac as _;
use super::key::{SCRAM_KEY_LEN, ScramKey};
use crate::metrics::Metrics;
use crate::scram::pbkdf2::Prf;
/// A collection of message parts needed to derive the client's signature.
#[derive(Debug)]
@@ -12,15 +16,18 @@ pub(crate) struct SignatureBuilder<'a> {
impl SignatureBuilder<'_> {
pub(crate) fn build(&self, key: &ScramKey) -> Signature {
let parts = [
self.client_first_message_bare.as_bytes(),
b",",
self.server_first_message.as_bytes(),
b",",
self.client_final_message_without_proof.as_bytes(),
];
// don't know exactly. this is a rough approx
Metrics::get().proxy.sha_rounds.inc_by(8);
super::hmac_sha256(key.as_ref(), parts).into()
let mut mac = Prf::new_from_slice(key.as_ref()).expect("HMAC accepts all key sizes");
mac.update(self.client_first_message_bare.as_bytes());
mac.update(b",");
mac.update(self.server_first_message.as_bytes());
mac.update(b",");
mac.update(self.client_final_message_without_proof.as_bytes());
Signature {
bytes: mac.finalize().into_bytes().into(),
}
}
}

View File

@@ -15,6 +15,8 @@ use futures::FutureExt;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use super::cache::Pbkdf2Cache;
use super::pbkdf2;
use super::pbkdf2::Pbkdf2;
use crate::intern::EndpointIdInt;
use crate::metrics::{ThreadPoolMetrics, ThreadPoolWorkerId};
@@ -23,6 +25,10 @@ use crate::scram::countmin::CountMinSketch;
pub struct ThreadPool {
runtime: Option<tokio::runtime::Runtime>,
pub metrics: Arc<ThreadPoolMetrics>,
// we hash a lot of passwords.
// we keep a cache of partial hashes for faster validation.
pub(super) cache: Pbkdf2Cache,
}
/// How often to reset the sketch values
@@ -68,6 +74,7 @@ impl ThreadPool {
Self {
runtime: Some(runtime),
metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
cache: Pbkdf2Cache::new(),
}
})
}
@@ -130,7 +137,7 @@ struct JobSpec {
}
impl Future for JobSpec {
type Output = [u8; 32];
type Output = pbkdf2::Block;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
STATE.with_borrow_mut(|state| {
@@ -166,10 +173,10 @@ impl Future for JobSpec {
}
}
pub(crate) struct JobHandle(tokio::task::JoinHandle<[u8; 32]>);
pub(crate) struct JobHandle(tokio::task::JoinHandle<pbkdf2::Block>);
impl Future for JobHandle {
type Output = [u8; 32];
type Output = pbkdf2::Block;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.0.poll_unpin(cx) {
@@ -203,10 +210,10 @@ mod tests {
.spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096))
.await;
let expected = [
let expected = &[
10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242,
178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140,
];
assert_eq!(actual, expected);
assert_eq!(actual.as_slice(), expected);
}
}

View File

@@ -26,7 +26,7 @@ use crate::context::RequestContext;
use crate::control_plane::client::ApiLockError;
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::intern::EndpointIdInt;
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::pqproto::StartupMessageParams;
use crate::proxy::{connect_auth, connect_compute};
use crate::rate_limiter::EndpointRateLimiter;
@@ -76,9 +76,11 @@ impl PoolingBackend {
};
let ep = EndpointIdInt::from(&user_info.endpoint);
let role = RoleNameInt::from(&user_info.user);
let auth_outcome = crate::auth::validate_password_and_exchange(
&self.config.authentication_config.thread_pool,
&self.config.authentication_config.scram_thread_pool,
ep,
role,
password,
secret,
)