Compare commits

..

4 Commits

Author SHA1 Message Date
Heikki Linnakangas
9e3a1ccbd8 Update rust toolchain as a separately cacheable step
Avoids repeatedly downloading the toolchain when iterating on compute
parts locally, and probably helps on the CI too.
2025-07-29 16:17:03 +03:00
Heikki Linnakangas
d487ba2b9b Replace 'memoffset' crate with core functionality (#12761)
The `std::mem::offset_of` macro was introduced in Rust 1.77.0.

In the passing, mark the function as `const`, as suggested in the
comment. Not sure which compiler version that requires, but it works
with what have currently.
2025-07-29 08:01:31 +00:00
Conrad Ludgate
e7a1d5de94 proxy: cache for password hashing (#12011)
## Problem

Password hashing for sql-over-http takes up a lot of CPU. Perhaps we can
get away with temporarily caching some steps so we only need fewer
rounds, which will save some CPU time.

## Summary of changes

The output of pbkdf2 is the XOR of the outputs of each iteration round,
eg `U1 ^ U2 ^ ... U15 ^ U16 ^ U17 ^ ... ^ Un`. We cache the suffix of
the expression `U16 ^ U17 ^ ... ^ Un`. To compute the result from the
cached suffix, we only need to compute the prefix `U1 ^ U2 ^ ... U15`.
The suffix by itself is useless, which prevent's its use in brute-force
attacks should this cached memory leak.

We are also caching the full 4096 round hash in memory, which can be
used for brute-force attacks, where this suffix could be used to speed
it up. My hope/expectation is that since these will be in different
allocations, it makes any such memory exploitation much much harder.
Since the full hash cache might be invalidated while the suffix is
cached, I'm storing the timestamp of the computation as a way to
identity the match.

I also added `zeroize()` to clear the sensitive state from the
stack/heap.

For the most security conscious customers, we hope to roll out OIDC
soon, so they can disable passwords entirely.

---

The numbers for the threadpool were pretty random, but according to our
busiest region for sql-over-http, we only see about 150 unique endpoints
every minute. So storing ~100 of the most common endpoints for that
minute should be the vast majority of requests.

1 minute was chosen so we don't keep data in memory for too long.
2025-07-29 06:48:14 +00:00
Ivan Efremov
6be572177c chore: Fix nightly lints (#12746)
- Remove some unused code
- Use `is_multiple_of()` instead of '%'
- Collapse consecuative "if let" statements
- Elided lifetime fixes

It is enough just to review the code of your team
2025-07-28 21:36:30 +00:00
37 changed files with 469 additions and 323 deletions

2
Cargo.lock generated
View File

@@ -5078,7 +5078,6 @@ dependencies = [
"criterion",
"env_logger",
"log",
"memoffset 0.9.0",
"once_cell",
"postgres",
"postgres_ffi_types",
@@ -5519,6 +5518,7 @@ dependencies = [
"workspace_hack",
"x509-cert",
"zerocopy 0.8.24",
"zeroize",
]
[[package]]

View File

@@ -135,7 +135,6 @@ lock_api = "0.4.13"
md5 = "0.7.0"
measured = { version = "0.0.22", features=["lasso"] }
measured-process = { version = "0.0.22" }
memoffset = "0.9"
moka = { version = "0.12", features = ["sync"] }
nix = { version = "0.30.1", features = ["dir", "fs", "mman", "process", "socket", "signal", "poll"] }
# Do not update to >= 7.0.0, at least. The update will have a significant impact
@@ -234,9 +233,10 @@ uuid = { version = "1.6.1", features = ["v4", "v7", "serde"] }
walkdir = "2.3.2"
rustls-native-certs = "0.8"
whoami = "1.5.1"
zerocopy = { version = "0.8", features = ["derive", "simd"] }
json-structural-diff = { version = "0.2.0" }
x509-cert = { version = "0.2.5" }
zerocopy = { version = "0.8", features = ["derive", "simd"] }
zeroize = "1.8"
## TODO replace this with tracing
env_logger = "0.11"

View File

@@ -1633,6 +1633,12 @@ FROM pg-build-with-cargo AS neon-ext-build
ARG PG_VERSION
USER root
# Update the rust toolchain. Running 'make' will do this, but better to do
# it as a separately cacheable step.
COPY rust-toolchain.toml .
RUN rustup show
COPY . .
RUN make -j $(getconf _NPROCESSORS_ONLN) -C pgxn -s install-compute \
@@ -1731,6 +1737,12 @@ ARG BUILD_TAG
ENV BUILD_TAG=$BUILD_TAG
USER nonroot
# Update the rust toolchain. Running 'cargo build' will do this, but
# better to do it as a separately cacheable step.
COPY --chown=nonroot rust-toolchain.toml .
RUN rustup show
# Copy entire project to get Cargo.* files with proper dependencies for the whole project
COPY --chown=nonroot . .
RUN --mount=type=cache,uid=1000,target=/home/nonroot/.cargo/registry \

View File

@@ -558,11 +558,11 @@ async fn add_request_id_header_to_response(
mut res: Response<Body>,
req_info: RequestInfo,
) -> Result<Response<Body>, ApiError> {
if let Some(request_id) = req_info.context::<RequestId>() {
if let Ok(request_header_value) = HeaderValue::from_str(&request_id.0) {
res.headers_mut()
.insert(&X_REQUEST_ID_HEADER, request_header_value);
};
if let Some(request_id) = req_info.context::<RequestId>()
&& let Ok(request_header_value) = HeaderValue::from_str(&request_id.0)
{
res.headers_mut()
.insert(&X_REQUEST_ID_HEADER, request_header_value);
};
Ok(res)

View File

@@ -72,10 +72,10 @@ impl Server {
if err.is_incomplete_message() || err.is_closed() || err.is_timeout() {
return true;
}
if let Some(inner) = err.source() {
if let Some(io) = inner.downcast_ref::<std::io::Error>() {
return suppress_io_error(io);
}
if let Some(inner) = err.source()
&& let Some(io) = inner.downcast_ref::<std::io::Error>()
{
return suppress_io_error(io);
}
false
}

View File

@@ -363,7 +363,7 @@ where
// TODO: An Iterator might be nicer. The communicator's clock algorithm needs to
// _slowly_ iterate through all buckets with its clock hand, without holding a lock.
// If we switch to an Iterator, it must not hold the lock.
pub fn get_at_bucket(&self, pos: usize) -> Option<ValueReadGuard<(K, V)>> {
pub fn get_at_bucket(&self, pos: usize) -> Option<ValueReadGuard<'_, (K, V)>> {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
if pos >= map.buckets.len() {
return None;

View File

@@ -12,7 +12,6 @@ crc32c.workspace = true
criterion.workspace = true
once_cell.workspace = true
log.workspace = true
memoffset.workspace = true
pprof.workspace = true
thiserror.workspace = true
serde.workspace = true

View File

@@ -34,9 +34,8 @@ const SIZEOF_CONTROLDATA: usize = size_of::<ControlFileData>();
impl ControlFileData {
/// Compute the offset of the `crc` field within the `ControlFileData` struct.
/// Equivalent to offsetof(ControlFileData, crc) in C.
// Someday this can be const when the right compiler features land.
fn pg_control_crc_offset() -> usize {
memoffset::offset_of!(ControlFileData, crc)
const fn pg_control_crc_offset() -> usize {
std::mem::offset_of!(ControlFileData, crc)
}
///

View File

@@ -49,7 +49,7 @@ impl PerfSpan {
}
}
pub fn enter(&self) -> PerfSpanEntered {
pub fn enter(&self) -> PerfSpanEntered<'_> {
if let Some(ref id) = self.inner.id() {
self.dispatch.enter(id);
}

View File

@@ -48,8 +48,6 @@ DATA = \
neon--1.3--1.4.sql \
neon--1.4--1.5.sql \
neon--1.5--1.6.sql \
neon--1.6--1.7.sql \
neon--1.7--1.6.sql \
neon--1.6--1.5.sql \
neon--1.5--1.4.sql \
neon--1.4--1.3.sql \

View File

@@ -260,7 +260,7 @@ typedef struct PrefetchState
/* the buffers */
prfh_hash *prf_hash;
int max_unflushed_shard_no;
int max_shard_no;
/* Mark shards involved in prefetch */
uint8 shard_bitmap[(MAX_SHARDS + 7)/8];
PrefetchRequest prf_buffer[]; /* prefetch buffers */
@@ -300,7 +300,6 @@ static void prefetch_do_request(PrefetchRequest *slot, neon_request_lsns *force_
static bool prefetch_wait_for(uint64 ring_index);
static void prefetch_cleanup_trailing_unused(void);
static inline void prefetch_set_unused(uint64 ring_index);
static bool prefetch_flush_requests(void);
static bool neon_prefetch_response_usable(neon_request_lsns *request_lsns,
PrefetchRequest *slot);
@@ -470,26 +469,13 @@ communicator_prefetch_pump_state(void)
{
START_PREFETCH_RECEIVE_WORK();
if (MyPState->ring_receive == MyPState->ring_flush && MyPState->ring_flush < MyPState->ring_unused)
{
/*
* Flush request to avoid requests pending for arbitrary long time,
* pinning LSN and holding GC at PS.
*/
if (!prefetch_flush_requests())
{
END_PREFETCH_RECEIVE_WORK();
return;
}
}
while (MyPState->ring_receive != MyPState->ring_flush)
{
NeonResponse *response;
PrefetchRequest *slot;
MemoryContext old;
uint64 my_ring_index = MyPState->ring_receive;
slot = GetPrfSlot(my_ring_index);
slot = GetPrfSlot(MyPState->ring_receive);
old = MemoryContextSwitchTo(MyPState->errctx);
response = page_server->try_receive(slot->shard_no);
@@ -503,12 +489,12 @@ communicator_prefetch_pump_state(void)
/* The slot should still be valid */
if (slot->status != PRFS_REQUESTED ||
slot->response != NULL ||
slot->my_ring_index != my_ring_index)
slot->my_ring_index != MyPState->ring_receive)
{
neon_shard_log(slot->shard_no, PANIC,
"Incorrect prefetch slot state after receive: status=%d response=%p my=" UINT64_FORMAT " receive=" UINT64_FORMAT "",
slot->status, slot->response,
slot->my_ring_index, my_ring_index);
slot->my_ring_index, MyPState->ring_receive);
}
/* update prefetch state */
MyPState->n_responses_buffered += 1;
@@ -536,19 +522,6 @@ communicator_prefetch_pump_state(void)
END_PREFETCH_RECEIVE_WORK();
if (RecoveryInProgress())
{
/*
* Update backend's min in-flight prefetch LSN.
*/
XLogRecPtr min_backend_prefetch_lsn = last_replay_lsn != InvalidXLogRecPtr ? last_replay_lsn : GetXLogReplayRecPtr(NULL);
for (uint64_t ring_index = MyPState->ring_receive; ring_index < MyPState->ring_unused; ring_index++)
{
PrefetchRequest* slot = GetPrfSlot(ring_index);
min_backend_prefetch_lsn = Min(slot->request_lsns.request_lsn, min_backend_prefetch_lsn);
}
MIN_BACKEND_REQUEST_LSN = min_backend_prefetch_lsn;
}
communicator_reconfigure_timeout_if_needed();
}
@@ -588,7 +561,7 @@ readahead_buffer_resize(int newsize, void *extra)
newPState->ring_last = newsize;
newPState->ring_unused = newsize;
newPState->ring_receive = newsize;
newPState->max_unflushed_shard_no = MyPState->max_unflushed_shard_no;
newPState->max_shard_no = MyPState->max_shard_no;
memcpy(newPState->shard_bitmap, MyPState->shard_bitmap, sizeof(MyPState->shard_bitmap));
/*
@@ -688,7 +661,6 @@ consume_prefetch_responses(void)
{
if (MyPState->ring_receive < MyPState->ring_unused)
prefetch_wait_for(MyPState->ring_unused - 1);
/*
* We know for sure we're not working on any prefetch pages after
* this.
@@ -718,7 +690,7 @@ prefetch_cleanup_trailing_unused(void)
static bool
prefetch_flush_requests(void)
{
for (shardno_t shard_no = 0; shard_no < MyPState->max_unflushed_shard_no; shard_no++)
for (shardno_t shard_no = 0; shard_no < MyPState->max_shard_no; shard_no++)
{
if (BITMAP_ISSET(MyPState->shard_bitmap, shard_no))
{
@@ -727,8 +699,7 @@ prefetch_flush_requests(void)
BITMAP_CLR(MyPState->shard_bitmap, shard_no);
}
}
MyPState->max_unflushed_shard_no = 0;
MyPState->ring_flush = MyPState->ring_unused;
MyPState->max_shard_no = 0;
return true;
}
@@ -752,6 +723,7 @@ prefetch_wait_for(uint64 ring_index)
{
if (!prefetch_flush_requests())
return false;
MyPState->ring_flush = MyPState->ring_unused;
}
Assert(MyPState->ring_unused > ring_index);
@@ -830,7 +802,6 @@ prefetch_read(PrefetchRequest *slot)
old = MemoryContextSwitchTo(MyPState->errctx);
response = (NeonResponse *) page_server->receive(shard_no);
MemoryContextSwitchTo(old);
if (response)
{
check_getpage_response(slot, response);
@@ -1039,16 +1010,11 @@ prefetch_do_request(PrefetchRequest *slot, neon_request_lsns *force_request_lsns
Assert(mySlotNo == MyPState->ring_unused);
if (force_request_lsns)
{
slot->request_lsns = *force_request_lsns;
}
else
{
neon_get_request_lsns(BufTagGetNRelFileInfo(slot->buftag),
slot->buftag.forkNum, slot->buftag.blockNum,
&slot->request_lsns, 1);
last_replay_lsn = InvalidXLogRecPtr;
}
request.hdr.lsn = slot->request_lsns.request_lsn;
request.hdr.not_modified_since = slot->request_lsns.not_modified_since;
@@ -1067,7 +1033,7 @@ prefetch_do_request(PrefetchRequest *slot, neon_request_lsns *force_request_lsns
MyPState->n_unused -= 1;
MyPState->ring_unused += 1;
BITMAP_SET(MyPState->shard_bitmap, slot->shard_no);
MyPState->max_unflushed_shard_no = Max(slot->shard_no+1, MyPState->max_unflushed_shard_no);
MyPState->max_shard_no = Max(slot->shard_no+1, MyPState->max_shard_no);
/* update slot state */
slot->status = PRFS_REQUESTED;
@@ -1075,25 +1041,6 @@ prefetch_do_request(PrefetchRequest *slot, neon_request_lsns *force_request_lsns
Assert(!found);
}
/*
* Check that returned page LSN is consistent with request lsns
*/
static void
check_page_lsn(NeonGetPageResponse* resp)
{
if (neon_protocol_version < 3) /* no information to check */
return;
if (PageGetLSN(resp->page) > resp->req.hdr.not_modified_since)
neon_log(PANIC, "Invalid getpage response version: %X/%08X is higher than last modified LSN %X/%08X",
LSN_FORMAT_ARGS(PageGetLSN(resp->page)),
LSN_FORMAT_ARGS(resp->req.hdr.not_modified_since));
if (PageGetLSN(resp->page) > resp->req.hdr.lsn)
neon_log(PANIC, "Invalid getpage response version: %X/%08X is higher than request LSN %X/%08X",
LSN_FORMAT_ARGS(PageGetLSN(resp->page)),
LSN_FORMAT_ARGS(resp->req.hdr.lsn));
}
/*
* Lookup of already received prefetch requests. Only already received responses matching required LSNs are accepted.
* Present pages are marked in "mask" bitmap and total number of such pages is returned.
@@ -1117,7 +1064,7 @@ communicator_prefetch_lookupv(NRelFileInfo rinfo, ForkNumber forknum, BlockNumbe
for (int i = 0; i < nblocks; i++)
{
PrfHashEntry *entry;
NeonGetPageResponse* resp;
hashkey.buftag.blockNum = blocknum + i;
entry = prfh_lookup(MyPState->prf_hash, &hashkey);
@@ -1150,9 +1097,8 @@ communicator_prefetch_lookupv(NRelFileInfo rinfo, ForkNumber forknum, BlockNumbe
continue;
}
Assert(slot->response->tag == T_NeonGetPageResponse); /* checked by check_getpage_response when response was assigned to the slot */
resp = (NeonGetPageResponse*)slot->response;
check_page_lsn(resp);
memcpy(buffers[i], resp->page, BLCKSZ);
memcpy(buffers[i], ((NeonGetPageResponse*)slot->response)->page, BLCKSZ);
/*
* With lfc_store_prefetch_result=true prefetch result is stored in LFC in prefetch_pump_state when response is received
@@ -1445,6 +1391,7 @@ Retry:
*/
goto Retry;
}
MyPState->ring_flush = MyPState->ring_unused;
}
return last_ring_index;
@@ -1514,12 +1461,10 @@ page_server_request(void const *req)
MyNeonCounters->pageserver_open_requests--;
} while (resp == NULL);
cancel_before_shmem_exit(prefetch_on_exit, Int32GetDatum(shard_no));
last_replay_lsn = InvalidXLogRecPtr;
}
PG_CATCH();
{
cancel_before_shmem_exit(prefetch_on_exit, Int32GetDatum(shard_no));
last_replay_lsn = InvalidXLogRecPtr;
/* Nothing should cancel disconnect: we should not leave connection in opaque state */
HOLD_INTERRUPTS();
page_server->disconnect(shard_no);
@@ -1919,13 +1864,6 @@ nm_to_string(NeonMessage *msg)
return s.data;
}
static void
reset_min_request_lsn(int code, Datum arg)
{
if (MyProcNumber != -1)
MIN_BACKEND_REQUEST_LSN = InvalidXLogRecPtr;
}
/*
* communicator_init() -- Initialize per-backend private state
*/
@@ -1937,8 +1875,6 @@ communicator_init(void)
if (MyPState != NULL)
return;
before_shmem_exit(reset_min_request_lsn, 0);
/*
* Sanity check that theperf counters array is sized correctly. We got
* this wrong once, and the formula for max number of backends and aux
@@ -1948,7 +1884,7 @@ communicator_init(void)
* the check here. That's OK, we don't expect the logic to change in old
* releases.
*/
#if PG_MAJORVERSION_NUM >= 15
#if PG_VERSION_NUM>=150000
if (MyNeonCounters >= &neon_per_backend_counters_shared[NUM_NEON_PERF_COUNTER_SLOTS])
elog(ERROR, "MyNeonCounters points past end of array");
#endif
@@ -2287,7 +2223,6 @@ Retry:
case T_NeonGetPageResponse:
{
NeonGetPageResponse* getpage_resp = (NeonGetPageResponse *) resp;
check_page_lsn(getpage_resp);
memcpy(buffer, getpage_resp->page, BLCKSZ);
/*
@@ -2564,30 +2499,12 @@ communicator_reconfigure_timeout_if_needed(void)
!AmPrewarmWorker && /* do not pump prefetch state in prewarm worker */
readahead_getpage_pull_timeout_ms > 0;
if (!needs_set && MIN_BACKEND_REQUEST_LSN != InvalidXLogRecPtr)
{
if (last_replay_lsn == InvalidXLogRecPtr)
MIN_BACKEND_REQUEST_LSN = InvalidXLogRecPtr;
else
needs_set = true; /* Can not reset MIN_BACKEND_REQUEST_LSN now, have to do it later */
}
if (needs_set != timeout_set)
{
/*
* The background writer/checkpointer doens't (shouldn't) read any pages.
* And definitely they should not run on replica.
* The only case when we can get here is replica promotion.
*/
if (AmBackgroundWriterProcess() || AmCheckpointerProcess())
{
MIN_BACKEND_REQUEST_LSN = InvalidXLogRecPtr;
if (timeout_set)
{
disable_timeout(PS_TIMEOUT_ID, false);
timeout_set = false;
}
return;
}
/* The background writer doens't (shouldn't) read any pages */
Assert(!AmBackgroundWriterProcess());
/* The checkpointer doens't (shouldn't) read any pages */
Assert(!AmCheckpointerProcess());
if (unlikely(PS_TIMEOUT_ID == 0))
{
@@ -2620,6 +2537,14 @@ communicator_reconfigure_timeout_if_needed(void)
static void
pagestore_timeout_handler(void)
{
#if PG_MAJORVERSION_NUM <= 14
/*
* PG14: Setting a repeating timeout is not possible, so we signal here
* that the timeout has already been reset, and by telling the system
* that system will re-schedule it later if we need to.
*/
timeout_set = false;
#endif
timeout_signaled = true;
InterruptPending = true;
}
@@ -2639,14 +2564,6 @@ communicator_processinterrupts(void)
if (!readpage_reentrant_guard && readahead_getpage_pull_timeout_ms > 0)
communicator_prefetch_pump_state();
#if PG_MAJORVERSION_NUM <= 14
/*
* PG14: Setting a repeating timeout is not possible, so we signal here
* that the timeout has already been reset, and by telling the system
* that system will re-schedule it later if we need to.
*/
timeout_set = false;
#endif
timeout_signaled = false;
communicator_reconfigure_timeout_if_needed();
}
@@ -2656,28 +2573,3 @@ communicator_processinterrupts(void)
return prev_interrupt_cb();
}
PG_FUNCTION_INFO_V1(neon_communicator_min_inflight_request_lsn);
Datum
neon_communicator_min_inflight_request_lsn(PG_FUNCTION_ARGS)
{
if (RecoveryInProgress())
{
/* Do not hold GC for primary */
PG_RETURN_INT64(UINT64_MAX);
}
else
{
XLogRecPtr min_lsn = GetXLogReplayRecPtr(NULL);
size_t n_procs = ProcGlobal->allProcCount;
for (size_t i = 0; i < n_procs; i++)
{
if (neon_per_backend_counters_shared[i].min_request_lsn != InvalidXLogRecPtr)
{
min_lsn = Min(min_lsn, neon_per_backend_counters_shared[i].min_request_lsn);
}
}
PG_RETURN_INT64(min_lsn);
}
}

View File

@@ -1,3 +0,0 @@
create function neon_communicator_min_inflight_request_lsn() returns pg_catalog.pg_lsn
AS 'MODULE_PATHNAME', 'neon_communicator_min_inflight_request_lsn'
LANGUAGE C;

View File

@@ -1 +0,0 @@
drop function neon_communicator_min_inflight_request_lsn();

View File

@@ -42,6 +42,7 @@ NeonPerfCountersShmemRequest(void)
}
void
NeonPerfCountersShmemInit(void)
{

View File

@@ -154,11 +154,6 @@ typedef struct
* Histogram of query execution time.
*/
QTHistogramData query_time_hist;
/*
* Minimal LSN of in-fligth request requests
*/
XLogRecPtr min_request_lsn;
} neon_per_backend_counters;
/* Pointer to the shared memory array of neon_per_backend_counters structs */
@@ -174,12 +169,6 @@ extern neon_per_backend_counters *neon_per_backend_counters_shared;
#define MyNeonCounters (&neon_per_backend_counters_shared[MyProcNumber])
/*
* Backend-local minimal in-flight request LSN.
* We store it in neon_per_backend_counters_shared and not in separate array to minimize false cache sharing
*/
#define MIN_BACKEND_REQUEST_LSN MyNeonCounters->min_request_lsn
extern void inc_getpage_wait(uint64 latency);
extern void inc_page_cache_read_wait(uint64 latency);
extern void inc_page_cache_write_wait(uint64 latency);

View File

@@ -243,7 +243,6 @@ extern char *neon_timeline;
extern char *neon_tenant;
extern int32 max_cluster_size;
extern int neon_protocol_version;
extern XLogRecPtr last_replay_lsn;
extern shardno_t get_shard_number(BufferTag* tag);

View File

@@ -96,8 +96,6 @@ typedef enum
int debug_compare_local;
XLogRecPtr last_replay_lsn;
static NRelFileInfo unlogged_build_rel_info;
static UnloggedBuildPhase unlogged_build_phase = UNLOGGED_BUILD_NOT_IN_PROGRESS;
@@ -161,7 +159,7 @@ log_newpages_copy(NRelFileInfo * rinfo, ForkNumber forkNum, BlockNumber blkno,
page_std);
}
return GetXLogInsertRecPtr();
return ProcLastRecPtr;
}
#endif /* PG_MAJORVERSION_NUM >= 17 */
@@ -590,17 +588,6 @@ neon_get_request_lsns(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
/* Request the page at the end of the last fully replayed LSN. */
XLogRecPtr replay_lsn = GetXLogReplayRecPtr(NULL);
if (MIN_BACKEND_REQUEST_LSN == InvalidXLogRecPtr)
{
/* mark the backend's replay_lsn as "we have a request ongoing", blocking the expiration of any current LSN */
MIN_BACKEND_REQUEST_LSN = replay_lsn;
/* make sure memory operations are in correct order, even in concurrent systems */
pg_memory_barrier();
/* get the current LSN to register */
replay_lsn = GetXLogReplayRecPtr(NULL);
MIN_BACKEND_REQUEST_LSN = replay_lsn;
}
last_replay_lsn = replay_lsn;
for (int i = 0; i < nblocks; i++)
{
neon_request_lsns *result = &output[i];

View File

@@ -107,6 +107,7 @@ uuid.workspace = true
x509-cert.workspace = true
redis.workspace = true
zerocopy.workspace = true
zeroize.workspace = true
# uncomment this to use the real subzero-core crate
# subzero-core = { git = "https://github.com/neondatabase/subzero", rev = "396264617e78e8be428682f87469bb25429af88a", features = ["postgresql"], optional = true }
# this is a stub for the subzero-core crate

View File

@@ -6,7 +6,7 @@ use crate::auth::{self, AuthFlow};
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::AuthSecret;
use crate::intern::EndpointIdInt;
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::sasl;
use crate::stream::{self, Stream};
@@ -25,13 +25,15 @@ pub(crate) async fn authenticate_cleartext(
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
let ep = EndpointIdInt::from(&info.endpoint);
let role = RoleNameInt::from(&info.user);
let auth_flow = AuthFlow::new(
client,
auth::CleartextPassword {
secret,
endpoint: ep,
pool: config.thread_pool.clone(),
role,
pool: config.scram_thread_pool.clone(),
},
);
let auth_outcome = {

View File

@@ -25,7 +25,7 @@ use crate::control_plane::messages::EndpointRateLimitConfig;
use crate::control_plane::{
self, AccessBlockerFlags, AuthSecret, ControlPlaneApi, EndpointAccessControl, RoleAccessControl,
};
use crate::intern::EndpointIdInt;
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::pqproto::BeMessage;
use crate::proxy::NeonOptions;
use crate::proxy::wake_compute::WakeComputeBackend;
@@ -273,9 +273,11 @@ async fn authenticate_with_secret(
) -> auth::Result<ComputeCredentials> {
if let Some(password) = unauthenticated_password {
let ep = EndpointIdInt::from(&info.endpoint);
let role = RoleNameInt::from(&info.user);
let auth_outcome =
validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
validate_password_and_exchange(&config.scram_thread_pool, ep, role, &password, secret)
.await?;
let keys = match auth_outcome {
crate::sasl::Outcome::Success(key) => key,
crate::sasl::Outcome::Failure(reason) => {
@@ -499,7 +501,7 @@ mod tests {
static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool: ThreadPool::new(1),
scram_thread_pool: ThreadPool::new(1),
scram_protocol_timeout: std::time::Duration::from_secs(5),
ip_allowlist_check_enabled: true,
is_vpc_acccess_proxy: false,

View File

@@ -10,7 +10,7 @@ use super::backend::ComputeCredentialKeys;
use super::{AuthError, PasswordHackPayload};
use crate::context::RequestContext;
use crate::control_plane::AuthSecret;
use crate::intern::EndpointIdInt;
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
use crate::sasl;
use crate::scram::threadpool::ThreadPool;
@@ -46,6 +46,7 @@ pub(crate) struct PasswordHack;
pub(crate) struct CleartextPassword {
pub(crate) pool: Arc<ThreadPool>,
pub(crate) endpoint: EndpointIdInt,
pub(crate) role: RoleNameInt,
pub(crate) secret: AuthSecret,
}
@@ -111,6 +112,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
let outcome = validate_password_and_exchange(
&self.state.pool,
self.state.endpoint,
self.state.role,
password,
self.state.secret,
)
@@ -165,13 +167,15 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
pub(crate) async fn validate_password_and_exchange(
pool: &ThreadPool,
endpoint: EndpointIdInt,
role: RoleNameInt,
password: &[u8],
secret: AuthSecret,
) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
match secret {
// perform scram authentication as both client and server to validate the keys
AuthSecret::Scram(scram_secret) => {
let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
let outcome =
crate::scram::exchange(pool, endpoint, role, &scram_secret, password).await?;
let client_key = match outcome {
sasl::Outcome::Success(client_key) => client_key,

View File

@@ -29,7 +29,7 @@ use crate::config::{
};
use crate::control_plane::locks::ApiLocks;
use crate::http::health_server::AppMetrics;
use crate::metrics::{Metrics, ServiceInfo, ThreadPoolMetrics};
use crate::metrics::{Metrics, ServiceInfo};
use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo};
use crate::scram::threadpool::ThreadPool;
use crate::serverless::cancel_set::CancelSet;
@@ -114,8 +114,6 @@ pub async fn run() -> anyhow::Result<()> {
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
// TODO: refactor these to use labels
debug!("Version: {GIT_VERSION}");
debug!("Build_tag: {BUILD_TAG}");
@@ -284,7 +282,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
http_config,
authentication_config: AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool: ThreadPool::new(0),
scram_thread_pool: ThreadPool::new(0),
scram_protocol_timeout: Duration::from_secs(10),
ip_allowlist_check_enabled: true,
is_vpc_acccess_proxy: false,

View File

@@ -26,7 +26,7 @@ use utils::project_git_version;
use utils::sentry_init::init_sentry;
use crate::context::RequestContext;
use crate::metrics::{Metrics, ServiceInfo, ThreadPoolMetrics};
use crate::metrics::{Metrics, ServiceInfo};
use crate::pglb::TlsRequired;
use crate::pqproto::FeStartupPacket;
use crate::protocol2::ConnectionInfo;
@@ -80,8 +80,6 @@ pub async fn run() -> anyhow::Result<()> {
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
let args = cli().get_matches();
let destination: String = args
.get_one::<String>("dest")

View File

@@ -617,7 +617,12 @@ pub async fn run() -> anyhow::Result<()> {
/// ProxyConfig is created at proxy startup, and lives forever.
fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
let thread_pool = ThreadPool::new(args.scram_thread_pool_size);
Metrics::install(thread_pool.metrics.clone());
Metrics::get()
.proxy
.scram_pool
.0
.set(thread_pool.metrics.clone())
.ok();
let tls_config = match (&args.tls_key, &args.tls_cert) {
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(
@@ -690,7 +695,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
};
let authentication_config = AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool,
scram_thread_pool: thread_pool,
scram_protocol_timeout: args.scram_protocol_timeout,
ip_allowlist_check_enabled: !args.is_private_access_proxy,
is_vpc_acccess_proxy: args.is_private_access_proxy,

View File

@@ -19,7 +19,7 @@ use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use crate::ext::TaskExt;
use crate::intern::RoleNameInt;
use crate::rate_limiter::{RateLimitAlgorithm, RateLimiterConfig};
use crate::scram::threadpool::ThreadPool;
use crate::scram;
use crate::serverless::GlobalConnPoolOptions;
use crate::serverless::cancel_set::CancelSet;
#[cfg(feature = "rest_broker")]
@@ -75,7 +75,7 @@ pub struct HttpConfig {
}
pub struct AuthenticationConfig {
pub thread_pool: Arc<ThreadPool>,
pub scram_thread_pool: Arc<scram::threadpool::ThreadPool>,
pub scram_protocol_timeout: tokio::time::Duration,
pub ip_allowlist_check_enabled: bool,
pub is_vpc_acccess_proxy: bool,

View File

@@ -5,6 +5,7 @@ use measured::label::{
FixedCardinalitySet, LabelGroupSet, LabelGroupVisitor, LabelName, LabelSet, LabelValue,
StaticLabelSet,
};
use measured::metric::group::Encoding;
use measured::metric::histogram::Thresholds;
use measured::metric::name::MetricName;
use measured::{
@@ -18,10 +19,10 @@ use crate::control_plane::messages::ColdStartInfo;
use crate::error::ErrorKind;
#[derive(MetricGroup)]
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>))]
#[metric(new())]
pub struct Metrics {
#[metric(namespace = "proxy")]
#[metric(init = ProxyMetrics::new(thread_pool))]
#[metric(init = ProxyMetrics::new())]
pub proxy: ProxyMetrics,
#[metric(namespace = "wake_compute_lock")]
@@ -34,34 +35,27 @@ pub struct Metrics {
pub cache: CacheMetrics,
}
static SELF: OnceLock<Metrics> = OnceLock::new();
impl Metrics {
pub fn install(thread_pool: Arc<ThreadPoolMetrics>) {
let mut metrics = Metrics::new(thread_pool);
metrics.proxy.errors_total.init_all_dense();
metrics.proxy.redis_errors_total.init_all_dense();
metrics.proxy.redis_events_count.init_all_dense();
metrics.proxy.retries_metric.init_all_dense();
metrics.proxy.connection_failures_total.init_all_dense();
SELF.set(metrics)
.ok()
.expect("proxy metrics must not be installed more than once");
}
#[track_caller]
pub fn get() -> &'static Self {
#[cfg(test)]
return SELF.get_or_init(|| Metrics::new(Arc::new(ThreadPoolMetrics::new(0))));
static SELF: OnceLock<Metrics> = OnceLock::new();
#[cfg(not(test))]
SELF.get()
.expect("proxy metrics must be installed by the main() function")
SELF.get_or_init(|| {
let mut metrics = Metrics::new();
metrics.proxy.errors_total.init_all_dense();
metrics.proxy.redis_errors_total.init_all_dense();
metrics.proxy.redis_events_count.init_all_dense();
metrics.proxy.retries_metric.init_all_dense();
metrics.proxy.connection_failures_total.init_all_dense();
metrics
})
}
}
#[derive(MetricGroup)]
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>))]
#[metric(new())]
pub struct ProxyMetrics {
#[metric(flatten)]
pub db_connections: CounterPairVec<NumDbConnectionsGauge>,
@@ -134,6 +128,9 @@ pub struct ProxyMetrics {
/// Number of TLS handshake failures
pub tls_handshake_failures: Counter,
/// Number of SHA 256 rounds executed.
pub sha_rounds: Counter,
/// HLL approximate cardinality of endpoints that are connecting
pub connecting_endpoints: HyperLogLogVec<StaticLabelSet<Protocol>, 32>,
@@ -151,8 +148,25 @@ pub struct ProxyMetrics {
pub connect_compute_lock: ApiLockMetrics,
#[metric(namespace = "scram_pool")]
#[metric(init = thread_pool)]
pub scram_pool: Arc<ThreadPoolMetrics>,
pub scram_pool: OnceLockWrapper<Arc<ThreadPoolMetrics>>,
}
/// A Wrapper over [`OnceLock`] to implement [`MetricGroup`].
pub struct OnceLockWrapper<T>(pub OnceLock<T>);
impl<T> Default for OnceLockWrapper<T> {
fn default() -> Self {
Self(OnceLock::new())
}
}
impl<Enc: Encoding, T: MetricGroup<Enc>> MetricGroup<Enc> for OnceLockWrapper<T> {
fn collect_group_into(&self, enc: &mut Enc) -> Result<(), Enc::Err> {
if let Some(inner) = self.0.get() {
inner.collect_group_into(enc)?;
}
Ok(())
}
}
#[derive(MetricGroup)]
@@ -553,14 +567,6 @@ impl From<bool> for Bool {
}
}
#[derive(LabelGroup)]
#[label(set = InvalidEndpointsSet)]
pub struct InvalidEndpointsGroup {
pub protocol: Protocol,
pub rejected: Bool,
pub outcome: ConnectOutcome,
}
#[derive(LabelGroup)]
#[label(set = RetriesMetricSet)]
pub struct RetriesMetricGroup {
@@ -727,6 +733,7 @@ pub enum CacheKind {
ProjectInfoEndpoints,
ProjectInfoRoles,
Schema,
Pbkdf2,
}
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]

84
proxy/src/scram/cache.rs Normal file
View File

@@ -0,0 +1,84 @@
use tokio::time::Instant;
use zeroize::Zeroize as _;
use super::pbkdf2;
use crate::cache::Cached;
use crate::cache::common::{Cache, count_cache_insert, count_cache_outcome, eviction_listener};
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::metrics::{CacheKind, Metrics};
pub(crate) struct Pbkdf2Cache(moka::sync::Cache<(EndpointIdInt, RoleNameInt), Pbkdf2CacheEntry>);
pub(crate) type CachedPbkdf2<'a> = Cached<&'a Pbkdf2Cache>;
impl Cache for Pbkdf2Cache {
type Key = (EndpointIdInt, RoleNameInt);
type Value = Pbkdf2CacheEntry;
fn invalidate(&self, info: &(EndpointIdInt, RoleNameInt)) {
self.0.invalidate(info);
}
}
/// To speed up password hashing for more active customers, we store the tail results of the
/// PBKDF2 algorithm. If the output of PBKDF2 is U1 ^ U2 ^ ⋯ ^ Uc, then we store
/// suffix = U17 ^ U18 ^ ⋯ ^ Uc. We only need to calculate U1 ^ U2 ^ ⋯ ^ U15 ^ U16
/// to determine the final result.
///
/// The suffix alone isn't enough to crack the password. The stored_key is still required.
/// While both are cached in memory, given they're in different locations is makes it much
/// harder to exploit, even if any such memory exploit exists in proxy.
#[derive(Clone)]
pub struct Pbkdf2CacheEntry {
/// corresponds to [`super::ServerSecret::cached_at`]
pub(super) cached_from: Instant,
pub(super) suffix: pbkdf2::Block,
}
impl Drop for Pbkdf2CacheEntry {
fn drop(&mut self) {
self.suffix.zeroize();
}
}
impl Pbkdf2Cache {
pub fn new() -> Self {
const SIZE: u64 = 100;
const TTL: std::time::Duration = std::time::Duration::from_secs(60);
let builder = moka::sync::Cache::builder()
.name("pbkdf2")
.max_capacity(SIZE)
// We use time_to_live so we don't refresh the lifetime for an invalid password attempt.
.time_to_live(TTL);
Metrics::get()
.cache
.capacity
.set(CacheKind::Pbkdf2, SIZE as i64);
let builder =
builder.eviction_listener(|_k, _v, cause| eviction_listener(CacheKind::Pbkdf2, cause));
Self(builder.build())
}
pub fn insert(&self, endpoint: EndpointIdInt, role: RoleNameInt, value: Pbkdf2CacheEntry) {
count_cache_insert(CacheKind::Pbkdf2);
self.0.insert((endpoint, role), value);
}
fn get(&self, endpoint: EndpointIdInt, role: RoleNameInt) -> Option<Pbkdf2CacheEntry> {
count_cache_outcome(CacheKind::Pbkdf2, self.0.get(&(endpoint, role)))
}
pub fn get_entry(
&self,
endpoint: EndpointIdInt,
role: RoleNameInt,
) -> Option<CachedPbkdf2<'_>> {
self.get(endpoint, role).map(|value| Cached {
token: Some((self, (endpoint, role))),
value,
})
}
}

View File

@@ -4,10 +4,8 @@ use std::convert::Infallible;
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
use hmac::{Hmac, Mac};
use sha2::Sha256;
use tracing::{debug, trace};
use super::ScramKey;
use super::messages::{
ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
};
@@ -15,8 +13,10 @@ use super::pbkdf2::Pbkdf2;
use super::secret::ServerSecret;
use super::signature::SignatureBuilder;
use super::threadpool::ThreadPool;
use crate::intern::EndpointIdInt;
use super::{ScramKey, pbkdf2};
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::sasl::{self, ChannelBinding, Error as SaslError};
use crate::scram::cache::Pbkdf2CacheEntry;
/// The only channel binding mode we currently support.
#[derive(Debug)]
@@ -77,46 +77,113 @@ impl<'a> Exchange<'a> {
}
}
// copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L236-L248>
async fn derive_client_key(
pool: &ThreadPool,
endpoint: EndpointIdInt,
password: &[u8],
salt: &[u8],
iterations: u32,
) -> ScramKey {
let salted_password = pool
.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
.await;
let make_key = |name| {
let key = Hmac::<Sha256>::new_from_slice(&salted_password)
.expect("HMAC is able to accept all key sizes")
.chain_update(name)
.finalize();
<[u8; 32]>::from(key.into_bytes())
};
make_key(b"Client Key").into()
) -> pbkdf2::Block {
pool.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
.await
}
/// For cleartext flow, we need to derive the client key to
/// 1. authenticate the client.
/// 2. authenticate with compute.
pub(crate) async fn exchange(
pool: &ThreadPool,
endpoint: EndpointIdInt,
role: RoleNameInt,
secret: &ServerSecret,
password: &[u8],
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
if secret.iterations > CACHED_ROUNDS {
exchange_with_cache(pool, endpoint, role, secret, password).await
} else {
let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?;
let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
Ok(validate_pbkdf2(secret, &hash))
}
}
/// Compute the client key using a cache. We cache the suffix of the pbkdf2 result only,
/// which is not enough by itself to perform an offline brute force.
async fn exchange_with_cache(
pool: &ThreadPool,
endpoint: EndpointIdInt,
role: RoleNameInt,
secret: &ServerSecret,
password: &[u8],
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?;
let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
debug_assert!(
secret.iterations > CACHED_ROUNDS,
"we should not cache password data if there isn't enough rounds needed"
);
// compute the prefix of the pbkdf2 output.
let prefix = derive_client_key(pool, endpoint, password, &salt, CACHED_ROUNDS).await;
if let Some(entry) = pool.cache.get_entry(endpoint, role) {
// hot path: let's check the threadpool cache
if secret.cached_at == entry.cached_from {
// cache is valid. compute the full hash by adding the prefix to the suffix.
let mut hash = prefix;
pbkdf2::xor_assign(&mut hash, &entry.suffix);
let outcome = validate_pbkdf2(secret, &hash);
if matches!(outcome, sasl::Outcome::Success(_)) {
trace!("password validated from cache");
}
return Ok(outcome);
}
// cached key is no longer valid.
debug!("invalidating cached password");
entry.invalidate();
}
// slow path: full password hash.
let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
let outcome = validate_pbkdf2(secret, &hash);
let client_key = match outcome {
sasl::Outcome::Success(client_key) => client_key,
sasl::Outcome::Failure(_) => return Ok(outcome),
};
trace!("storing cached password");
// time to cache, compute the suffix by subtracting the prefix from the hash.
let mut suffix = hash;
pbkdf2::xor_assign(&mut suffix, &prefix);
pool.cache.insert(
endpoint,
role,
Pbkdf2CacheEntry {
cached_from: secret.cached_at,
suffix,
},
);
Ok(sasl::Outcome::Success(client_key))
}
fn validate_pbkdf2(secret: &ServerSecret, hash: &pbkdf2::Block) -> sasl::Outcome<ScramKey> {
let client_key = super::ScramKey::client_key(&(*hash).into());
if secret.is_password_invalid(&client_key).into() {
Ok(sasl::Outcome::Failure("password doesn't match"))
sasl::Outcome::Failure("password doesn't match")
} else {
Ok(sasl::Outcome::Success(client_key))
sasl::Outcome::Success(client_key)
}
}
const CACHED_ROUNDS: u32 = 16;
impl SaslInitial {
fn transition(
&self,

View File

@@ -1,6 +1,12 @@
//! Tools for client/server/stored key management.
use hmac::Mac as _;
use sha2::Digest as _;
use subtle::ConstantTimeEq;
use zeroize::Zeroize as _;
use crate::metrics::Metrics;
use crate::scram::pbkdf2::Prf;
/// Faithfully taken from PostgreSQL.
pub(crate) const SCRAM_KEY_LEN: usize = 32;
@@ -14,6 +20,12 @@ pub(crate) struct ScramKey {
bytes: [u8; SCRAM_KEY_LEN],
}
impl Drop for ScramKey {
fn drop(&mut self) {
self.bytes.zeroize();
}
}
impl PartialEq for ScramKey {
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).into()
@@ -28,12 +40,26 @@ impl ConstantTimeEq for ScramKey {
impl ScramKey {
pub(crate) fn sha256(&self) -> Self {
super::sha256([self.as_ref()]).into()
Metrics::get().proxy.sha_rounds.inc_by(1);
Self {
bytes: sha2::Sha256::digest(self.as_bytes()).into(),
}
}
pub(crate) fn as_bytes(&self) -> [u8; SCRAM_KEY_LEN] {
self.bytes
}
pub(crate) fn client_key(b: &[u8; 32]) -> Self {
// Prf::new_from_slice will run 2 sha256 rounds.
// Update + Finalize run 2 sha256 rounds.
Metrics::get().proxy.sha_rounds.inc_by(4);
let mut prf = Prf::new_from_slice(b).expect("HMAC is able to accept all key sizes");
prf.update(b"Client Key");
let client_key: [u8; 32] = prf.finalize().into_bytes().into();
client_key.into()
}
}
impl From<[u8; SCRAM_KEY_LEN]> for ScramKey {

View File

@@ -6,6 +6,7 @@
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-scram.c>
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth-scram.c>
mod cache;
mod countmin;
mod exchange;
mod key;
@@ -18,10 +19,8 @@ pub mod threadpool;
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
pub(crate) use exchange::{Exchange, exchange};
use hmac::{Hmac, Mac};
pub(crate) use key::ScramKey;
pub(crate) use secret::ServerSecret;
use sha2::{Digest, Sha256};
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
@@ -42,29 +41,13 @@ fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N
Some(bytes)
}
/// This function essentially is `Hmac(sha256, key, input)`.
/// Further reading: <https://datatracker.ietf.org/doc/html/rfc2104>.
fn hmac_sha256<'a>(key: &[u8], parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("bad key size");
parts.into_iter().for_each(|s| mac.update(s));
mac.finalize().into_bytes().into()
}
fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
let mut hasher = Sha256::new();
parts.into_iter().for_each(|s| hasher.update(s));
hasher.finalize().into()
}
#[cfg(test)]
mod tests {
use super::threadpool::ThreadPool;
use super::{Exchange, ServerSecret};
use crate::intern::EndpointIdInt;
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::sasl::{Mechanism, Step};
use crate::types::EndpointId;
use crate::types::{EndpointId, RoleName};
#[test]
fn snapshot() {
@@ -114,23 +97,34 @@ mod tests {
);
}
async fn run_round_trip_test(server_password: &str, client_password: &str) {
let pool = ThreadPool::new(1);
async fn check(
pool: &ThreadPool,
scram_secret: &ServerSecret,
password: &[u8],
) -> Result<(), &'static str> {
let ep = EndpointId::from("foo");
let ep = EndpointIdInt::from(ep);
let role = RoleName::from("user");
let role = RoleNameInt::from(&role);
let scram_secret = ServerSecret::build(server_password).await.unwrap();
let outcome = super::exchange(&pool, ep, &scram_secret, client_password.as_bytes())
let outcome = super::exchange(pool, ep, role, scram_secret, password)
.await
.unwrap();
match outcome {
crate::sasl::Outcome::Success(_) => {}
crate::sasl::Outcome::Failure(r) => panic!("{r}"),
crate::sasl::Outcome::Success(_) => Ok(()),
crate::sasl::Outcome::Failure(r) => Err(r),
}
}
async fn run_round_trip_test(server_password: &str, client_password: &str) {
let pool = ThreadPool::new(1);
let scram_secret = ServerSecret::build(server_password).await.unwrap();
check(&pool, &scram_secret, client_password.as_bytes())
.await
.unwrap();
}
#[tokio::test]
async fn round_trip() {
run_round_trip_test("pencil", "pencil").await;
@@ -141,4 +135,27 @@ mod tests {
async fn failure() {
run_round_trip_test("pencil", "eraser").await;
}
#[tokio::test]
#[tracing_test::traced_test]
async fn password_cache() {
let pool = ThreadPool::new(1);
let scram_secret = ServerSecret::build("password").await.unwrap();
// wrong passwords are not added to cache
check(&pool, &scram_secret, b"wrong").await.unwrap_err();
assert!(!logs_contain("storing cached password"));
// correct passwords get cached
check(&pool, &scram_secret, b"password").await.unwrap();
assert!(logs_contain("storing cached password"));
// wrong passwords do not match the cache
check(&pool, &scram_secret, b"wrong").await.unwrap_err();
assert!(!logs_contain("password validated from cache"));
// correct passwords match the cache
check(&pool, &scram_secret, b"password").await.unwrap();
assert!(logs_contain("password validated from cache"));
}
}

View File

@@ -1,25 +1,50 @@
//! For postgres password authentication, we need to perform a PBKDF2 using
//! PRF=HMAC-SHA2-256, producing only 1 block (32 bytes) of output key.
use hmac::Mac as _;
use hmac::digest::consts::U32;
use hmac::digest::generic_array::GenericArray;
use hmac::{Hmac, Mac};
use sha2::Sha256;
use zeroize::Zeroize as _;
use crate::metrics::Metrics;
/// The Psuedo-random function used during PBKDF2 and the SCRAM-SHA-256 handshake.
pub type Prf = hmac::Hmac<sha2::Sha256>;
pub(crate) type Block = GenericArray<u8, U32>;
pub(crate) struct Pbkdf2 {
hmac: Hmac<Sha256>,
prev: GenericArray<u8, U32>,
hi: GenericArray<u8, U32>,
hmac: Prf,
/// U{r-1} for whatever iteration r we are currently on.
prev: Block,
/// the output of `fold(xor, U{1}..U{r})` for whatever iteration r we are currently on.
hi: Block,
/// number of iterations left
iterations: u32,
}
impl Drop for Pbkdf2 {
fn drop(&mut self) {
self.prev.zeroize();
self.hi.zeroize();
}
}
// inspired from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L36-L61>
impl Pbkdf2 {
pub(crate) fn start(str: &[u8], salt: &[u8], iterations: u32) -> Self {
pub(crate) fn start(pw: &[u8], salt: &[u8], iterations: u32) -> Self {
// key the HMAC and derive the first block in-place
let mut hmac =
Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
let mut hmac = Prf::new_from_slice(pw).expect("HMAC is able to accept all key sizes");
// U1 = PRF(Password, Salt + INT_32_BE(i))
// i = 1 since we only need 1 block of output.
hmac.update(salt);
hmac.update(&1u32.to_be_bytes());
let init_block = hmac.finalize_reset().into_bytes();
// Prf::new_from_slice will run 2 sha256 rounds.
// Our update + finalize run 2 sha256 rounds for each pbkdf2 round.
Metrics::get().proxy.sha_rounds.inc_by(4);
Self {
hmac,
// one iteration spent above
@@ -33,7 +58,11 @@ impl Pbkdf2 {
(self.iterations).clamp(0, 4096)
}
pub(crate) fn turn(&mut self) -> std::task::Poll<[u8; 32]> {
/// For "fairness", we implement PBKDF2 with cooperative yielding, which is why we use this `turn`
/// function that only executes a fixed number of iterations before continuing.
///
/// Task must be rescheuled if this returns [`std::task::Poll::Pending`].
pub(crate) fn turn(&mut self) -> std::task::Poll<Block> {
let Self {
hmac,
prev,
@@ -44,25 +73,37 @@ impl Pbkdf2 {
// only do up to 4096 iterations per turn for fairness
let n = (*iterations).clamp(0, 4096);
for _ in 0..n {
hmac.update(prev);
let block = hmac.finalize_reset().into_bytes();
for (hi_byte, &b) in hi.iter_mut().zip(block.iter()) {
*hi_byte ^= b;
}
*prev = block;
let next = single_round(hmac, prev);
xor_assign(hi, &next);
*prev = next;
}
// Our update + finalize run 2 sha256 rounds for each pbkdf2 round.
Metrics::get().proxy.sha_rounds.inc_by(2 * n as u64);
*iterations -= n;
if *iterations == 0 {
std::task::Poll::Ready((*hi).into())
std::task::Poll::Ready(*hi)
} else {
std::task::Poll::Pending
}
}
}
#[inline(always)]
pub fn xor_assign(x: &mut Block, y: &Block) {
for (x, &y) in std::iter::zip(x, y) {
*x ^= y;
}
}
#[inline(always)]
fn single_round(prf: &mut Prf, ui: &Block) -> Block {
// Ui = PRF(Password, Ui-1)
prf.update(ui);
prf.finalize_reset().into_bytes()
}
#[cfg(test)]
mod tests {
use pbkdf2::pbkdf2_hmac_array;
@@ -76,11 +117,11 @@ mod tests {
let pass = b"Ne0n_!5_50_C007";
let mut job = Pbkdf2::start(pass, salt, 60000);
let hash = loop {
let hash: [u8; 32] = loop {
let std::task::Poll::Ready(hash) = job.turn() else {
continue;
};
break hash;
break hash.into();
};
let expected = pbkdf2_hmac_array::<Sha256, 32>(pass, salt, 60000);

View File

@@ -3,6 +3,7 @@
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
use subtle::{Choice, ConstantTimeEq};
use tokio::time::Instant;
use super::base64_decode_array;
use super::key::ScramKey;
@@ -11,6 +12,9 @@ use super::key::ScramKey;
/// and is used throughout the authentication process.
#[derive(Clone, Eq, PartialEq, Debug)]
pub(crate) struct ServerSecret {
/// When this secret was cached.
pub(crate) cached_at: Instant,
/// Number of iterations for `PBKDF2` function.
pub(crate) iterations: u32,
/// Salt used to hash user's password.
@@ -34,6 +38,7 @@ impl ServerSecret {
params.split_once(':').zip(keys.split_once(':'))?;
let secret = ServerSecret {
cached_at: Instant::now(),
iterations: iterations.parse().ok()?,
salt_base64: salt.into(),
stored_key: base64_decode_array(stored_key)?.into(),
@@ -54,6 +59,7 @@ impl ServerSecret {
/// See `auth-scram.c : mock_scram_secret` for details.
pub(crate) fn mock(nonce: [u8; 32]) -> Self {
Self {
cached_at: Instant::now(),
// this doesn't reveal much information as we're going to use
// iteration count 1 for our generated passwords going forward.
// PG16 users can set iteration count=1 already today.

View File

@@ -1,6 +1,10 @@
//! Tools for client/server signature management.
use hmac::Mac as _;
use super::key::{SCRAM_KEY_LEN, ScramKey};
use crate::metrics::Metrics;
use crate::scram::pbkdf2::Prf;
/// A collection of message parts needed to derive the client's signature.
#[derive(Debug)]
@@ -12,15 +16,18 @@ pub(crate) struct SignatureBuilder<'a> {
impl SignatureBuilder<'_> {
pub(crate) fn build(&self, key: &ScramKey) -> Signature {
let parts = [
self.client_first_message_bare.as_bytes(),
b",",
self.server_first_message.as_bytes(),
b",",
self.client_final_message_without_proof.as_bytes(),
];
// don't know exactly. this is a rough approx
Metrics::get().proxy.sha_rounds.inc_by(8);
super::hmac_sha256(key.as_ref(), parts).into()
let mut mac = Prf::new_from_slice(key.as_ref()).expect("HMAC accepts all key sizes");
mac.update(self.client_first_message_bare.as_bytes());
mac.update(b",");
mac.update(self.server_first_message.as_bytes());
mac.update(b",");
mac.update(self.client_final_message_without_proof.as_bytes());
Signature {
bytes: mac.finalize().into_bytes().into(),
}
}
}

View File

@@ -15,6 +15,8 @@ use futures::FutureExt;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use super::cache::Pbkdf2Cache;
use super::pbkdf2;
use super::pbkdf2::Pbkdf2;
use crate::intern::EndpointIdInt;
use crate::metrics::{ThreadPoolMetrics, ThreadPoolWorkerId};
@@ -23,6 +25,10 @@ use crate::scram::countmin::CountMinSketch;
pub struct ThreadPool {
runtime: Option<tokio::runtime::Runtime>,
pub metrics: Arc<ThreadPoolMetrics>,
// we hash a lot of passwords.
// we keep a cache of partial hashes for faster validation.
pub(super) cache: Pbkdf2Cache,
}
/// How often to reset the sketch values
@@ -68,6 +74,7 @@ impl ThreadPool {
Self {
runtime: Some(runtime),
metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
cache: Pbkdf2Cache::new(),
}
})
}
@@ -130,7 +137,7 @@ struct JobSpec {
}
impl Future for JobSpec {
type Output = [u8; 32];
type Output = pbkdf2::Block;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
STATE.with_borrow_mut(|state| {
@@ -166,10 +173,10 @@ impl Future for JobSpec {
}
}
pub(crate) struct JobHandle(tokio::task::JoinHandle<[u8; 32]>);
pub(crate) struct JobHandle(tokio::task::JoinHandle<pbkdf2::Block>);
impl Future for JobHandle {
type Output = [u8; 32];
type Output = pbkdf2::Block;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.0.poll_unpin(cx) {
@@ -203,10 +210,10 @@ mod tests {
.spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096))
.await;
let expected = [
let expected = &[
10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242,
178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140,
];
assert_eq!(actual, expected);
assert_eq!(actual.as_slice(), expected);
}
}

View File

@@ -26,7 +26,7 @@ use crate::context::RequestContext;
use crate::control_plane::client::ApiLockError;
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::intern::EndpointIdInt;
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::pqproto::StartupMessageParams;
use crate::proxy::{connect_auth, connect_compute};
use crate::rate_limiter::EndpointRateLimiter;
@@ -76,9 +76,11 @@ impl PoolingBackend {
};
let ep = EndpointIdInt::from(&user_info.endpoint);
let role = RoleNameInt::from(&user_info.user);
let auth_outcome = crate::auth::validate_password_and_exchange(
&self.config.authentication_config.thread_pool,
&self.config.authentication_config.scram_thread_pool,
ep,
role,
password,
secret,
)

View File

@@ -102,7 +102,7 @@ pub struct ReportedError {
}
impl ReportedError {
pub fn new(e: (impl UserFacingError + Into<anyhow::Error>)) -> Self {
pub fn new(e: impl UserFacingError + Into<anyhow::Error>) -> Self {
let error_kind = e.get_error_kind();
Self {
source: e.into(),

View File

@@ -55,7 +55,7 @@ def test_neon_extension_compatibility(neon_env_builder: NeonEnvBuilder):
# Ensure that the default version is also updated in the neon.control file
assert cur.fetchone() == ("1.6",)
cur.execute("SELECT * from neon.NEON_STAT_FILE_CACHE")
all_versions = ["1.7", "1.6", "1.5", "1.4", "1.3", "1.2", "1.1", "1.0"]
all_versions = ["1.6", "1.5", "1.4", "1.3", "1.2", "1.1", "1.0"]
current_version = "1.6"
for idx, begin_version in enumerate(all_versions):
for target_version in all_versions[idx + 1 :]: