mirror of
https://github.com/neondatabase/neon.git
synced 2025-12-22 21:59:59 +00:00
proxy: cache for password hashing (#12011)
## Problem Password hashing for sql-over-http takes up a lot of CPU. Perhaps we can get away with temporarily caching some steps so we only need fewer rounds, which will save some CPU time. ## Summary of changes The output of pbkdf2 is the XOR of the outputs of each iteration round, eg `U1 ^ U2 ^ ... U15 ^ U16 ^ U17 ^ ... ^ Un`. We cache the suffix of the expression `U16 ^ U17 ^ ... ^ Un`. To compute the result from the cached suffix, we only need to compute the prefix `U1 ^ U2 ^ ... U15`. The suffix by itself is useless, which prevent's its use in brute-force attacks should this cached memory leak. We are also caching the full 4096 round hash in memory, which can be used for brute-force attacks, where this suffix could be used to speed it up. My hope/expectation is that since these will be in different allocations, it makes any such memory exploitation much much harder. Since the full hash cache might be invalidated while the suffix is cached, I'm storing the timestamp of the computation as a way to identity the match. I also added `zeroize()` to clear the sensitive state from the stack/heap. For the most security conscious customers, we hope to roll out OIDC soon, so they can disable passwords entirely. --- The numbers for the threadpool were pretty random, but according to our busiest region for sql-over-http, we only see about 150 unique endpoints every minute. So storing ~100 of the most common endpoints for that minute should be the vast majority of requests. 1 minute was chosen so we don't keep data in memory for too long.
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -5519,6 +5519,7 @@ dependencies = [
|
|||||||
"workspace_hack",
|
"workspace_hack",
|
||||||
"x509-cert",
|
"x509-cert",
|
||||||
"zerocopy 0.8.24",
|
"zerocopy 0.8.24",
|
||||||
|
"zeroize",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|||||||
@@ -234,9 +234,10 @@ uuid = { version = "1.6.1", features = ["v4", "v7", "serde"] }
|
|||||||
walkdir = "2.3.2"
|
walkdir = "2.3.2"
|
||||||
rustls-native-certs = "0.8"
|
rustls-native-certs = "0.8"
|
||||||
whoami = "1.5.1"
|
whoami = "1.5.1"
|
||||||
zerocopy = { version = "0.8", features = ["derive", "simd"] }
|
|
||||||
json-structural-diff = { version = "0.2.0" }
|
json-structural-diff = { version = "0.2.0" }
|
||||||
x509-cert = { version = "0.2.5" }
|
x509-cert = { version = "0.2.5" }
|
||||||
|
zerocopy = { version = "0.8", features = ["derive", "simd"] }
|
||||||
|
zeroize = "1.8"
|
||||||
|
|
||||||
## TODO replace this with tracing
|
## TODO replace this with tracing
|
||||||
env_logger = "0.11"
|
env_logger = "0.11"
|
||||||
|
|||||||
@@ -107,6 +107,7 @@ uuid.workspace = true
|
|||||||
x509-cert.workspace = true
|
x509-cert.workspace = true
|
||||||
redis.workspace = true
|
redis.workspace = true
|
||||||
zerocopy.workspace = true
|
zerocopy.workspace = true
|
||||||
|
zeroize.workspace = true
|
||||||
# uncomment this to use the real subzero-core crate
|
# uncomment this to use the real subzero-core crate
|
||||||
# subzero-core = { git = "https://github.com/neondatabase/subzero", rev = "396264617e78e8be428682f87469bb25429af88a", features = ["postgresql"], optional = true }
|
# subzero-core = { git = "https://github.com/neondatabase/subzero", rev = "396264617e78e8be428682f87469bb25429af88a", features = ["postgresql"], optional = true }
|
||||||
# this is a stub for the subzero-core crate
|
# this is a stub for the subzero-core crate
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ use crate::auth::{self, AuthFlow};
|
|||||||
use crate::config::AuthenticationConfig;
|
use crate::config::AuthenticationConfig;
|
||||||
use crate::context::RequestContext;
|
use crate::context::RequestContext;
|
||||||
use crate::control_plane::AuthSecret;
|
use crate::control_plane::AuthSecret;
|
||||||
use crate::intern::EndpointIdInt;
|
use crate::intern::{EndpointIdInt, RoleNameInt};
|
||||||
use crate::sasl;
|
use crate::sasl;
|
||||||
use crate::stream::{self, Stream};
|
use crate::stream::{self, Stream};
|
||||||
|
|
||||||
@@ -25,13 +25,15 @@ pub(crate) async fn authenticate_cleartext(
|
|||||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||||
|
|
||||||
let ep = EndpointIdInt::from(&info.endpoint);
|
let ep = EndpointIdInt::from(&info.endpoint);
|
||||||
|
let role = RoleNameInt::from(&info.user);
|
||||||
|
|
||||||
let auth_flow = AuthFlow::new(
|
let auth_flow = AuthFlow::new(
|
||||||
client,
|
client,
|
||||||
auth::CleartextPassword {
|
auth::CleartextPassword {
|
||||||
secret,
|
secret,
|
||||||
endpoint: ep,
|
endpoint: ep,
|
||||||
pool: config.thread_pool.clone(),
|
role,
|
||||||
|
pool: config.scram_thread_pool.clone(),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
let auth_outcome = {
|
let auth_outcome = {
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ use crate::control_plane::messages::EndpointRateLimitConfig;
|
|||||||
use crate::control_plane::{
|
use crate::control_plane::{
|
||||||
self, AccessBlockerFlags, AuthSecret, ControlPlaneApi, EndpointAccessControl, RoleAccessControl,
|
self, AccessBlockerFlags, AuthSecret, ControlPlaneApi, EndpointAccessControl, RoleAccessControl,
|
||||||
};
|
};
|
||||||
use crate::intern::EndpointIdInt;
|
use crate::intern::{EndpointIdInt, RoleNameInt};
|
||||||
use crate::pqproto::BeMessage;
|
use crate::pqproto::BeMessage;
|
||||||
use crate::proxy::NeonOptions;
|
use crate::proxy::NeonOptions;
|
||||||
use crate::proxy::wake_compute::WakeComputeBackend;
|
use crate::proxy::wake_compute::WakeComputeBackend;
|
||||||
@@ -273,9 +273,11 @@ async fn authenticate_with_secret(
|
|||||||
) -> auth::Result<ComputeCredentials> {
|
) -> auth::Result<ComputeCredentials> {
|
||||||
if let Some(password) = unauthenticated_password {
|
if let Some(password) = unauthenticated_password {
|
||||||
let ep = EndpointIdInt::from(&info.endpoint);
|
let ep = EndpointIdInt::from(&info.endpoint);
|
||||||
|
let role = RoleNameInt::from(&info.user);
|
||||||
|
|
||||||
let auth_outcome =
|
let auth_outcome =
|
||||||
validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
|
validate_password_and_exchange(&config.scram_thread_pool, ep, role, &password, secret)
|
||||||
|
.await?;
|
||||||
let keys = match auth_outcome {
|
let keys = match auth_outcome {
|
||||||
crate::sasl::Outcome::Success(key) => key,
|
crate::sasl::Outcome::Success(key) => key,
|
||||||
crate::sasl::Outcome::Failure(reason) => {
|
crate::sasl::Outcome::Failure(reason) => {
|
||||||
@@ -499,7 +501,7 @@ mod tests {
|
|||||||
|
|
||||||
static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
|
static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
|
||||||
jwks_cache: JwkCache::default(),
|
jwks_cache: JwkCache::default(),
|
||||||
thread_pool: ThreadPool::new(1),
|
scram_thread_pool: ThreadPool::new(1),
|
||||||
scram_protocol_timeout: std::time::Duration::from_secs(5),
|
scram_protocol_timeout: std::time::Duration::from_secs(5),
|
||||||
ip_allowlist_check_enabled: true,
|
ip_allowlist_check_enabled: true,
|
||||||
is_vpc_acccess_proxy: false,
|
is_vpc_acccess_proxy: false,
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ use super::backend::ComputeCredentialKeys;
|
|||||||
use super::{AuthError, PasswordHackPayload};
|
use super::{AuthError, PasswordHackPayload};
|
||||||
use crate::context::RequestContext;
|
use crate::context::RequestContext;
|
||||||
use crate::control_plane::AuthSecret;
|
use crate::control_plane::AuthSecret;
|
||||||
use crate::intern::EndpointIdInt;
|
use crate::intern::{EndpointIdInt, RoleNameInt};
|
||||||
use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
|
use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
|
||||||
use crate::sasl;
|
use crate::sasl;
|
||||||
use crate::scram::threadpool::ThreadPool;
|
use crate::scram::threadpool::ThreadPool;
|
||||||
@@ -46,6 +46,7 @@ pub(crate) struct PasswordHack;
|
|||||||
pub(crate) struct CleartextPassword {
|
pub(crate) struct CleartextPassword {
|
||||||
pub(crate) pool: Arc<ThreadPool>,
|
pub(crate) pool: Arc<ThreadPool>,
|
||||||
pub(crate) endpoint: EndpointIdInt,
|
pub(crate) endpoint: EndpointIdInt,
|
||||||
|
pub(crate) role: RoleNameInt,
|
||||||
pub(crate) secret: AuthSecret,
|
pub(crate) secret: AuthSecret,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -111,6 +112,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
|
|||||||
let outcome = validate_password_and_exchange(
|
let outcome = validate_password_and_exchange(
|
||||||
&self.state.pool,
|
&self.state.pool,
|
||||||
self.state.endpoint,
|
self.state.endpoint,
|
||||||
|
self.state.role,
|
||||||
password,
|
password,
|
||||||
self.state.secret,
|
self.state.secret,
|
||||||
)
|
)
|
||||||
@@ -165,13 +167,15 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
|||||||
pub(crate) async fn validate_password_and_exchange(
|
pub(crate) async fn validate_password_and_exchange(
|
||||||
pool: &ThreadPool,
|
pool: &ThreadPool,
|
||||||
endpoint: EndpointIdInt,
|
endpoint: EndpointIdInt,
|
||||||
|
role: RoleNameInt,
|
||||||
password: &[u8],
|
password: &[u8],
|
||||||
secret: AuthSecret,
|
secret: AuthSecret,
|
||||||
) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
|
) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
|
||||||
match secret {
|
match secret {
|
||||||
// perform scram authentication as both client and server to validate the keys
|
// perform scram authentication as both client and server to validate the keys
|
||||||
AuthSecret::Scram(scram_secret) => {
|
AuthSecret::Scram(scram_secret) => {
|
||||||
let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
|
let outcome =
|
||||||
|
crate::scram::exchange(pool, endpoint, role, &scram_secret, password).await?;
|
||||||
|
|
||||||
let client_key = match outcome {
|
let client_key = match outcome {
|
||||||
sasl::Outcome::Success(client_key) => client_key,
|
sasl::Outcome::Success(client_key) => client_key,
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ use crate::config::{
|
|||||||
};
|
};
|
||||||
use crate::control_plane::locks::ApiLocks;
|
use crate::control_plane::locks::ApiLocks;
|
||||||
use crate::http::health_server::AppMetrics;
|
use crate::http::health_server::AppMetrics;
|
||||||
use crate::metrics::{Metrics, ServiceInfo, ThreadPoolMetrics};
|
use crate::metrics::{Metrics, ServiceInfo};
|
||||||
use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo};
|
use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo};
|
||||||
use crate::scram::threadpool::ThreadPool;
|
use crate::scram::threadpool::ThreadPool;
|
||||||
use crate::serverless::cancel_set::CancelSet;
|
use crate::serverless::cancel_set::CancelSet;
|
||||||
@@ -114,8 +114,6 @@ pub async fn run() -> anyhow::Result<()> {
|
|||||||
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
|
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
|
||||||
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
|
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
|
||||||
|
|
||||||
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
|
|
||||||
|
|
||||||
// TODO: refactor these to use labels
|
// TODO: refactor these to use labels
|
||||||
debug!("Version: {GIT_VERSION}");
|
debug!("Version: {GIT_VERSION}");
|
||||||
debug!("Build_tag: {BUILD_TAG}");
|
debug!("Build_tag: {BUILD_TAG}");
|
||||||
@@ -284,7 +282,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
|
|||||||
http_config,
|
http_config,
|
||||||
authentication_config: AuthenticationConfig {
|
authentication_config: AuthenticationConfig {
|
||||||
jwks_cache: JwkCache::default(),
|
jwks_cache: JwkCache::default(),
|
||||||
thread_pool: ThreadPool::new(0),
|
scram_thread_pool: ThreadPool::new(0),
|
||||||
scram_protocol_timeout: Duration::from_secs(10),
|
scram_protocol_timeout: Duration::from_secs(10),
|
||||||
ip_allowlist_check_enabled: true,
|
ip_allowlist_check_enabled: true,
|
||||||
is_vpc_acccess_proxy: false,
|
is_vpc_acccess_proxy: false,
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ use utils::project_git_version;
|
|||||||
use utils::sentry_init::init_sentry;
|
use utils::sentry_init::init_sentry;
|
||||||
|
|
||||||
use crate::context::RequestContext;
|
use crate::context::RequestContext;
|
||||||
use crate::metrics::{Metrics, ServiceInfo, ThreadPoolMetrics};
|
use crate::metrics::{Metrics, ServiceInfo};
|
||||||
use crate::pglb::TlsRequired;
|
use crate::pglb::TlsRequired;
|
||||||
use crate::pqproto::FeStartupPacket;
|
use crate::pqproto::FeStartupPacket;
|
||||||
use crate::protocol2::ConnectionInfo;
|
use crate::protocol2::ConnectionInfo;
|
||||||
@@ -80,8 +80,6 @@ pub async fn run() -> anyhow::Result<()> {
|
|||||||
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
|
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
|
||||||
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
|
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
|
||||||
|
|
||||||
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
|
|
||||||
|
|
||||||
let args = cli().get_matches();
|
let args = cli().get_matches();
|
||||||
let destination: String = args
|
let destination: String = args
|
||||||
.get_one::<String>("dest")
|
.get_one::<String>("dest")
|
||||||
|
|||||||
@@ -617,7 +617,12 @@ pub async fn run() -> anyhow::Result<()> {
|
|||||||
/// ProxyConfig is created at proxy startup, and lives forever.
|
/// ProxyConfig is created at proxy startup, and lives forever.
|
||||||
fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||||
let thread_pool = ThreadPool::new(args.scram_thread_pool_size);
|
let thread_pool = ThreadPool::new(args.scram_thread_pool_size);
|
||||||
Metrics::install(thread_pool.metrics.clone());
|
Metrics::get()
|
||||||
|
.proxy
|
||||||
|
.scram_pool
|
||||||
|
.0
|
||||||
|
.set(thread_pool.metrics.clone())
|
||||||
|
.ok();
|
||||||
|
|
||||||
let tls_config = match (&args.tls_key, &args.tls_cert) {
|
let tls_config = match (&args.tls_key, &args.tls_cert) {
|
||||||
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(
|
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(
|
||||||
@@ -690,7 +695,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
|||||||
};
|
};
|
||||||
let authentication_config = AuthenticationConfig {
|
let authentication_config = AuthenticationConfig {
|
||||||
jwks_cache: JwkCache::default(),
|
jwks_cache: JwkCache::default(),
|
||||||
thread_pool,
|
scram_thread_pool: thread_pool,
|
||||||
scram_protocol_timeout: args.scram_protocol_timeout,
|
scram_protocol_timeout: args.scram_protocol_timeout,
|
||||||
ip_allowlist_check_enabled: !args.is_private_access_proxy,
|
ip_allowlist_check_enabled: !args.is_private_access_proxy,
|
||||||
is_vpc_acccess_proxy: args.is_private_access_proxy,
|
is_vpc_acccess_proxy: args.is_private_access_proxy,
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
|
|||||||
use crate::ext::TaskExt;
|
use crate::ext::TaskExt;
|
||||||
use crate::intern::RoleNameInt;
|
use crate::intern::RoleNameInt;
|
||||||
use crate::rate_limiter::{RateLimitAlgorithm, RateLimiterConfig};
|
use crate::rate_limiter::{RateLimitAlgorithm, RateLimiterConfig};
|
||||||
use crate::scram::threadpool::ThreadPool;
|
use crate::scram;
|
||||||
use crate::serverless::GlobalConnPoolOptions;
|
use crate::serverless::GlobalConnPoolOptions;
|
||||||
use crate::serverless::cancel_set::CancelSet;
|
use crate::serverless::cancel_set::CancelSet;
|
||||||
#[cfg(feature = "rest_broker")]
|
#[cfg(feature = "rest_broker")]
|
||||||
@@ -75,7 +75,7 @@ pub struct HttpConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub struct AuthenticationConfig {
|
pub struct AuthenticationConfig {
|
||||||
pub thread_pool: Arc<ThreadPool>,
|
pub scram_thread_pool: Arc<scram::threadpool::ThreadPool>,
|
||||||
pub scram_protocol_timeout: tokio::time::Duration,
|
pub scram_protocol_timeout: tokio::time::Duration,
|
||||||
pub ip_allowlist_check_enabled: bool,
|
pub ip_allowlist_check_enabled: bool,
|
||||||
pub is_vpc_acccess_proxy: bool,
|
pub is_vpc_acccess_proxy: bool,
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ use measured::label::{
|
|||||||
FixedCardinalitySet, LabelGroupSet, LabelGroupVisitor, LabelName, LabelSet, LabelValue,
|
FixedCardinalitySet, LabelGroupSet, LabelGroupVisitor, LabelName, LabelSet, LabelValue,
|
||||||
StaticLabelSet,
|
StaticLabelSet,
|
||||||
};
|
};
|
||||||
|
use measured::metric::group::Encoding;
|
||||||
use measured::metric::histogram::Thresholds;
|
use measured::metric::histogram::Thresholds;
|
||||||
use measured::metric::name::MetricName;
|
use measured::metric::name::MetricName;
|
||||||
use measured::{
|
use measured::{
|
||||||
@@ -18,10 +19,10 @@ use crate::control_plane::messages::ColdStartInfo;
|
|||||||
use crate::error::ErrorKind;
|
use crate::error::ErrorKind;
|
||||||
|
|
||||||
#[derive(MetricGroup)]
|
#[derive(MetricGroup)]
|
||||||
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>))]
|
#[metric(new())]
|
||||||
pub struct Metrics {
|
pub struct Metrics {
|
||||||
#[metric(namespace = "proxy")]
|
#[metric(namespace = "proxy")]
|
||||||
#[metric(init = ProxyMetrics::new(thread_pool))]
|
#[metric(init = ProxyMetrics::new())]
|
||||||
pub proxy: ProxyMetrics,
|
pub proxy: ProxyMetrics,
|
||||||
|
|
||||||
#[metric(namespace = "wake_compute_lock")]
|
#[metric(namespace = "wake_compute_lock")]
|
||||||
@@ -34,34 +35,27 @@ pub struct Metrics {
|
|||||||
pub cache: CacheMetrics,
|
pub cache: CacheMetrics,
|
||||||
}
|
}
|
||||||
|
|
||||||
static SELF: OnceLock<Metrics> = OnceLock::new();
|
|
||||||
impl Metrics {
|
impl Metrics {
|
||||||
pub fn install(thread_pool: Arc<ThreadPoolMetrics>) {
|
#[track_caller]
|
||||||
let mut metrics = Metrics::new(thread_pool);
|
|
||||||
|
|
||||||
metrics.proxy.errors_total.init_all_dense();
|
|
||||||
metrics.proxy.redis_errors_total.init_all_dense();
|
|
||||||
metrics.proxy.redis_events_count.init_all_dense();
|
|
||||||
metrics.proxy.retries_metric.init_all_dense();
|
|
||||||
metrics.proxy.connection_failures_total.init_all_dense();
|
|
||||||
|
|
||||||
SELF.set(metrics)
|
|
||||||
.ok()
|
|
||||||
.expect("proxy metrics must not be installed more than once");
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get() -> &'static Self {
|
pub fn get() -> &'static Self {
|
||||||
#[cfg(test)]
|
static SELF: OnceLock<Metrics> = OnceLock::new();
|
||||||
return SELF.get_or_init(|| Metrics::new(Arc::new(ThreadPoolMetrics::new(0))));
|
|
||||||
|
|
||||||
#[cfg(not(test))]
|
SELF.get_or_init(|| {
|
||||||
SELF.get()
|
let mut metrics = Metrics::new();
|
||||||
.expect("proxy metrics must be installed by the main() function")
|
|
||||||
|
metrics.proxy.errors_total.init_all_dense();
|
||||||
|
metrics.proxy.redis_errors_total.init_all_dense();
|
||||||
|
metrics.proxy.redis_events_count.init_all_dense();
|
||||||
|
metrics.proxy.retries_metric.init_all_dense();
|
||||||
|
metrics.proxy.connection_failures_total.init_all_dense();
|
||||||
|
|
||||||
|
metrics
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(MetricGroup)]
|
#[derive(MetricGroup)]
|
||||||
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>))]
|
#[metric(new())]
|
||||||
pub struct ProxyMetrics {
|
pub struct ProxyMetrics {
|
||||||
#[metric(flatten)]
|
#[metric(flatten)]
|
||||||
pub db_connections: CounterPairVec<NumDbConnectionsGauge>,
|
pub db_connections: CounterPairVec<NumDbConnectionsGauge>,
|
||||||
@@ -134,6 +128,9 @@ pub struct ProxyMetrics {
|
|||||||
/// Number of TLS handshake failures
|
/// Number of TLS handshake failures
|
||||||
pub tls_handshake_failures: Counter,
|
pub tls_handshake_failures: Counter,
|
||||||
|
|
||||||
|
/// Number of SHA 256 rounds executed.
|
||||||
|
pub sha_rounds: Counter,
|
||||||
|
|
||||||
/// HLL approximate cardinality of endpoints that are connecting
|
/// HLL approximate cardinality of endpoints that are connecting
|
||||||
pub connecting_endpoints: HyperLogLogVec<StaticLabelSet<Protocol>, 32>,
|
pub connecting_endpoints: HyperLogLogVec<StaticLabelSet<Protocol>, 32>,
|
||||||
|
|
||||||
@@ -151,8 +148,25 @@ pub struct ProxyMetrics {
|
|||||||
pub connect_compute_lock: ApiLockMetrics,
|
pub connect_compute_lock: ApiLockMetrics,
|
||||||
|
|
||||||
#[metric(namespace = "scram_pool")]
|
#[metric(namespace = "scram_pool")]
|
||||||
#[metric(init = thread_pool)]
|
pub scram_pool: OnceLockWrapper<Arc<ThreadPoolMetrics>>,
|
||||||
pub scram_pool: Arc<ThreadPoolMetrics>,
|
}
|
||||||
|
|
||||||
|
/// A Wrapper over [`OnceLock`] to implement [`MetricGroup`].
|
||||||
|
pub struct OnceLockWrapper<T>(pub OnceLock<T>);
|
||||||
|
|
||||||
|
impl<T> Default for OnceLockWrapper<T> {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self(OnceLock::new())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Enc: Encoding, T: MetricGroup<Enc>> MetricGroup<Enc> for OnceLockWrapper<T> {
|
||||||
|
fn collect_group_into(&self, enc: &mut Enc) -> Result<(), Enc::Err> {
|
||||||
|
if let Some(inner) = self.0.get() {
|
||||||
|
inner.collect_group_into(enc)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(MetricGroup)]
|
#[derive(MetricGroup)]
|
||||||
@@ -719,6 +733,7 @@ pub enum CacheKind {
|
|||||||
ProjectInfoEndpoints,
|
ProjectInfoEndpoints,
|
||||||
ProjectInfoRoles,
|
ProjectInfoRoles,
|
||||||
Schema,
|
Schema,
|
||||||
|
Pbkdf2,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]
|
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]
|
||||||
|
|||||||
84
proxy/src/scram/cache.rs
Normal file
84
proxy/src/scram/cache.rs
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
use tokio::time::Instant;
|
||||||
|
use zeroize::Zeroize as _;
|
||||||
|
|
||||||
|
use super::pbkdf2;
|
||||||
|
use crate::cache::Cached;
|
||||||
|
use crate::cache::common::{Cache, count_cache_insert, count_cache_outcome, eviction_listener};
|
||||||
|
use crate::intern::{EndpointIdInt, RoleNameInt};
|
||||||
|
use crate::metrics::{CacheKind, Metrics};
|
||||||
|
|
||||||
|
pub(crate) struct Pbkdf2Cache(moka::sync::Cache<(EndpointIdInt, RoleNameInt), Pbkdf2CacheEntry>);
|
||||||
|
pub(crate) type CachedPbkdf2<'a> = Cached<&'a Pbkdf2Cache>;
|
||||||
|
|
||||||
|
impl Cache for Pbkdf2Cache {
|
||||||
|
type Key = (EndpointIdInt, RoleNameInt);
|
||||||
|
type Value = Pbkdf2CacheEntry;
|
||||||
|
|
||||||
|
fn invalidate(&self, info: &(EndpointIdInt, RoleNameInt)) {
|
||||||
|
self.0.invalidate(info);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// To speed up password hashing for more active customers, we store the tail results of the
|
||||||
|
/// PBKDF2 algorithm. If the output of PBKDF2 is U1 ^ U2 ^ ⋯ ^ Uc, then we store
|
||||||
|
/// suffix = U17 ^ U18 ^ ⋯ ^ Uc. We only need to calculate U1 ^ U2 ^ ⋯ ^ U15 ^ U16
|
||||||
|
/// to determine the final result.
|
||||||
|
///
|
||||||
|
/// The suffix alone isn't enough to crack the password. The stored_key is still required.
|
||||||
|
/// While both are cached in memory, given they're in different locations is makes it much
|
||||||
|
/// harder to exploit, even if any such memory exploit exists in proxy.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Pbkdf2CacheEntry {
|
||||||
|
/// corresponds to [`super::ServerSecret::cached_at`]
|
||||||
|
pub(super) cached_from: Instant,
|
||||||
|
pub(super) suffix: pbkdf2::Block,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for Pbkdf2CacheEntry {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.suffix.zeroize();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Pbkdf2Cache {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
const SIZE: u64 = 100;
|
||||||
|
const TTL: std::time::Duration = std::time::Duration::from_secs(60);
|
||||||
|
|
||||||
|
let builder = moka::sync::Cache::builder()
|
||||||
|
.name("pbkdf2")
|
||||||
|
.max_capacity(SIZE)
|
||||||
|
// We use time_to_live so we don't refresh the lifetime for an invalid password attempt.
|
||||||
|
.time_to_live(TTL);
|
||||||
|
|
||||||
|
Metrics::get()
|
||||||
|
.cache
|
||||||
|
.capacity
|
||||||
|
.set(CacheKind::Pbkdf2, SIZE as i64);
|
||||||
|
|
||||||
|
let builder =
|
||||||
|
builder.eviction_listener(|_k, _v, cause| eviction_listener(CacheKind::Pbkdf2, cause));
|
||||||
|
|
||||||
|
Self(builder.build())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn insert(&self, endpoint: EndpointIdInt, role: RoleNameInt, value: Pbkdf2CacheEntry) {
|
||||||
|
count_cache_insert(CacheKind::Pbkdf2);
|
||||||
|
self.0.insert((endpoint, role), value);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get(&self, endpoint: EndpointIdInt, role: RoleNameInt) -> Option<Pbkdf2CacheEntry> {
|
||||||
|
count_cache_outcome(CacheKind::Pbkdf2, self.0.get(&(endpoint, role)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_entry(
|
||||||
|
&self,
|
||||||
|
endpoint: EndpointIdInt,
|
||||||
|
role: RoleNameInt,
|
||||||
|
) -> Option<CachedPbkdf2<'_>> {
|
||||||
|
self.get(endpoint, role).map(|value| Cached {
|
||||||
|
token: Some((self, (endpoint, role))),
|
||||||
|
value,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,10 +4,8 @@ use std::convert::Infallible;
|
|||||||
|
|
||||||
use base64::Engine as _;
|
use base64::Engine as _;
|
||||||
use base64::prelude::BASE64_STANDARD;
|
use base64::prelude::BASE64_STANDARD;
|
||||||
use hmac::{Hmac, Mac};
|
use tracing::{debug, trace};
|
||||||
use sha2::Sha256;
|
|
||||||
|
|
||||||
use super::ScramKey;
|
|
||||||
use super::messages::{
|
use super::messages::{
|
||||||
ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
|
ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
|
||||||
};
|
};
|
||||||
@@ -15,8 +13,10 @@ use super::pbkdf2::Pbkdf2;
|
|||||||
use super::secret::ServerSecret;
|
use super::secret::ServerSecret;
|
||||||
use super::signature::SignatureBuilder;
|
use super::signature::SignatureBuilder;
|
||||||
use super::threadpool::ThreadPool;
|
use super::threadpool::ThreadPool;
|
||||||
use crate::intern::EndpointIdInt;
|
use super::{ScramKey, pbkdf2};
|
||||||
|
use crate::intern::{EndpointIdInt, RoleNameInt};
|
||||||
use crate::sasl::{self, ChannelBinding, Error as SaslError};
|
use crate::sasl::{self, ChannelBinding, Error as SaslError};
|
||||||
|
use crate::scram::cache::Pbkdf2CacheEntry;
|
||||||
|
|
||||||
/// The only channel binding mode we currently support.
|
/// The only channel binding mode we currently support.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@@ -77,46 +77,113 @@ impl<'a> Exchange<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L236-L248>
|
|
||||||
async fn derive_client_key(
|
async fn derive_client_key(
|
||||||
pool: &ThreadPool,
|
pool: &ThreadPool,
|
||||||
endpoint: EndpointIdInt,
|
endpoint: EndpointIdInt,
|
||||||
password: &[u8],
|
password: &[u8],
|
||||||
salt: &[u8],
|
salt: &[u8],
|
||||||
iterations: u32,
|
iterations: u32,
|
||||||
) -> ScramKey {
|
) -> pbkdf2::Block {
|
||||||
let salted_password = pool
|
pool.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
|
||||||
.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
|
.await
|
||||||
.await;
|
|
||||||
|
|
||||||
let make_key = |name| {
|
|
||||||
let key = Hmac::<Sha256>::new_from_slice(&salted_password)
|
|
||||||
.expect("HMAC is able to accept all key sizes")
|
|
||||||
.chain_update(name)
|
|
||||||
.finalize();
|
|
||||||
|
|
||||||
<[u8; 32]>::from(key.into_bytes())
|
|
||||||
};
|
|
||||||
|
|
||||||
make_key(b"Client Key").into()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// For cleartext flow, we need to derive the client key to
|
||||||
|
/// 1. authenticate the client.
|
||||||
|
/// 2. authenticate with compute.
|
||||||
pub(crate) async fn exchange(
|
pub(crate) async fn exchange(
|
||||||
pool: &ThreadPool,
|
pool: &ThreadPool,
|
||||||
endpoint: EndpointIdInt,
|
endpoint: EndpointIdInt,
|
||||||
|
role: RoleNameInt,
|
||||||
|
secret: &ServerSecret,
|
||||||
|
password: &[u8],
|
||||||
|
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
|
||||||
|
if secret.iterations > CACHED_ROUNDS {
|
||||||
|
exchange_with_cache(pool, endpoint, role, secret, password).await
|
||||||
|
} else {
|
||||||
|
let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?;
|
||||||
|
let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
|
||||||
|
Ok(validate_pbkdf2(secret, &hash))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compute the client key using a cache. We cache the suffix of the pbkdf2 result only,
|
||||||
|
/// which is not enough by itself to perform an offline brute force.
|
||||||
|
async fn exchange_with_cache(
|
||||||
|
pool: &ThreadPool,
|
||||||
|
endpoint: EndpointIdInt,
|
||||||
|
role: RoleNameInt,
|
||||||
secret: &ServerSecret,
|
secret: &ServerSecret,
|
||||||
password: &[u8],
|
password: &[u8],
|
||||||
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
|
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
|
||||||
let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?;
|
let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?;
|
||||||
let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
|
|
||||||
|
|
||||||
|
debug_assert!(
|
||||||
|
secret.iterations > CACHED_ROUNDS,
|
||||||
|
"we should not cache password data if there isn't enough rounds needed"
|
||||||
|
);
|
||||||
|
|
||||||
|
// compute the prefix of the pbkdf2 output.
|
||||||
|
let prefix = derive_client_key(pool, endpoint, password, &salt, CACHED_ROUNDS).await;
|
||||||
|
|
||||||
|
if let Some(entry) = pool.cache.get_entry(endpoint, role) {
|
||||||
|
// hot path: let's check the threadpool cache
|
||||||
|
if secret.cached_at == entry.cached_from {
|
||||||
|
// cache is valid. compute the full hash by adding the prefix to the suffix.
|
||||||
|
let mut hash = prefix;
|
||||||
|
pbkdf2::xor_assign(&mut hash, &entry.suffix);
|
||||||
|
let outcome = validate_pbkdf2(secret, &hash);
|
||||||
|
|
||||||
|
if matches!(outcome, sasl::Outcome::Success(_)) {
|
||||||
|
trace!("password validated from cache");
|
||||||
|
}
|
||||||
|
|
||||||
|
return Ok(outcome);
|
||||||
|
}
|
||||||
|
|
||||||
|
// cached key is no longer valid.
|
||||||
|
debug!("invalidating cached password");
|
||||||
|
entry.invalidate();
|
||||||
|
}
|
||||||
|
|
||||||
|
// slow path: full password hash.
|
||||||
|
let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
|
||||||
|
let outcome = validate_pbkdf2(secret, &hash);
|
||||||
|
|
||||||
|
let client_key = match outcome {
|
||||||
|
sasl::Outcome::Success(client_key) => client_key,
|
||||||
|
sasl::Outcome::Failure(_) => return Ok(outcome),
|
||||||
|
};
|
||||||
|
|
||||||
|
trace!("storing cached password");
|
||||||
|
|
||||||
|
// time to cache, compute the suffix by subtracting the prefix from the hash.
|
||||||
|
let mut suffix = hash;
|
||||||
|
pbkdf2::xor_assign(&mut suffix, &prefix);
|
||||||
|
|
||||||
|
pool.cache.insert(
|
||||||
|
endpoint,
|
||||||
|
role,
|
||||||
|
Pbkdf2CacheEntry {
|
||||||
|
cached_from: secret.cached_at,
|
||||||
|
suffix,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(sasl::Outcome::Success(client_key))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate_pbkdf2(secret: &ServerSecret, hash: &pbkdf2::Block) -> sasl::Outcome<ScramKey> {
|
||||||
|
let client_key = super::ScramKey::client_key(&(*hash).into());
|
||||||
if secret.is_password_invalid(&client_key).into() {
|
if secret.is_password_invalid(&client_key).into() {
|
||||||
Ok(sasl::Outcome::Failure("password doesn't match"))
|
sasl::Outcome::Failure("password doesn't match")
|
||||||
} else {
|
} else {
|
||||||
Ok(sasl::Outcome::Success(client_key))
|
sasl::Outcome::Success(client_key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const CACHED_ROUNDS: u32 = 16;
|
||||||
|
|
||||||
impl SaslInitial {
|
impl SaslInitial {
|
||||||
fn transition(
|
fn transition(
|
||||||
&self,
|
&self,
|
||||||
|
|||||||
@@ -1,6 +1,12 @@
|
|||||||
//! Tools for client/server/stored key management.
|
//! Tools for client/server/stored key management.
|
||||||
|
|
||||||
|
use hmac::Mac as _;
|
||||||
|
use sha2::Digest as _;
|
||||||
use subtle::ConstantTimeEq;
|
use subtle::ConstantTimeEq;
|
||||||
|
use zeroize::Zeroize as _;
|
||||||
|
|
||||||
|
use crate::metrics::Metrics;
|
||||||
|
use crate::scram::pbkdf2::Prf;
|
||||||
|
|
||||||
/// Faithfully taken from PostgreSQL.
|
/// Faithfully taken from PostgreSQL.
|
||||||
pub(crate) const SCRAM_KEY_LEN: usize = 32;
|
pub(crate) const SCRAM_KEY_LEN: usize = 32;
|
||||||
@@ -14,6 +20,12 @@ pub(crate) struct ScramKey {
|
|||||||
bytes: [u8; SCRAM_KEY_LEN],
|
bytes: [u8; SCRAM_KEY_LEN],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Drop for ScramKey {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.bytes.zeroize();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl PartialEq for ScramKey {
|
impl PartialEq for ScramKey {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
self.ct_eq(other).into()
|
self.ct_eq(other).into()
|
||||||
@@ -28,12 +40,26 @@ impl ConstantTimeEq for ScramKey {
|
|||||||
|
|
||||||
impl ScramKey {
|
impl ScramKey {
|
||||||
pub(crate) fn sha256(&self) -> Self {
|
pub(crate) fn sha256(&self) -> Self {
|
||||||
super::sha256([self.as_ref()]).into()
|
Metrics::get().proxy.sha_rounds.inc_by(1);
|
||||||
|
Self {
|
||||||
|
bytes: sha2::Sha256::digest(self.as_bytes()).into(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn as_bytes(&self) -> [u8; SCRAM_KEY_LEN] {
|
pub(crate) fn as_bytes(&self) -> [u8; SCRAM_KEY_LEN] {
|
||||||
self.bytes
|
self.bytes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn client_key(b: &[u8; 32]) -> Self {
|
||||||
|
// Prf::new_from_slice will run 2 sha256 rounds.
|
||||||
|
// Update + Finalize run 2 sha256 rounds.
|
||||||
|
Metrics::get().proxy.sha_rounds.inc_by(4);
|
||||||
|
|
||||||
|
let mut prf = Prf::new_from_slice(b).expect("HMAC is able to accept all key sizes");
|
||||||
|
prf.update(b"Client Key");
|
||||||
|
let client_key: [u8; 32] = prf.finalize().into_bytes().into();
|
||||||
|
client_key.into()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<[u8; SCRAM_KEY_LEN]> for ScramKey {
|
impl From<[u8; SCRAM_KEY_LEN]> for ScramKey {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-scram.c>
|
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-scram.c>
|
||||||
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth-scram.c>
|
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth-scram.c>
|
||||||
|
|
||||||
|
mod cache;
|
||||||
mod countmin;
|
mod countmin;
|
||||||
mod exchange;
|
mod exchange;
|
||||||
mod key;
|
mod key;
|
||||||
@@ -18,10 +19,8 @@ pub mod threadpool;
|
|||||||
use base64::Engine as _;
|
use base64::Engine as _;
|
||||||
use base64::prelude::BASE64_STANDARD;
|
use base64::prelude::BASE64_STANDARD;
|
||||||
pub(crate) use exchange::{Exchange, exchange};
|
pub(crate) use exchange::{Exchange, exchange};
|
||||||
use hmac::{Hmac, Mac};
|
|
||||||
pub(crate) use key::ScramKey;
|
pub(crate) use key::ScramKey;
|
||||||
pub(crate) use secret::ServerSecret;
|
pub(crate) use secret::ServerSecret;
|
||||||
use sha2::{Digest, Sha256};
|
|
||||||
|
|
||||||
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
|
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
|
||||||
const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
|
const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
|
||||||
@@ -42,29 +41,13 @@ fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N
|
|||||||
Some(bytes)
|
Some(bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This function essentially is `Hmac(sha256, key, input)`.
|
|
||||||
/// Further reading: <https://datatracker.ietf.org/doc/html/rfc2104>.
|
|
||||||
fn hmac_sha256<'a>(key: &[u8], parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
|
|
||||||
let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("bad key size");
|
|
||||||
parts.into_iter().for_each(|s| mac.update(s));
|
|
||||||
|
|
||||||
mac.finalize().into_bytes().into()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
|
|
||||||
let mut hasher = Sha256::new();
|
|
||||||
parts.into_iter().for_each(|s| hasher.update(s));
|
|
||||||
|
|
||||||
hasher.finalize().into()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::threadpool::ThreadPool;
|
use super::threadpool::ThreadPool;
|
||||||
use super::{Exchange, ServerSecret};
|
use super::{Exchange, ServerSecret};
|
||||||
use crate::intern::EndpointIdInt;
|
use crate::intern::{EndpointIdInt, RoleNameInt};
|
||||||
use crate::sasl::{Mechanism, Step};
|
use crate::sasl::{Mechanism, Step};
|
||||||
use crate::types::EndpointId;
|
use crate::types::{EndpointId, RoleName};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn snapshot() {
|
fn snapshot() {
|
||||||
@@ -114,23 +97,34 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn run_round_trip_test(server_password: &str, client_password: &str) {
|
async fn check(
|
||||||
let pool = ThreadPool::new(1);
|
pool: &ThreadPool,
|
||||||
|
scram_secret: &ServerSecret,
|
||||||
|
password: &[u8],
|
||||||
|
) -> Result<(), &'static str> {
|
||||||
let ep = EndpointId::from("foo");
|
let ep = EndpointId::from("foo");
|
||||||
let ep = EndpointIdInt::from(ep);
|
let ep = EndpointIdInt::from(ep);
|
||||||
|
let role = RoleName::from("user");
|
||||||
|
let role = RoleNameInt::from(&role);
|
||||||
|
|
||||||
let scram_secret = ServerSecret::build(server_password).await.unwrap();
|
let outcome = super::exchange(pool, ep, role, scram_secret, password)
|
||||||
let outcome = super::exchange(&pool, ep, &scram_secret, client_password.as_bytes())
|
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
match outcome {
|
match outcome {
|
||||||
crate::sasl::Outcome::Success(_) => {}
|
crate::sasl::Outcome::Success(_) => Ok(()),
|
||||||
crate::sasl::Outcome::Failure(r) => panic!("{r}"),
|
crate::sasl::Outcome::Failure(r) => Err(r),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn run_round_trip_test(server_password: &str, client_password: &str) {
|
||||||
|
let pool = ThreadPool::new(1);
|
||||||
|
let scram_secret = ServerSecret::build(server_password).await.unwrap();
|
||||||
|
check(&pool, &scram_secret, client_password.as_bytes())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn round_trip() {
|
async fn round_trip() {
|
||||||
run_round_trip_test("pencil", "pencil").await;
|
run_round_trip_test("pencil", "pencil").await;
|
||||||
@@ -141,4 +135,27 @@ mod tests {
|
|||||||
async fn failure() {
|
async fn failure() {
|
||||||
run_round_trip_test("pencil", "eraser").await;
|
run_round_trip_test("pencil", "eraser").await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[tracing_test::traced_test]
|
||||||
|
async fn password_cache() {
|
||||||
|
let pool = ThreadPool::new(1);
|
||||||
|
let scram_secret = ServerSecret::build("password").await.unwrap();
|
||||||
|
|
||||||
|
// wrong passwords are not added to cache
|
||||||
|
check(&pool, &scram_secret, b"wrong").await.unwrap_err();
|
||||||
|
assert!(!logs_contain("storing cached password"));
|
||||||
|
|
||||||
|
// correct passwords get cached
|
||||||
|
check(&pool, &scram_secret, b"password").await.unwrap();
|
||||||
|
assert!(logs_contain("storing cached password"));
|
||||||
|
|
||||||
|
// wrong passwords do not match the cache
|
||||||
|
check(&pool, &scram_secret, b"wrong").await.unwrap_err();
|
||||||
|
assert!(!logs_contain("password validated from cache"));
|
||||||
|
|
||||||
|
// correct passwords match the cache
|
||||||
|
check(&pool, &scram_secret, b"password").await.unwrap();
|
||||||
|
assert!(logs_contain("password validated from cache"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,25 +1,50 @@
|
|||||||
|
//! For postgres password authentication, we need to perform a PBKDF2 using
|
||||||
|
//! PRF=HMAC-SHA2-256, producing only 1 block (32 bytes) of output key.
|
||||||
|
|
||||||
|
use hmac::Mac as _;
|
||||||
use hmac::digest::consts::U32;
|
use hmac::digest::consts::U32;
|
||||||
use hmac::digest::generic_array::GenericArray;
|
use hmac::digest::generic_array::GenericArray;
|
||||||
use hmac::{Hmac, Mac};
|
use zeroize::Zeroize as _;
|
||||||
use sha2::Sha256;
|
|
||||||
|
use crate::metrics::Metrics;
|
||||||
|
|
||||||
|
/// The Psuedo-random function used during PBKDF2 and the SCRAM-SHA-256 handshake.
|
||||||
|
pub type Prf = hmac::Hmac<sha2::Sha256>;
|
||||||
|
pub(crate) type Block = GenericArray<u8, U32>;
|
||||||
|
|
||||||
pub(crate) struct Pbkdf2 {
|
pub(crate) struct Pbkdf2 {
|
||||||
hmac: Hmac<Sha256>,
|
hmac: Prf,
|
||||||
prev: GenericArray<u8, U32>,
|
/// U{r-1} for whatever iteration r we are currently on.
|
||||||
hi: GenericArray<u8, U32>,
|
prev: Block,
|
||||||
|
/// the output of `fold(xor, U{1}..U{r})` for whatever iteration r we are currently on.
|
||||||
|
hi: Block,
|
||||||
|
/// number of iterations left
|
||||||
iterations: u32,
|
iterations: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Drop for Pbkdf2 {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.prev.zeroize();
|
||||||
|
self.hi.zeroize();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// inspired from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L36-L61>
|
// inspired from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L36-L61>
|
||||||
impl Pbkdf2 {
|
impl Pbkdf2 {
|
||||||
pub(crate) fn start(str: &[u8], salt: &[u8], iterations: u32) -> Self {
|
pub(crate) fn start(pw: &[u8], salt: &[u8], iterations: u32) -> Self {
|
||||||
// key the HMAC and derive the first block in-place
|
// key the HMAC and derive the first block in-place
|
||||||
let mut hmac =
|
let mut hmac = Prf::new_from_slice(pw).expect("HMAC is able to accept all key sizes");
|
||||||
Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
|
|
||||||
|
// U1 = PRF(Password, Salt + INT_32_BE(i))
|
||||||
|
// i = 1 since we only need 1 block of output.
|
||||||
hmac.update(salt);
|
hmac.update(salt);
|
||||||
hmac.update(&1u32.to_be_bytes());
|
hmac.update(&1u32.to_be_bytes());
|
||||||
let init_block = hmac.finalize_reset().into_bytes();
|
let init_block = hmac.finalize_reset().into_bytes();
|
||||||
|
|
||||||
|
// Prf::new_from_slice will run 2 sha256 rounds.
|
||||||
|
// Our update + finalize run 2 sha256 rounds for each pbkdf2 round.
|
||||||
|
Metrics::get().proxy.sha_rounds.inc_by(4);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
hmac,
|
hmac,
|
||||||
// one iteration spent above
|
// one iteration spent above
|
||||||
@@ -33,7 +58,11 @@ impl Pbkdf2 {
|
|||||||
(self.iterations).clamp(0, 4096)
|
(self.iterations).clamp(0, 4096)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn turn(&mut self) -> std::task::Poll<[u8; 32]> {
|
/// For "fairness", we implement PBKDF2 with cooperative yielding, which is why we use this `turn`
|
||||||
|
/// function that only executes a fixed number of iterations before continuing.
|
||||||
|
///
|
||||||
|
/// Task must be rescheuled if this returns [`std::task::Poll::Pending`].
|
||||||
|
pub(crate) fn turn(&mut self) -> std::task::Poll<Block> {
|
||||||
let Self {
|
let Self {
|
||||||
hmac,
|
hmac,
|
||||||
prev,
|
prev,
|
||||||
@@ -44,25 +73,37 @@ impl Pbkdf2 {
|
|||||||
// only do up to 4096 iterations per turn for fairness
|
// only do up to 4096 iterations per turn for fairness
|
||||||
let n = (*iterations).clamp(0, 4096);
|
let n = (*iterations).clamp(0, 4096);
|
||||||
for _ in 0..n {
|
for _ in 0..n {
|
||||||
hmac.update(prev);
|
let next = single_round(hmac, prev);
|
||||||
let block = hmac.finalize_reset().into_bytes();
|
xor_assign(hi, &next);
|
||||||
|
*prev = next;
|
||||||
for (hi_byte, &b) in hi.iter_mut().zip(block.iter()) {
|
|
||||||
*hi_byte ^= b;
|
|
||||||
}
|
|
||||||
|
|
||||||
*prev = block;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Our update + finalize run 2 sha256 rounds for each pbkdf2 round.
|
||||||
|
Metrics::get().proxy.sha_rounds.inc_by(2 * n as u64);
|
||||||
|
|
||||||
*iterations -= n;
|
*iterations -= n;
|
||||||
if *iterations == 0 {
|
if *iterations == 0 {
|
||||||
std::task::Poll::Ready((*hi).into())
|
std::task::Poll::Ready(*hi)
|
||||||
} else {
|
} else {
|
||||||
std::task::Poll::Pending
|
std::task::Poll::Pending
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn xor_assign(x: &mut Block, y: &Block) {
|
||||||
|
for (x, &y) in std::iter::zip(x, y) {
|
||||||
|
*x ^= y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn single_round(prf: &mut Prf, ui: &Block) -> Block {
|
||||||
|
// Ui = PRF(Password, Ui-1)
|
||||||
|
prf.update(ui);
|
||||||
|
prf.finalize_reset().into_bytes()
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use pbkdf2::pbkdf2_hmac_array;
|
use pbkdf2::pbkdf2_hmac_array;
|
||||||
@@ -76,11 +117,11 @@ mod tests {
|
|||||||
let pass = b"Ne0n_!5_50_C007";
|
let pass = b"Ne0n_!5_50_C007";
|
||||||
|
|
||||||
let mut job = Pbkdf2::start(pass, salt, 60000);
|
let mut job = Pbkdf2::start(pass, salt, 60000);
|
||||||
let hash = loop {
|
let hash: [u8; 32] = loop {
|
||||||
let std::task::Poll::Ready(hash) = job.turn() else {
|
let std::task::Poll::Ready(hash) = job.turn() else {
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
break hash;
|
break hash.into();
|
||||||
};
|
};
|
||||||
|
|
||||||
let expected = pbkdf2_hmac_array::<Sha256, 32>(pass, salt, 60000);
|
let expected = pbkdf2_hmac_array::<Sha256, 32>(pass, salt, 60000);
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
use base64::Engine as _;
|
use base64::Engine as _;
|
||||||
use base64::prelude::BASE64_STANDARD;
|
use base64::prelude::BASE64_STANDARD;
|
||||||
use subtle::{Choice, ConstantTimeEq};
|
use subtle::{Choice, ConstantTimeEq};
|
||||||
|
use tokio::time::Instant;
|
||||||
|
|
||||||
use super::base64_decode_array;
|
use super::base64_decode_array;
|
||||||
use super::key::ScramKey;
|
use super::key::ScramKey;
|
||||||
@@ -11,6 +12,9 @@ use super::key::ScramKey;
|
|||||||
/// and is used throughout the authentication process.
|
/// and is used throughout the authentication process.
|
||||||
#[derive(Clone, Eq, PartialEq, Debug)]
|
#[derive(Clone, Eq, PartialEq, Debug)]
|
||||||
pub(crate) struct ServerSecret {
|
pub(crate) struct ServerSecret {
|
||||||
|
/// When this secret was cached.
|
||||||
|
pub(crate) cached_at: Instant,
|
||||||
|
|
||||||
/// Number of iterations for `PBKDF2` function.
|
/// Number of iterations for `PBKDF2` function.
|
||||||
pub(crate) iterations: u32,
|
pub(crate) iterations: u32,
|
||||||
/// Salt used to hash user's password.
|
/// Salt used to hash user's password.
|
||||||
@@ -34,6 +38,7 @@ impl ServerSecret {
|
|||||||
params.split_once(':').zip(keys.split_once(':'))?;
|
params.split_once(':').zip(keys.split_once(':'))?;
|
||||||
|
|
||||||
let secret = ServerSecret {
|
let secret = ServerSecret {
|
||||||
|
cached_at: Instant::now(),
|
||||||
iterations: iterations.parse().ok()?,
|
iterations: iterations.parse().ok()?,
|
||||||
salt_base64: salt.into(),
|
salt_base64: salt.into(),
|
||||||
stored_key: base64_decode_array(stored_key)?.into(),
|
stored_key: base64_decode_array(stored_key)?.into(),
|
||||||
@@ -54,6 +59,7 @@ impl ServerSecret {
|
|||||||
/// See `auth-scram.c : mock_scram_secret` for details.
|
/// See `auth-scram.c : mock_scram_secret` for details.
|
||||||
pub(crate) fn mock(nonce: [u8; 32]) -> Self {
|
pub(crate) fn mock(nonce: [u8; 32]) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
cached_at: Instant::now(),
|
||||||
// this doesn't reveal much information as we're going to use
|
// this doesn't reveal much information as we're going to use
|
||||||
// iteration count 1 for our generated passwords going forward.
|
// iteration count 1 for our generated passwords going forward.
|
||||||
// PG16 users can set iteration count=1 already today.
|
// PG16 users can set iteration count=1 already today.
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
//! Tools for client/server signature management.
|
//! Tools for client/server signature management.
|
||||||
|
|
||||||
|
use hmac::Mac as _;
|
||||||
|
|
||||||
use super::key::{SCRAM_KEY_LEN, ScramKey};
|
use super::key::{SCRAM_KEY_LEN, ScramKey};
|
||||||
|
use crate::metrics::Metrics;
|
||||||
|
use crate::scram::pbkdf2::Prf;
|
||||||
|
|
||||||
/// A collection of message parts needed to derive the client's signature.
|
/// A collection of message parts needed to derive the client's signature.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@@ -12,15 +16,18 @@ pub(crate) struct SignatureBuilder<'a> {
|
|||||||
|
|
||||||
impl SignatureBuilder<'_> {
|
impl SignatureBuilder<'_> {
|
||||||
pub(crate) fn build(&self, key: &ScramKey) -> Signature {
|
pub(crate) fn build(&self, key: &ScramKey) -> Signature {
|
||||||
let parts = [
|
// don't know exactly. this is a rough approx
|
||||||
self.client_first_message_bare.as_bytes(),
|
Metrics::get().proxy.sha_rounds.inc_by(8);
|
||||||
b",",
|
|
||||||
self.server_first_message.as_bytes(),
|
|
||||||
b",",
|
|
||||||
self.client_final_message_without_proof.as_bytes(),
|
|
||||||
];
|
|
||||||
|
|
||||||
super::hmac_sha256(key.as_ref(), parts).into()
|
let mut mac = Prf::new_from_slice(key.as_ref()).expect("HMAC accepts all key sizes");
|
||||||
|
mac.update(self.client_first_message_bare.as_bytes());
|
||||||
|
mac.update(b",");
|
||||||
|
mac.update(self.server_first_message.as_bytes());
|
||||||
|
mac.update(b",");
|
||||||
|
mac.update(self.client_final_message_without_proof.as_bytes());
|
||||||
|
Signature {
|
||||||
|
bytes: mac.finalize().into_bytes().into(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,8 @@ use futures::FutureExt;
|
|||||||
use rand::rngs::SmallRng;
|
use rand::rngs::SmallRng;
|
||||||
use rand::{Rng, SeedableRng};
|
use rand::{Rng, SeedableRng};
|
||||||
|
|
||||||
|
use super::cache::Pbkdf2Cache;
|
||||||
|
use super::pbkdf2;
|
||||||
use super::pbkdf2::Pbkdf2;
|
use super::pbkdf2::Pbkdf2;
|
||||||
use crate::intern::EndpointIdInt;
|
use crate::intern::EndpointIdInt;
|
||||||
use crate::metrics::{ThreadPoolMetrics, ThreadPoolWorkerId};
|
use crate::metrics::{ThreadPoolMetrics, ThreadPoolWorkerId};
|
||||||
@@ -23,6 +25,10 @@ use crate::scram::countmin::CountMinSketch;
|
|||||||
pub struct ThreadPool {
|
pub struct ThreadPool {
|
||||||
runtime: Option<tokio::runtime::Runtime>,
|
runtime: Option<tokio::runtime::Runtime>,
|
||||||
pub metrics: Arc<ThreadPoolMetrics>,
|
pub metrics: Arc<ThreadPoolMetrics>,
|
||||||
|
|
||||||
|
// we hash a lot of passwords.
|
||||||
|
// we keep a cache of partial hashes for faster validation.
|
||||||
|
pub(super) cache: Pbkdf2Cache,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// How often to reset the sketch values
|
/// How often to reset the sketch values
|
||||||
@@ -68,6 +74,7 @@ impl ThreadPool {
|
|||||||
Self {
|
Self {
|
||||||
runtime: Some(runtime),
|
runtime: Some(runtime),
|
||||||
metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
|
metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
|
||||||
|
cache: Pbkdf2Cache::new(),
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -130,7 +137,7 @@ struct JobSpec {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Future for JobSpec {
|
impl Future for JobSpec {
|
||||||
type Output = [u8; 32];
|
type Output = pbkdf2::Block;
|
||||||
|
|
||||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
STATE.with_borrow_mut(|state| {
|
STATE.with_borrow_mut(|state| {
|
||||||
@@ -166,10 +173,10 @@ impl Future for JobSpec {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) struct JobHandle(tokio::task::JoinHandle<[u8; 32]>);
|
pub(crate) struct JobHandle(tokio::task::JoinHandle<pbkdf2::Block>);
|
||||||
|
|
||||||
impl Future for JobHandle {
|
impl Future for JobHandle {
|
||||||
type Output = [u8; 32];
|
type Output = pbkdf2::Block;
|
||||||
|
|
||||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
match self.0.poll_unpin(cx) {
|
match self.0.poll_unpin(cx) {
|
||||||
@@ -203,10 +210,10 @@ mod tests {
|
|||||||
.spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096))
|
.spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096))
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let expected = [
|
let expected = &[
|
||||||
10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242,
|
10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242,
|
||||||
178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140,
|
178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140,
|
||||||
];
|
];
|
||||||
assert_eq!(actual, expected);
|
assert_eq!(actual.as_slice(), expected);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ use crate::context::RequestContext;
|
|||||||
use crate::control_plane::client::ApiLockError;
|
use crate::control_plane::client::ApiLockError;
|
||||||
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
|
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
|
||||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||||
use crate::intern::EndpointIdInt;
|
use crate::intern::{EndpointIdInt, RoleNameInt};
|
||||||
use crate::pqproto::StartupMessageParams;
|
use crate::pqproto::StartupMessageParams;
|
||||||
use crate::proxy::{connect_auth, connect_compute};
|
use crate::proxy::{connect_auth, connect_compute};
|
||||||
use crate::rate_limiter::EndpointRateLimiter;
|
use crate::rate_limiter::EndpointRateLimiter;
|
||||||
@@ -76,9 +76,11 @@ impl PoolingBackend {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let ep = EndpointIdInt::from(&user_info.endpoint);
|
let ep = EndpointIdInt::from(&user_info.endpoint);
|
||||||
|
let role = RoleNameInt::from(&user_info.user);
|
||||||
let auth_outcome = crate::auth::validate_password_and_exchange(
|
let auth_outcome = crate::auth::validate_password_and_exchange(
|
||||||
&self.config.authentication_config.thread_pool,
|
&self.config.authentication_config.scram_thread_pool,
|
||||||
ep,
|
ep,
|
||||||
|
role,
|
||||||
password,
|
password,
|
||||||
secret,
|
secret,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user