mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-27 10:00:38 +00:00
Compare commits
7 Commits
bodobolero
...
heikki/upd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9e3a1ccbd8 | ||
|
|
d487ba2b9b | ||
|
|
e7a1d5de94 | ||
|
|
6be572177c | ||
|
|
fe7a4e1ab6 | ||
|
|
40cae8cc36 | ||
|
|
02fc8b7c70 |
6
.github/workflows/benchbase_tpcc.yml
vendored
6
.github/workflows/benchbase_tpcc.yml
vendored
@@ -31,15 +31,15 @@ jobs:
|
||||
include:
|
||||
- warehouses: 50 # defines number of warehouses and is used to compute number of terminals
|
||||
max_rate: 800 # measured max TPS at scale factor based on experiments. Adjust if performance is better/worse
|
||||
min_cu: 2 # simulate free tier plan (0.25 -2 CU)
|
||||
min_cu: 0.25 # simulate free tier plan (0.25 -2 CU)
|
||||
max_cu: 2
|
||||
- warehouses: 500 # serverless plan (2-8 CU)
|
||||
max_rate: 2000
|
||||
min_cu: 8
|
||||
min_cu: 2
|
||||
max_cu: 8
|
||||
- warehouses: 1000 # business plan (2-16 CU)
|
||||
max_rate: 2900
|
||||
min_cu: 16
|
||||
min_cu: 2
|
||||
max_cu: 16
|
||||
max-parallel: 1 # we want to run each workload size sequentially to avoid noisy neighbors
|
||||
permissions:
|
||||
|
||||
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -5078,7 +5078,6 @@ dependencies = [
|
||||
"criterion",
|
||||
"env_logger",
|
||||
"log",
|
||||
"memoffset 0.9.0",
|
||||
"once_cell",
|
||||
"postgres",
|
||||
"postgres_ffi_types",
|
||||
@@ -5519,6 +5518,7 @@ dependencies = [
|
||||
"workspace_hack",
|
||||
"x509-cert",
|
||||
"zerocopy 0.8.24",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -135,7 +135,6 @@ lock_api = "0.4.13"
|
||||
md5 = "0.7.0"
|
||||
measured = { version = "0.0.22", features=["lasso"] }
|
||||
measured-process = { version = "0.0.22" }
|
||||
memoffset = "0.9"
|
||||
moka = { version = "0.12", features = ["sync"] }
|
||||
nix = { version = "0.30.1", features = ["dir", "fs", "mman", "process", "socket", "signal", "poll"] }
|
||||
# Do not update to >= 7.0.0, at least. The update will have a significant impact
|
||||
@@ -234,9 +233,10 @@ uuid = { version = "1.6.1", features = ["v4", "v7", "serde"] }
|
||||
walkdir = "2.3.2"
|
||||
rustls-native-certs = "0.8"
|
||||
whoami = "1.5.1"
|
||||
zerocopy = { version = "0.8", features = ["derive", "simd"] }
|
||||
json-structural-diff = { version = "0.2.0" }
|
||||
x509-cert = { version = "0.2.5" }
|
||||
zerocopy = { version = "0.8", features = ["derive", "simd"] }
|
||||
zeroize = "1.8"
|
||||
|
||||
## TODO replace this with tracing
|
||||
env_logger = "0.11"
|
||||
|
||||
@@ -1633,6 +1633,12 @@ FROM pg-build-with-cargo AS neon-ext-build
|
||||
ARG PG_VERSION
|
||||
|
||||
USER root
|
||||
|
||||
# Update the rust toolchain. Running 'make' will do this, but better to do
|
||||
# it as a separately cacheable step.
|
||||
COPY rust-toolchain.toml .
|
||||
RUN rustup show
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN make -j $(getconf _NPROCESSORS_ONLN) -C pgxn -s install-compute \
|
||||
@@ -1731,6 +1737,12 @@ ARG BUILD_TAG
|
||||
ENV BUILD_TAG=$BUILD_TAG
|
||||
|
||||
USER nonroot
|
||||
|
||||
# Update the rust toolchain. Running 'cargo build' will do this, but
|
||||
# better to do it as a separately cacheable step.
|
||||
COPY --chown=nonroot rust-toolchain.toml .
|
||||
RUN rustup show
|
||||
|
||||
# Copy entire project to get Cargo.* files with proper dependencies for the whole project
|
||||
COPY --chown=nonroot . .
|
||||
RUN --mount=type=cache,uid=1000,target=/home/nonroot/.cargo/registry \
|
||||
|
||||
@@ -558,11 +558,11 @@ async fn add_request_id_header_to_response(
|
||||
mut res: Response<Body>,
|
||||
req_info: RequestInfo,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
if let Some(request_id) = req_info.context::<RequestId>() {
|
||||
if let Ok(request_header_value) = HeaderValue::from_str(&request_id.0) {
|
||||
res.headers_mut()
|
||||
.insert(&X_REQUEST_ID_HEADER, request_header_value);
|
||||
};
|
||||
if let Some(request_id) = req_info.context::<RequestId>()
|
||||
&& let Ok(request_header_value) = HeaderValue::from_str(&request_id.0)
|
||||
{
|
||||
res.headers_mut()
|
||||
.insert(&X_REQUEST_ID_HEADER, request_header_value);
|
||||
};
|
||||
|
||||
Ok(res)
|
||||
|
||||
@@ -72,10 +72,10 @@ impl Server {
|
||||
if err.is_incomplete_message() || err.is_closed() || err.is_timeout() {
|
||||
return true;
|
||||
}
|
||||
if let Some(inner) = err.source() {
|
||||
if let Some(io) = inner.downcast_ref::<std::io::Error>() {
|
||||
return suppress_io_error(io);
|
||||
}
|
||||
if let Some(inner) = err.source()
|
||||
&& let Some(io) = inner.downcast_ref::<std::io::Error>()
|
||||
{
|
||||
return suppress_io_error(io);
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
@@ -363,7 +363,7 @@ where
|
||||
// TODO: An Iterator might be nicer. The communicator's clock algorithm needs to
|
||||
// _slowly_ iterate through all buckets with its clock hand, without holding a lock.
|
||||
// If we switch to an Iterator, it must not hold the lock.
|
||||
pub fn get_at_bucket(&self, pos: usize) -> Option<ValueReadGuard<(K, V)>> {
|
||||
pub fn get_at_bucket(&self, pos: usize) -> Option<ValueReadGuard<'_, (K, V)>> {
|
||||
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
|
||||
if pos >= map.buckets.len() {
|
||||
return None;
|
||||
|
||||
@@ -12,7 +12,6 @@ crc32c.workspace = true
|
||||
criterion.workspace = true
|
||||
once_cell.workspace = true
|
||||
log.workspace = true
|
||||
memoffset.workspace = true
|
||||
pprof.workspace = true
|
||||
thiserror.workspace = true
|
||||
serde.workspace = true
|
||||
|
||||
@@ -34,9 +34,8 @@ const SIZEOF_CONTROLDATA: usize = size_of::<ControlFileData>();
|
||||
impl ControlFileData {
|
||||
/// Compute the offset of the `crc` field within the `ControlFileData` struct.
|
||||
/// Equivalent to offsetof(ControlFileData, crc) in C.
|
||||
// Someday this can be const when the right compiler features land.
|
||||
fn pg_control_crc_offset() -> usize {
|
||||
memoffset::offset_of!(ControlFileData, crc)
|
||||
const fn pg_control_crc_offset() -> usize {
|
||||
std::mem::offset_of!(ControlFileData, crc)
|
||||
}
|
||||
|
||||
///
|
||||
|
||||
@@ -49,7 +49,7 @@ impl PerfSpan {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn enter(&self) -> PerfSpanEntered {
|
||||
pub fn enter(&self) -> PerfSpanEntered<'_> {
|
||||
if let Some(ref id) = self.inner.id() {
|
||||
self.dispatch.enter(id);
|
||||
}
|
||||
|
||||
@@ -79,10 +79,6 @@
|
||||
#include "access/xlogrecovery.h"
|
||||
#endif
|
||||
|
||||
#if PG_VERSION_NUM < 160000
|
||||
typedef PGAlignedBlock PGIOAlignedBlock;
|
||||
#endif
|
||||
|
||||
#define NEON_PANIC_CONNECTION_STATE(shard_no, elvl, message, ...) \
|
||||
neon_shard_log(shard_no, elvl, "Broken connection state: " message, \
|
||||
##__VA_ARGS__)
|
||||
|
||||
@@ -635,6 +635,11 @@ lfc_init(void)
|
||||
NULL);
|
||||
}
|
||||
|
||||
/*
|
||||
* Dump a list of pages that are currently in the LFC
|
||||
*
|
||||
* This is used to get a snapshot that can be used to prewarm the LFC later.
|
||||
*/
|
||||
FileCacheState*
|
||||
lfc_get_state(size_t max_entries)
|
||||
{
|
||||
@@ -2267,4 +2272,3 @@ get_prewarm_info(PG_FUNCTION_ARGS)
|
||||
|
||||
PG_RETURN_DATUM(HeapTupleGetDatum(heap_form_tuple(tupdesc, values, nulls)));
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
/*-------------------------------------------------------------------------
|
||||
*
|
||||
* neon.c
|
||||
* Main entry point into the neon exension
|
||||
* Main entry point into the neon extension
|
||||
*
|
||||
*-------------------------------------------------------------------------
|
||||
*/
|
||||
@@ -508,7 +508,7 @@ _PG_init(void)
|
||||
|
||||
DefineCustomBoolVariable(
|
||||
"neon.disable_logical_replication_subscribers",
|
||||
"Disables incomming logical replication",
|
||||
"Disable incoming logical replication",
|
||||
NULL,
|
||||
&disable_logical_replication_subscribers,
|
||||
false,
|
||||
@@ -567,7 +567,7 @@ _PG_init(void)
|
||||
|
||||
DefineCustomEnumVariable(
|
||||
"neon.debug_compare_local",
|
||||
"Debug mode for compaing content of pages in prefetch ring/LFC/PS and local disk",
|
||||
"Debug mode for comparing content of pages in prefetch ring/LFC/PS and local disk",
|
||||
NULL,
|
||||
&debug_compare_local,
|
||||
DEBUG_COMPARE_LOCAL_NONE,
|
||||
@@ -735,7 +735,6 @@ neon_shmem_request_hook(void)
|
||||
static void
|
||||
neon_shmem_startup_hook(void)
|
||||
{
|
||||
/* Initialize */
|
||||
if (prev_shmem_startup_hook)
|
||||
prev_shmem_startup_hook();
|
||||
|
||||
|
||||
@@ -167,11 +167,7 @@ extern neon_per_backend_counters *neon_per_backend_counters_shared;
|
||||
*/
|
||||
#define NUM_NEON_PERF_COUNTER_SLOTS (MaxBackends + NUM_AUXILIARY_PROCS)
|
||||
|
||||
#if PG_VERSION_NUM >= 170000
|
||||
#define MyNeonCounters (&neon_per_backend_counters_shared[MyProcNumber])
|
||||
#else
|
||||
#define MyNeonCounters (&neon_per_backend_counters_shared[MyProc->pgprocno])
|
||||
#endif
|
||||
|
||||
extern void inc_getpage_wait(uint64 latency);
|
||||
extern void inc_page_cache_read_wait(uint64 latency);
|
||||
|
||||
@@ -9,6 +9,10 @@
|
||||
#include "fmgr.h"
|
||||
#include "storage/buf_internals.h"
|
||||
|
||||
#if PG_MAJORVERSION_NUM < 16
|
||||
typedef PGAlignedBlock PGIOAlignedBlock;
|
||||
#endif
|
||||
|
||||
#if PG_MAJORVERSION_NUM < 17
|
||||
#define NRelFileInfoBackendIsTemp(rinfo) (rinfo.backend != InvalidBackendId)
|
||||
#else
|
||||
@@ -158,6 +162,10 @@ InitBufferTag(BufferTag *tag, const RelFileNode *rnode,
|
||||
#define AmAutoVacuumWorkerProcess() (IsAutoVacuumWorkerProcess())
|
||||
#endif
|
||||
|
||||
#if PG_MAJORVERSION_NUM < 17
|
||||
#define MyProcNumber (MyProc - &ProcGlobal->allProcs[0])
|
||||
#endif
|
||||
|
||||
#if PG_MAJORVERSION_NUM < 15
|
||||
extern void InitMaterializedSRF(FunctionCallInfo fcinfo, bits32 flags);
|
||||
extern TimeLineID GetWALInsertionTimeLine(void);
|
||||
|
||||
@@ -72,10 +72,6 @@
|
||||
#include "access/xlogrecovery.h"
|
||||
#endif
|
||||
|
||||
#if PG_VERSION_NUM < 160000
|
||||
typedef PGAlignedBlock PGIOAlignedBlock;
|
||||
#endif
|
||||
|
||||
#include "access/nbtree.h"
|
||||
#include "storage/bufpage.h"
|
||||
#include "access/xlog_internal.h"
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "neon.h"
|
||||
#include "neon_pgversioncompat.h"
|
||||
|
||||
#include "miscadmin.h"
|
||||
#include "pagestore_client.h"
|
||||
#include RELFILEINFO_HDR
|
||||
#include "storage/smgr.h"
|
||||
@@ -23,10 +24,6 @@
|
||||
#include "utils/dynahash.h"
|
||||
#include "utils/guc.h"
|
||||
|
||||
#if PG_VERSION_NUM >= 150000
|
||||
#include "miscadmin.h"
|
||||
#endif
|
||||
|
||||
typedef struct
|
||||
{
|
||||
NRelFileInfo rinfo;
|
||||
|
||||
@@ -107,6 +107,7 @@ uuid.workspace = true
|
||||
x509-cert.workspace = true
|
||||
redis.workspace = true
|
||||
zerocopy.workspace = true
|
||||
zeroize.workspace = true
|
||||
# uncomment this to use the real subzero-core crate
|
||||
# subzero-core = { git = "https://github.com/neondatabase/subzero", rev = "396264617e78e8be428682f87469bb25429af88a", features = ["postgresql"], optional = true }
|
||||
# this is a stub for the subzero-core crate
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::auth::{self, AuthFlow};
|
||||
use crate::config::AuthenticationConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::AuthSecret;
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::intern::{EndpointIdInt, RoleNameInt};
|
||||
use crate::sasl;
|
||||
use crate::stream::{self, Stream};
|
||||
|
||||
@@ -25,13 +25,15 @@ pub(crate) async fn authenticate_cleartext(
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||
|
||||
let ep = EndpointIdInt::from(&info.endpoint);
|
||||
let role = RoleNameInt::from(&info.user);
|
||||
|
||||
let auth_flow = AuthFlow::new(
|
||||
client,
|
||||
auth::CleartextPassword {
|
||||
secret,
|
||||
endpoint: ep,
|
||||
pool: config.thread_pool.clone(),
|
||||
role,
|
||||
pool: config.scram_thread_pool.clone(),
|
||||
},
|
||||
);
|
||||
let auth_outcome = {
|
||||
|
||||
@@ -25,7 +25,7 @@ use crate::control_plane::messages::EndpointRateLimitConfig;
|
||||
use crate::control_plane::{
|
||||
self, AccessBlockerFlags, AuthSecret, ControlPlaneApi, EndpointAccessControl, RoleAccessControl,
|
||||
};
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::intern::{EndpointIdInt, RoleNameInt};
|
||||
use crate::pqproto::BeMessage;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::proxy::wake_compute::WakeComputeBackend;
|
||||
@@ -273,9 +273,11 @@ async fn authenticate_with_secret(
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
if let Some(password) = unauthenticated_password {
|
||||
let ep = EndpointIdInt::from(&info.endpoint);
|
||||
let role = RoleNameInt::from(&info.user);
|
||||
|
||||
let auth_outcome =
|
||||
validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
|
||||
validate_password_and_exchange(&config.scram_thread_pool, ep, role, &password, secret)
|
||||
.await?;
|
||||
let keys = match auth_outcome {
|
||||
crate::sasl::Outcome::Success(key) => key,
|
||||
crate::sasl::Outcome::Failure(reason) => {
|
||||
@@ -499,7 +501,7 @@ mod tests {
|
||||
|
||||
static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
|
||||
jwks_cache: JwkCache::default(),
|
||||
thread_pool: ThreadPool::new(1),
|
||||
scram_thread_pool: ThreadPool::new(1),
|
||||
scram_protocol_timeout: std::time::Duration::from_secs(5),
|
||||
ip_allowlist_check_enabled: true,
|
||||
is_vpc_acccess_proxy: false,
|
||||
|
||||
@@ -10,7 +10,7 @@ use super::backend::ComputeCredentialKeys;
|
||||
use super::{AuthError, PasswordHackPayload};
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::AuthSecret;
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::intern::{EndpointIdInt, RoleNameInt};
|
||||
use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
|
||||
use crate::sasl;
|
||||
use crate::scram::threadpool::ThreadPool;
|
||||
@@ -46,6 +46,7 @@ pub(crate) struct PasswordHack;
|
||||
pub(crate) struct CleartextPassword {
|
||||
pub(crate) pool: Arc<ThreadPool>,
|
||||
pub(crate) endpoint: EndpointIdInt,
|
||||
pub(crate) role: RoleNameInt,
|
||||
pub(crate) secret: AuthSecret,
|
||||
}
|
||||
|
||||
@@ -111,6 +112,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
|
||||
let outcome = validate_password_and_exchange(
|
||||
&self.state.pool,
|
||||
self.state.endpoint,
|
||||
self.state.role,
|
||||
password,
|
||||
self.state.secret,
|
||||
)
|
||||
@@ -165,13 +167,15 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
pub(crate) async fn validate_password_and_exchange(
|
||||
pool: &ThreadPool,
|
||||
endpoint: EndpointIdInt,
|
||||
role: RoleNameInt,
|
||||
password: &[u8],
|
||||
secret: AuthSecret,
|
||||
) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
|
||||
match secret {
|
||||
// perform scram authentication as both client and server to validate the keys
|
||||
AuthSecret::Scram(scram_secret) => {
|
||||
let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
|
||||
let outcome =
|
||||
crate::scram::exchange(pool, endpoint, role, &scram_secret, password).await?;
|
||||
|
||||
let client_key = match outcome {
|
||||
sasl::Outcome::Success(client_key) => client_key,
|
||||
|
||||
@@ -29,7 +29,7 @@ use crate::config::{
|
||||
};
|
||||
use crate::control_plane::locks::ApiLocks;
|
||||
use crate::http::health_server::AppMetrics;
|
||||
use crate::metrics::{Metrics, ServiceInfo, ThreadPoolMetrics};
|
||||
use crate::metrics::{Metrics, ServiceInfo};
|
||||
use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo};
|
||||
use crate::scram::threadpool::ThreadPool;
|
||||
use crate::serverless::cancel_set::CancelSet;
|
||||
@@ -114,8 +114,6 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
|
||||
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
|
||||
|
||||
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
|
||||
|
||||
// TODO: refactor these to use labels
|
||||
debug!("Version: {GIT_VERSION}");
|
||||
debug!("Build_tag: {BUILD_TAG}");
|
||||
@@ -284,7 +282,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
|
||||
http_config,
|
||||
authentication_config: AuthenticationConfig {
|
||||
jwks_cache: JwkCache::default(),
|
||||
thread_pool: ThreadPool::new(0),
|
||||
scram_thread_pool: ThreadPool::new(0),
|
||||
scram_protocol_timeout: Duration::from_secs(10),
|
||||
ip_allowlist_check_enabled: true,
|
||||
is_vpc_acccess_proxy: false,
|
||||
|
||||
@@ -26,7 +26,7 @@ use utils::project_git_version;
|
||||
use utils::sentry_init::init_sentry;
|
||||
|
||||
use crate::context::RequestContext;
|
||||
use crate::metrics::{Metrics, ServiceInfo, ThreadPoolMetrics};
|
||||
use crate::metrics::{Metrics, ServiceInfo};
|
||||
use crate::pglb::TlsRequired;
|
||||
use crate::pqproto::FeStartupPacket;
|
||||
use crate::protocol2::ConnectionInfo;
|
||||
@@ -80,8 +80,6 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
|
||||
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
|
||||
|
||||
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
|
||||
|
||||
let args = cli().get_matches();
|
||||
let destination: String = args
|
||||
.get_one::<String>("dest")
|
||||
|
||||
@@ -617,7 +617,12 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
/// ProxyConfig is created at proxy startup, and lives forever.
|
||||
fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
let thread_pool = ThreadPool::new(args.scram_thread_pool_size);
|
||||
Metrics::install(thread_pool.metrics.clone());
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.scram_pool
|
||||
.0
|
||||
.set(thread_pool.metrics.clone())
|
||||
.ok();
|
||||
|
||||
let tls_config = match (&args.tls_key, &args.tls_cert) {
|
||||
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(
|
||||
@@ -690,7 +695,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
};
|
||||
let authentication_config = AuthenticationConfig {
|
||||
jwks_cache: JwkCache::default(),
|
||||
thread_pool,
|
||||
scram_thread_pool: thread_pool,
|
||||
scram_protocol_timeout: args.scram_protocol_timeout,
|
||||
ip_allowlist_check_enabled: !args.is_private_access_proxy,
|
||||
is_vpc_acccess_proxy: args.is_private_access_proxy,
|
||||
|
||||
@@ -19,7 +19,7 @@ use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
|
||||
use crate::ext::TaskExt;
|
||||
use crate::intern::RoleNameInt;
|
||||
use crate::rate_limiter::{RateLimitAlgorithm, RateLimiterConfig};
|
||||
use crate::scram::threadpool::ThreadPool;
|
||||
use crate::scram;
|
||||
use crate::serverless::GlobalConnPoolOptions;
|
||||
use crate::serverless::cancel_set::CancelSet;
|
||||
#[cfg(feature = "rest_broker")]
|
||||
@@ -75,7 +75,7 @@ pub struct HttpConfig {
|
||||
}
|
||||
|
||||
pub struct AuthenticationConfig {
|
||||
pub thread_pool: Arc<ThreadPool>,
|
||||
pub scram_thread_pool: Arc<scram::threadpool::ThreadPool>,
|
||||
pub scram_protocol_timeout: tokio::time::Duration,
|
||||
pub ip_allowlist_check_enabled: bool,
|
||||
pub is_vpc_acccess_proxy: bool,
|
||||
|
||||
@@ -5,6 +5,7 @@ use measured::label::{
|
||||
FixedCardinalitySet, LabelGroupSet, LabelGroupVisitor, LabelName, LabelSet, LabelValue,
|
||||
StaticLabelSet,
|
||||
};
|
||||
use measured::metric::group::Encoding;
|
||||
use measured::metric::histogram::Thresholds;
|
||||
use measured::metric::name::MetricName;
|
||||
use measured::{
|
||||
@@ -18,10 +19,10 @@ use crate::control_plane::messages::ColdStartInfo;
|
||||
use crate::error::ErrorKind;
|
||||
|
||||
#[derive(MetricGroup)]
|
||||
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>))]
|
||||
#[metric(new())]
|
||||
pub struct Metrics {
|
||||
#[metric(namespace = "proxy")]
|
||||
#[metric(init = ProxyMetrics::new(thread_pool))]
|
||||
#[metric(init = ProxyMetrics::new())]
|
||||
pub proxy: ProxyMetrics,
|
||||
|
||||
#[metric(namespace = "wake_compute_lock")]
|
||||
@@ -34,34 +35,27 @@ pub struct Metrics {
|
||||
pub cache: CacheMetrics,
|
||||
}
|
||||
|
||||
static SELF: OnceLock<Metrics> = OnceLock::new();
|
||||
impl Metrics {
|
||||
pub fn install(thread_pool: Arc<ThreadPoolMetrics>) {
|
||||
let mut metrics = Metrics::new(thread_pool);
|
||||
|
||||
metrics.proxy.errors_total.init_all_dense();
|
||||
metrics.proxy.redis_errors_total.init_all_dense();
|
||||
metrics.proxy.redis_events_count.init_all_dense();
|
||||
metrics.proxy.retries_metric.init_all_dense();
|
||||
metrics.proxy.connection_failures_total.init_all_dense();
|
||||
|
||||
SELF.set(metrics)
|
||||
.ok()
|
||||
.expect("proxy metrics must not be installed more than once");
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
pub fn get() -> &'static Self {
|
||||
#[cfg(test)]
|
||||
return SELF.get_or_init(|| Metrics::new(Arc::new(ThreadPoolMetrics::new(0))));
|
||||
static SELF: OnceLock<Metrics> = OnceLock::new();
|
||||
|
||||
#[cfg(not(test))]
|
||||
SELF.get()
|
||||
.expect("proxy metrics must be installed by the main() function")
|
||||
SELF.get_or_init(|| {
|
||||
let mut metrics = Metrics::new();
|
||||
|
||||
metrics.proxy.errors_total.init_all_dense();
|
||||
metrics.proxy.redis_errors_total.init_all_dense();
|
||||
metrics.proxy.redis_events_count.init_all_dense();
|
||||
metrics.proxy.retries_metric.init_all_dense();
|
||||
metrics.proxy.connection_failures_total.init_all_dense();
|
||||
|
||||
metrics
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(MetricGroup)]
|
||||
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>))]
|
||||
#[metric(new())]
|
||||
pub struct ProxyMetrics {
|
||||
#[metric(flatten)]
|
||||
pub db_connections: CounterPairVec<NumDbConnectionsGauge>,
|
||||
@@ -134,6 +128,9 @@ pub struct ProxyMetrics {
|
||||
/// Number of TLS handshake failures
|
||||
pub tls_handshake_failures: Counter,
|
||||
|
||||
/// Number of SHA 256 rounds executed.
|
||||
pub sha_rounds: Counter,
|
||||
|
||||
/// HLL approximate cardinality of endpoints that are connecting
|
||||
pub connecting_endpoints: HyperLogLogVec<StaticLabelSet<Protocol>, 32>,
|
||||
|
||||
@@ -151,8 +148,25 @@ pub struct ProxyMetrics {
|
||||
pub connect_compute_lock: ApiLockMetrics,
|
||||
|
||||
#[metric(namespace = "scram_pool")]
|
||||
#[metric(init = thread_pool)]
|
||||
pub scram_pool: Arc<ThreadPoolMetrics>,
|
||||
pub scram_pool: OnceLockWrapper<Arc<ThreadPoolMetrics>>,
|
||||
}
|
||||
|
||||
/// A Wrapper over [`OnceLock`] to implement [`MetricGroup`].
|
||||
pub struct OnceLockWrapper<T>(pub OnceLock<T>);
|
||||
|
||||
impl<T> Default for OnceLockWrapper<T> {
|
||||
fn default() -> Self {
|
||||
Self(OnceLock::new())
|
||||
}
|
||||
}
|
||||
|
||||
impl<Enc: Encoding, T: MetricGroup<Enc>> MetricGroup<Enc> for OnceLockWrapper<T> {
|
||||
fn collect_group_into(&self, enc: &mut Enc) -> Result<(), Enc::Err> {
|
||||
if let Some(inner) = self.0.get() {
|
||||
inner.collect_group_into(enc)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(MetricGroup)]
|
||||
@@ -553,14 +567,6 @@ impl From<bool> for Bool {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(LabelGroup)]
|
||||
#[label(set = InvalidEndpointsSet)]
|
||||
pub struct InvalidEndpointsGroup {
|
||||
pub protocol: Protocol,
|
||||
pub rejected: Bool,
|
||||
pub outcome: ConnectOutcome,
|
||||
}
|
||||
|
||||
#[derive(LabelGroup)]
|
||||
#[label(set = RetriesMetricSet)]
|
||||
pub struct RetriesMetricGroup {
|
||||
@@ -727,6 +733,7 @@ pub enum CacheKind {
|
||||
ProjectInfoEndpoints,
|
||||
ProjectInfoRoles,
|
||||
Schema,
|
||||
Pbkdf2,
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]
|
||||
|
||||
84
proxy/src/scram/cache.rs
Normal file
84
proxy/src/scram/cache.rs
Normal file
@@ -0,0 +1,84 @@
|
||||
use tokio::time::Instant;
|
||||
use zeroize::Zeroize as _;
|
||||
|
||||
use super::pbkdf2;
|
||||
use crate::cache::Cached;
|
||||
use crate::cache::common::{Cache, count_cache_insert, count_cache_outcome, eviction_listener};
|
||||
use crate::intern::{EndpointIdInt, RoleNameInt};
|
||||
use crate::metrics::{CacheKind, Metrics};
|
||||
|
||||
pub(crate) struct Pbkdf2Cache(moka::sync::Cache<(EndpointIdInt, RoleNameInt), Pbkdf2CacheEntry>);
|
||||
pub(crate) type CachedPbkdf2<'a> = Cached<&'a Pbkdf2Cache>;
|
||||
|
||||
impl Cache for Pbkdf2Cache {
|
||||
type Key = (EndpointIdInt, RoleNameInt);
|
||||
type Value = Pbkdf2CacheEntry;
|
||||
|
||||
fn invalidate(&self, info: &(EndpointIdInt, RoleNameInt)) {
|
||||
self.0.invalidate(info);
|
||||
}
|
||||
}
|
||||
|
||||
/// To speed up password hashing for more active customers, we store the tail results of the
|
||||
/// PBKDF2 algorithm. If the output of PBKDF2 is U1 ^ U2 ^ ⋯ ^ Uc, then we store
|
||||
/// suffix = U17 ^ U18 ^ ⋯ ^ Uc. We only need to calculate U1 ^ U2 ^ ⋯ ^ U15 ^ U16
|
||||
/// to determine the final result.
|
||||
///
|
||||
/// The suffix alone isn't enough to crack the password. The stored_key is still required.
|
||||
/// While both are cached in memory, given they're in different locations is makes it much
|
||||
/// harder to exploit, even if any such memory exploit exists in proxy.
|
||||
#[derive(Clone)]
|
||||
pub struct Pbkdf2CacheEntry {
|
||||
/// corresponds to [`super::ServerSecret::cached_at`]
|
||||
pub(super) cached_from: Instant,
|
||||
pub(super) suffix: pbkdf2::Block,
|
||||
}
|
||||
|
||||
impl Drop for Pbkdf2CacheEntry {
|
||||
fn drop(&mut self) {
|
||||
self.suffix.zeroize();
|
||||
}
|
||||
}
|
||||
|
||||
impl Pbkdf2Cache {
|
||||
pub fn new() -> Self {
|
||||
const SIZE: u64 = 100;
|
||||
const TTL: std::time::Duration = std::time::Duration::from_secs(60);
|
||||
|
||||
let builder = moka::sync::Cache::builder()
|
||||
.name("pbkdf2")
|
||||
.max_capacity(SIZE)
|
||||
// We use time_to_live so we don't refresh the lifetime for an invalid password attempt.
|
||||
.time_to_live(TTL);
|
||||
|
||||
Metrics::get()
|
||||
.cache
|
||||
.capacity
|
||||
.set(CacheKind::Pbkdf2, SIZE as i64);
|
||||
|
||||
let builder =
|
||||
builder.eviction_listener(|_k, _v, cause| eviction_listener(CacheKind::Pbkdf2, cause));
|
||||
|
||||
Self(builder.build())
|
||||
}
|
||||
|
||||
pub fn insert(&self, endpoint: EndpointIdInt, role: RoleNameInt, value: Pbkdf2CacheEntry) {
|
||||
count_cache_insert(CacheKind::Pbkdf2);
|
||||
self.0.insert((endpoint, role), value);
|
||||
}
|
||||
|
||||
fn get(&self, endpoint: EndpointIdInt, role: RoleNameInt) -> Option<Pbkdf2CacheEntry> {
|
||||
count_cache_outcome(CacheKind::Pbkdf2, self.0.get(&(endpoint, role)))
|
||||
}
|
||||
|
||||
pub fn get_entry(
|
||||
&self,
|
||||
endpoint: EndpointIdInt,
|
||||
role: RoleNameInt,
|
||||
) -> Option<CachedPbkdf2<'_>> {
|
||||
self.get(endpoint, role).map(|value| Cached {
|
||||
token: Some((self, (endpoint, role))),
|
||||
value,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,10 +4,8 @@ use std::convert::Infallible;
|
||||
|
||||
use base64::Engine as _;
|
||||
use base64::prelude::BASE64_STANDARD;
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
use tracing::{debug, trace};
|
||||
|
||||
use super::ScramKey;
|
||||
use super::messages::{
|
||||
ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
|
||||
};
|
||||
@@ -15,8 +13,10 @@ use super::pbkdf2::Pbkdf2;
|
||||
use super::secret::ServerSecret;
|
||||
use super::signature::SignatureBuilder;
|
||||
use super::threadpool::ThreadPool;
|
||||
use crate::intern::EndpointIdInt;
|
||||
use super::{ScramKey, pbkdf2};
|
||||
use crate::intern::{EndpointIdInt, RoleNameInt};
|
||||
use crate::sasl::{self, ChannelBinding, Error as SaslError};
|
||||
use crate::scram::cache::Pbkdf2CacheEntry;
|
||||
|
||||
/// The only channel binding mode we currently support.
|
||||
#[derive(Debug)]
|
||||
@@ -77,46 +77,113 @@ impl<'a> Exchange<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
// copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L236-L248>
|
||||
async fn derive_client_key(
|
||||
pool: &ThreadPool,
|
||||
endpoint: EndpointIdInt,
|
||||
password: &[u8],
|
||||
salt: &[u8],
|
||||
iterations: u32,
|
||||
) -> ScramKey {
|
||||
let salted_password = pool
|
||||
.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
|
||||
.await;
|
||||
|
||||
let make_key = |name| {
|
||||
let key = Hmac::<Sha256>::new_from_slice(&salted_password)
|
||||
.expect("HMAC is able to accept all key sizes")
|
||||
.chain_update(name)
|
||||
.finalize();
|
||||
|
||||
<[u8; 32]>::from(key.into_bytes())
|
||||
};
|
||||
|
||||
make_key(b"Client Key").into()
|
||||
) -> pbkdf2::Block {
|
||||
pool.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
|
||||
.await
|
||||
}
|
||||
|
||||
/// For cleartext flow, we need to derive the client key to
|
||||
/// 1. authenticate the client.
|
||||
/// 2. authenticate with compute.
|
||||
pub(crate) async fn exchange(
|
||||
pool: &ThreadPool,
|
||||
endpoint: EndpointIdInt,
|
||||
role: RoleNameInt,
|
||||
secret: &ServerSecret,
|
||||
password: &[u8],
|
||||
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
|
||||
if secret.iterations > CACHED_ROUNDS {
|
||||
exchange_with_cache(pool, endpoint, role, secret, password).await
|
||||
} else {
|
||||
let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?;
|
||||
let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
|
||||
Ok(validate_pbkdf2(secret, &hash))
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the client key using a cache. We cache the suffix of the pbkdf2 result only,
|
||||
/// which is not enough by itself to perform an offline brute force.
|
||||
async fn exchange_with_cache(
|
||||
pool: &ThreadPool,
|
||||
endpoint: EndpointIdInt,
|
||||
role: RoleNameInt,
|
||||
secret: &ServerSecret,
|
||||
password: &[u8],
|
||||
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
|
||||
let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?;
|
||||
let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
|
||||
|
||||
debug_assert!(
|
||||
secret.iterations > CACHED_ROUNDS,
|
||||
"we should not cache password data if there isn't enough rounds needed"
|
||||
);
|
||||
|
||||
// compute the prefix of the pbkdf2 output.
|
||||
let prefix = derive_client_key(pool, endpoint, password, &salt, CACHED_ROUNDS).await;
|
||||
|
||||
if let Some(entry) = pool.cache.get_entry(endpoint, role) {
|
||||
// hot path: let's check the threadpool cache
|
||||
if secret.cached_at == entry.cached_from {
|
||||
// cache is valid. compute the full hash by adding the prefix to the suffix.
|
||||
let mut hash = prefix;
|
||||
pbkdf2::xor_assign(&mut hash, &entry.suffix);
|
||||
let outcome = validate_pbkdf2(secret, &hash);
|
||||
|
||||
if matches!(outcome, sasl::Outcome::Success(_)) {
|
||||
trace!("password validated from cache");
|
||||
}
|
||||
|
||||
return Ok(outcome);
|
||||
}
|
||||
|
||||
// cached key is no longer valid.
|
||||
debug!("invalidating cached password");
|
||||
entry.invalidate();
|
||||
}
|
||||
|
||||
// slow path: full password hash.
|
||||
let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
|
||||
let outcome = validate_pbkdf2(secret, &hash);
|
||||
|
||||
let client_key = match outcome {
|
||||
sasl::Outcome::Success(client_key) => client_key,
|
||||
sasl::Outcome::Failure(_) => return Ok(outcome),
|
||||
};
|
||||
|
||||
trace!("storing cached password");
|
||||
|
||||
// time to cache, compute the suffix by subtracting the prefix from the hash.
|
||||
let mut suffix = hash;
|
||||
pbkdf2::xor_assign(&mut suffix, &prefix);
|
||||
|
||||
pool.cache.insert(
|
||||
endpoint,
|
||||
role,
|
||||
Pbkdf2CacheEntry {
|
||||
cached_from: secret.cached_at,
|
||||
suffix,
|
||||
},
|
||||
);
|
||||
|
||||
Ok(sasl::Outcome::Success(client_key))
|
||||
}
|
||||
|
||||
fn validate_pbkdf2(secret: &ServerSecret, hash: &pbkdf2::Block) -> sasl::Outcome<ScramKey> {
|
||||
let client_key = super::ScramKey::client_key(&(*hash).into());
|
||||
if secret.is_password_invalid(&client_key).into() {
|
||||
Ok(sasl::Outcome::Failure("password doesn't match"))
|
||||
sasl::Outcome::Failure("password doesn't match")
|
||||
} else {
|
||||
Ok(sasl::Outcome::Success(client_key))
|
||||
sasl::Outcome::Success(client_key)
|
||||
}
|
||||
}
|
||||
|
||||
const CACHED_ROUNDS: u32 = 16;
|
||||
|
||||
impl SaslInitial {
|
||||
fn transition(
|
||||
&self,
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
//! Tools for client/server/stored key management.
|
||||
|
||||
use hmac::Mac as _;
|
||||
use sha2::Digest as _;
|
||||
use subtle::ConstantTimeEq;
|
||||
use zeroize::Zeroize as _;
|
||||
|
||||
use crate::metrics::Metrics;
|
||||
use crate::scram::pbkdf2::Prf;
|
||||
|
||||
/// Faithfully taken from PostgreSQL.
|
||||
pub(crate) const SCRAM_KEY_LEN: usize = 32;
|
||||
@@ -14,6 +20,12 @@ pub(crate) struct ScramKey {
|
||||
bytes: [u8; SCRAM_KEY_LEN],
|
||||
}
|
||||
|
||||
impl Drop for ScramKey {
|
||||
fn drop(&mut self) {
|
||||
self.bytes.zeroize();
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for ScramKey {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.ct_eq(other).into()
|
||||
@@ -28,12 +40,26 @@ impl ConstantTimeEq for ScramKey {
|
||||
|
||||
impl ScramKey {
|
||||
pub(crate) fn sha256(&self) -> Self {
|
||||
super::sha256([self.as_ref()]).into()
|
||||
Metrics::get().proxy.sha_rounds.inc_by(1);
|
||||
Self {
|
||||
bytes: sha2::Sha256::digest(self.as_bytes()).into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn as_bytes(&self) -> [u8; SCRAM_KEY_LEN] {
|
||||
self.bytes
|
||||
}
|
||||
|
||||
pub(crate) fn client_key(b: &[u8; 32]) -> Self {
|
||||
// Prf::new_from_slice will run 2 sha256 rounds.
|
||||
// Update + Finalize run 2 sha256 rounds.
|
||||
Metrics::get().proxy.sha_rounds.inc_by(4);
|
||||
|
||||
let mut prf = Prf::new_from_slice(b).expect("HMAC is able to accept all key sizes");
|
||||
prf.update(b"Client Key");
|
||||
let client_key: [u8; 32] = prf.finalize().into_bytes().into();
|
||||
client_key.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u8; SCRAM_KEY_LEN]> for ScramKey {
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-scram.c>
|
||||
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth-scram.c>
|
||||
|
||||
mod cache;
|
||||
mod countmin;
|
||||
mod exchange;
|
||||
mod key;
|
||||
@@ -18,10 +19,8 @@ pub mod threadpool;
|
||||
use base64::Engine as _;
|
||||
use base64::prelude::BASE64_STANDARD;
|
||||
pub(crate) use exchange::{Exchange, exchange};
|
||||
use hmac::{Hmac, Mac};
|
||||
pub(crate) use key::ScramKey;
|
||||
pub(crate) use secret::ServerSecret;
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
|
||||
const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
|
||||
@@ -42,29 +41,13 @@ fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N
|
||||
Some(bytes)
|
||||
}
|
||||
|
||||
/// This function essentially is `Hmac(sha256, key, input)`.
|
||||
/// Further reading: <https://datatracker.ietf.org/doc/html/rfc2104>.
|
||||
fn hmac_sha256<'a>(key: &[u8], parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
|
||||
let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("bad key size");
|
||||
parts.into_iter().for_each(|s| mac.update(s));
|
||||
|
||||
mac.finalize().into_bytes().into()
|
||||
}
|
||||
|
||||
fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
|
||||
let mut hasher = Sha256::new();
|
||||
parts.into_iter().for_each(|s| hasher.update(s));
|
||||
|
||||
hasher.finalize().into()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::threadpool::ThreadPool;
|
||||
use super::{Exchange, ServerSecret};
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::intern::{EndpointIdInt, RoleNameInt};
|
||||
use crate::sasl::{Mechanism, Step};
|
||||
use crate::types::EndpointId;
|
||||
use crate::types::{EndpointId, RoleName};
|
||||
|
||||
#[test]
|
||||
fn snapshot() {
|
||||
@@ -114,23 +97,34 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
async fn run_round_trip_test(server_password: &str, client_password: &str) {
|
||||
let pool = ThreadPool::new(1);
|
||||
|
||||
async fn check(
|
||||
pool: &ThreadPool,
|
||||
scram_secret: &ServerSecret,
|
||||
password: &[u8],
|
||||
) -> Result<(), &'static str> {
|
||||
let ep = EndpointId::from("foo");
|
||||
let ep = EndpointIdInt::from(ep);
|
||||
let role = RoleName::from("user");
|
||||
let role = RoleNameInt::from(&role);
|
||||
|
||||
let scram_secret = ServerSecret::build(server_password).await.unwrap();
|
||||
let outcome = super::exchange(&pool, ep, &scram_secret, client_password.as_bytes())
|
||||
let outcome = super::exchange(pool, ep, role, scram_secret, password)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match outcome {
|
||||
crate::sasl::Outcome::Success(_) => {}
|
||||
crate::sasl::Outcome::Failure(r) => panic!("{r}"),
|
||||
crate::sasl::Outcome::Success(_) => Ok(()),
|
||||
crate::sasl::Outcome::Failure(r) => Err(r),
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_round_trip_test(server_password: &str, client_password: &str) {
|
||||
let pool = ThreadPool::new(1);
|
||||
let scram_secret = ServerSecret::build(server_password).await.unwrap();
|
||||
check(&pool, &scram_secret, client_password.as_bytes())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn round_trip() {
|
||||
run_round_trip_test("pencil", "pencil").await;
|
||||
@@ -141,4 +135,27 @@ mod tests {
|
||||
async fn failure() {
|
||||
run_round_trip_test("pencil", "eraser").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[tracing_test::traced_test]
|
||||
async fn password_cache() {
|
||||
let pool = ThreadPool::new(1);
|
||||
let scram_secret = ServerSecret::build("password").await.unwrap();
|
||||
|
||||
// wrong passwords are not added to cache
|
||||
check(&pool, &scram_secret, b"wrong").await.unwrap_err();
|
||||
assert!(!logs_contain("storing cached password"));
|
||||
|
||||
// correct passwords get cached
|
||||
check(&pool, &scram_secret, b"password").await.unwrap();
|
||||
assert!(logs_contain("storing cached password"));
|
||||
|
||||
// wrong passwords do not match the cache
|
||||
check(&pool, &scram_secret, b"wrong").await.unwrap_err();
|
||||
assert!(!logs_contain("password validated from cache"));
|
||||
|
||||
// correct passwords match the cache
|
||||
check(&pool, &scram_secret, b"password").await.unwrap();
|
||||
assert!(logs_contain("password validated from cache"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,25 +1,50 @@
|
||||
//! For postgres password authentication, we need to perform a PBKDF2 using
|
||||
//! PRF=HMAC-SHA2-256, producing only 1 block (32 bytes) of output key.
|
||||
|
||||
use hmac::Mac as _;
|
||||
use hmac::digest::consts::U32;
|
||||
use hmac::digest::generic_array::GenericArray;
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
use zeroize::Zeroize as _;
|
||||
|
||||
use crate::metrics::Metrics;
|
||||
|
||||
/// The Psuedo-random function used during PBKDF2 and the SCRAM-SHA-256 handshake.
|
||||
pub type Prf = hmac::Hmac<sha2::Sha256>;
|
||||
pub(crate) type Block = GenericArray<u8, U32>;
|
||||
|
||||
pub(crate) struct Pbkdf2 {
|
||||
hmac: Hmac<Sha256>,
|
||||
prev: GenericArray<u8, U32>,
|
||||
hi: GenericArray<u8, U32>,
|
||||
hmac: Prf,
|
||||
/// U{r-1} for whatever iteration r we are currently on.
|
||||
prev: Block,
|
||||
/// the output of `fold(xor, U{1}..U{r})` for whatever iteration r we are currently on.
|
||||
hi: Block,
|
||||
/// number of iterations left
|
||||
iterations: u32,
|
||||
}
|
||||
|
||||
impl Drop for Pbkdf2 {
|
||||
fn drop(&mut self) {
|
||||
self.prev.zeroize();
|
||||
self.hi.zeroize();
|
||||
}
|
||||
}
|
||||
|
||||
// inspired from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L36-L61>
|
||||
impl Pbkdf2 {
|
||||
pub(crate) fn start(str: &[u8], salt: &[u8], iterations: u32) -> Self {
|
||||
pub(crate) fn start(pw: &[u8], salt: &[u8], iterations: u32) -> Self {
|
||||
// key the HMAC and derive the first block in-place
|
||||
let mut hmac =
|
||||
Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
|
||||
let mut hmac = Prf::new_from_slice(pw).expect("HMAC is able to accept all key sizes");
|
||||
|
||||
// U1 = PRF(Password, Salt + INT_32_BE(i))
|
||||
// i = 1 since we only need 1 block of output.
|
||||
hmac.update(salt);
|
||||
hmac.update(&1u32.to_be_bytes());
|
||||
let init_block = hmac.finalize_reset().into_bytes();
|
||||
|
||||
// Prf::new_from_slice will run 2 sha256 rounds.
|
||||
// Our update + finalize run 2 sha256 rounds for each pbkdf2 round.
|
||||
Metrics::get().proxy.sha_rounds.inc_by(4);
|
||||
|
||||
Self {
|
||||
hmac,
|
||||
// one iteration spent above
|
||||
@@ -33,7 +58,11 @@ impl Pbkdf2 {
|
||||
(self.iterations).clamp(0, 4096)
|
||||
}
|
||||
|
||||
pub(crate) fn turn(&mut self) -> std::task::Poll<[u8; 32]> {
|
||||
/// For "fairness", we implement PBKDF2 with cooperative yielding, which is why we use this `turn`
|
||||
/// function that only executes a fixed number of iterations before continuing.
|
||||
///
|
||||
/// Task must be rescheuled if this returns [`std::task::Poll::Pending`].
|
||||
pub(crate) fn turn(&mut self) -> std::task::Poll<Block> {
|
||||
let Self {
|
||||
hmac,
|
||||
prev,
|
||||
@@ -44,25 +73,37 @@ impl Pbkdf2 {
|
||||
// only do up to 4096 iterations per turn for fairness
|
||||
let n = (*iterations).clamp(0, 4096);
|
||||
for _ in 0..n {
|
||||
hmac.update(prev);
|
||||
let block = hmac.finalize_reset().into_bytes();
|
||||
|
||||
for (hi_byte, &b) in hi.iter_mut().zip(block.iter()) {
|
||||
*hi_byte ^= b;
|
||||
}
|
||||
|
||||
*prev = block;
|
||||
let next = single_round(hmac, prev);
|
||||
xor_assign(hi, &next);
|
||||
*prev = next;
|
||||
}
|
||||
|
||||
// Our update + finalize run 2 sha256 rounds for each pbkdf2 round.
|
||||
Metrics::get().proxy.sha_rounds.inc_by(2 * n as u64);
|
||||
|
||||
*iterations -= n;
|
||||
if *iterations == 0 {
|
||||
std::task::Poll::Ready((*hi).into())
|
||||
std::task::Poll::Ready(*hi)
|
||||
} else {
|
||||
std::task::Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn xor_assign(x: &mut Block, y: &Block) {
|
||||
for (x, &y) in std::iter::zip(x, y) {
|
||||
*x ^= y;
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn single_round(prf: &mut Prf, ui: &Block) -> Block {
|
||||
// Ui = PRF(Password, Ui-1)
|
||||
prf.update(ui);
|
||||
prf.finalize_reset().into_bytes()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use pbkdf2::pbkdf2_hmac_array;
|
||||
@@ -76,11 +117,11 @@ mod tests {
|
||||
let pass = b"Ne0n_!5_50_C007";
|
||||
|
||||
let mut job = Pbkdf2::start(pass, salt, 60000);
|
||||
let hash = loop {
|
||||
let hash: [u8; 32] = loop {
|
||||
let std::task::Poll::Ready(hash) = job.turn() else {
|
||||
continue;
|
||||
};
|
||||
break hash;
|
||||
break hash.into();
|
||||
};
|
||||
|
||||
let expected = pbkdf2_hmac_array::<Sha256, 32>(pass, salt, 60000);
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
use base64::Engine as _;
|
||||
use base64::prelude::BASE64_STANDARD;
|
||||
use subtle::{Choice, ConstantTimeEq};
|
||||
use tokio::time::Instant;
|
||||
|
||||
use super::base64_decode_array;
|
||||
use super::key::ScramKey;
|
||||
@@ -11,6 +12,9 @@ use super::key::ScramKey;
|
||||
/// and is used throughout the authentication process.
|
||||
#[derive(Clone, Eq, PartialEq, Debug)]
|
||||
pub(crate) struct ServerSecret {
|
||||
/// When this secret was cached.
|
||||
pub(crate) cached_at: Instant,
|
||||
|
||||
/// Number of iterations for `PBKDF2` function.
|
||||
pub(crate) iterations: u32,
|
||||
/// Salt used to hash user's password.
|
||||
@@ -34,6 +38,7 @@ impl ServerSecret {
|
||||
params.split_once(':').zip(keys.split_once(':'))?;
|
||||
|
||||
let secret = ServerSecret {
|
||||
cached_at: Instant::now(),
|
||||
iterations: iterations.parse().ok()?,
|
||||
salt_base64: salt.into(),
|
||||
stored_key: base64_decode_array(stored_key)?.into(),
|
||||
@@ -54,6 +59,7 @@ impl ServerSecret {
|
||||
/// See `auth-scram.c : mock_scram_secret` for details.
|
||||
pub(crate) fn mock(nonce: [u8; 32]) -> Self {
|
||||
Self {
|
||||
cached_at: Instant::now(),
|
||||
// this doesn't reveal much information as we're going to use
|
||||
// iteration count 1 for our generated passwords going forward.
|
||||
// PG16 users can set iteration count=1 already today.
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
//! Tools for client/server signature management.
|
||||
|
||||
use hmac::Mac as _;
|
||||
|
||||
use super::key::{SCRAM_KEY_LEN, ScramKey};
|
||||
use crate::metrics::Metrics;
|
||||
use crate::scram::pbkdf2::Prf;
|
||||
|
||||
/// A collection of message parts needed to derive the client's signature.
|
||||
#[derive(Debug)]
|
||||
@@ -12,15 +16,18 @@ pub(crate) struct SignatureBuilder<'a> {
|
||||
|
||||
impl SignatureBuilder<'_> {
|
||||
pub(crate) fn build(&self, key: &ScramKey) -> Signature {
|
||||
let parts = [
|
||||
self.client_first_message_bare.as_bytes(),
|
||||
b",",
|
||||
self.server_first_message.as_bytes(),
|
||||
b",",
|
||||
self.client_final_message_without_proof.as_bytes(),
|
||||
];
|
||||
// don't know exactly. this is a rough approx
|
||||
Metrics::get().proxy.sha_rounds.inc_by(8);
|
||||
|
||||
super::hmac_sha256(key.as_ref(), parts).into()
|
||||
let mut mac = Prf::new_from_slice(key.as_ref()).expect("HMAC accepts all key sizes");
|
||||
mac.update(self.client_first_message_bare.as_bytes());
|
||||
mac.update(b",");
|
||||
mac.update(self.server_first_message.as_bytes());
|
||||
mac.update(b",");
|
||||
mac.update(self.client_final_message_without_proof.as_bytes());
|
||||
Signature {
|
||||
bytes: mac.finalize().into_bytes().into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,8 @@ use futures::FutureExt;
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
use super::cache::Pbkdf2Cache;
|
||||
use super::pbkdf2;
|
||||
use super::pbkdf2::Pbkdf2;
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::metrics::{ThreadPoolMetrics, ThreadPoolWorkerId};
|
||||
@@ -23,6 +25,10 @@ use crate::scram::countmin::CountMinSketch;
|
||||
pub struct ThreadPool {
|
||||
runtime: Option<tokio::runtime::Runtime>,
|
||||
pub metrics: Arc<ThreadPoolMetrics>,
|
||||
|
||||
// we hash a lot of passwords.
|
||||
// we keep a cache of partial hashes for faster validation.
|
||||
pub(super) cache: Pbkdf2Cache,
|
||||
}
|
||||
|
||||
/// How often to reset the sketch values
|
||||
@@ -68,6 +74,7 @@ impl ThreadPool {
|
||||
Self {
|
||||
runtime: Some(runtime),
|
||||
metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
|
||||
cache: Pbkdf2Cache::new(),
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -130,7 +137,7 @@ struct JobSpec {
|
||||
}
|
||||
|
||||
impl Future for JobSpec {
|
||||
type Output = [u8; 32];
|
||||
type Output = pbkdf2::Block;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
STATE.with_borrow_mut(|state| {
|
||||
@@ -166,10 +173,10 @@ impl Future for JobSpec {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct JobHandle(tokio::task::JoinHandle<[u8; 32]>);
|
||||
pub(crate) struct JobHandle(tokio::task::JoinHandle<pbkdf2::Block>);
|
||||
|
||||
impl Future for JobHandle {
|
||||
type Output = [u8; 32];
|
||||
type Output = pbkdf2::Block;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match self.0.poll_unpin(cx) {
|
||||
@@ -203,10 +210,10 @@ mod tests {
|
||||
.spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096))
|
||||
.await;
|
||||
|
||||
let expected = [
|
||||
let expected = &[
|
||||
10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242,
|
||||
178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140,
|
||||
];
|
||||
assert_eq!(actual, expected);
|
||||
assert_eq!(actual.as_slice(), expected);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ use crate::context::RequestContext;
|
||||
use crate::control_plane::client::ApiLockError;
|
||||
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::intern::{EndpointIdInt, RoleNameInt};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::{connect_auth, connect_compute};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
@@ -76,9 +76,11 @@ impl PoolingBackend {
|
||||
};
|
||||
|
||||
let ep = EndpointIdInt::from(&user_info.endpoint);
|
||||
let role = RoleNameInt::from(&user_info.user);
|
||||
let auth_outcome = crate::auth::validate_password_and_exchange(
|
||||
&self.config.authentication_config.thread_pool,
|
||||
&self.config.authentication_config.scram_thread_pool,
|
||||
ep,
|
||||
role,
|
||||
password,
|
||||
secret,
|
||||
)
|
||||
|
||||
@@ -102,7 +102,7 @@ pub struct ReportedError {
|
||||
}
|
||||
|
||||
impl ReportedError {
|
||||
pub fn new(e: (impl UserFacingError + Into<anyhow::Error>)) -> Self {
|
||||
pub fn new(e: impl UserFacingError + Into<anyhow::Error>) -> Self {
|
||||
let error_kind = e.get_error_kind();
|
||||
Self {
|
||||
source: e.into(),
|
||||
|
||||
@@ -298,15 +298,26 @@ def test_pageserver_metrics_removed_after_detach(neon_env_builder: NeonEnvBuilde
|
||||
assert post_detach_samples == set()
|
||||
|
||||
|
||||
def test_pageserver_metrics_removed_after_offload(neon_env_builder: NeonEnvBuilder):
|
||||
@pytest.mark.parametrize("compaction", ["compaction_enabled", "compaction_disabled"])
|
||||
def test_pageserver_metrics_removed_after_offload(
|
||||
neon_env_builder: NeonEnvBuilder, compaction: str
|
||||
):
|
||||
"""Tests that when a timeline is offloaded, the tenant specific metrics are not left behind"""
|
||||
|
||||
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.MOCK_S3)
|
||||
|
||||
neon_env_builder.num_safekeepers = 3
|
||||
|
||||
env = neon_env_builder.init_start()
|
||||
tenant_1, _ = env.create_tenant()
|
||||
tenant_1, _ = env.create_tenant(
|
||||
conf={
|
||||
# disable background compaction and GC so that we don't have leftover tasks
|
||||
# after offloading.
|
||||
"gc_period": "0s",
|
||||
"compaction_period": "0s",
|
||||
}
|
||||
if compaction == "compaction_disabled"
|
||||
else None
|
||||
)
|
||||
|
||||
timeline_1 = env.create_timeline("test_metrics_removed_after_offload_1", tenant_id=tenant_1)
|
||||
timeline_2 = env.create_timeline("test_metrics_removed_after_offload_2", tenant_id=tenant_1)
|
||||
@@ -351,6 +362,23 @@ def test_pageserver_metrics_removed_after_offload(neon_env_builder: NeonEnvBuild
|
||||
state=TimelineArchivalState.ARCHIVED,
|
||||
)
|
||||
env.pageserver.http_client().timeline_offload(tenant_1, timeline)
|
||||
# We need to wait until all background jobs are finished before we can check the metrics.
|
||||
# There're many of them: compaction, GC, etc.
|
||||
wait_until(
|
||||
lambda: all(
|
||||
sample.value == 0
|
||||
for sample in env.pageserver.http_client()
|
||||
.get_metrics()
|
||||
.query_all("pageserver_background_loop_semaphore_waiting_tasks")
|
||||
)
|
||||
and all(
|
||||
sample.value == 0
|
||||
for sample in env.pageserver.http_client()
|
||||
.get_metrics()
|
||||
.query_all("pageserver_background_loop_semaphore_running_tasks")
|
||||
)
|
||||
)
|
||||
|
||||
post_offload_samples = set(
|
||||
[x.name for x in get_ps_metric_samples_for_timeline(tenant_1, timeline)]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user