mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-20 22:50:38 +00:00
Compare commits
4 Commits
min_inflig
...
heikki/upd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9e3a1ccbd8 | ||
|
|
d487ba2b9b | ||
|
|
e7a1d5de94 | ||
|
|
6be572177c |
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -5078,7 +5078,6 @@ dependencies = [
|
||||
"criterion",
|
||||
"env_logger",
|
||||
"log",
|
||||
"memoffset 0.9.0",
|
||||
"once_cell",
|
||||
"postgres",
|
||||
"postgres_ffi_types",
|
||||
@@ -5519,6 +5518,7 @@ dependencies = [
|
||||
"workspace_hack",
|
||||
"x509-cert",
|
||||
"zerocopy 0.8.24",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -135,7 +135,6 @@ lock_api = "0.4.13"
|
||||
md5 = "0.7.0"
|
||||
measured = { version = "0.0.22", features=["lasso"] }
|
||||
measured-process = { version = "0.0.22" }
|
||||
memoffset = "0.9"
|
||||
moka = { version = "0.12", features = ["sync"] }
|
||||
nix = { version = "0.30.1", features = ["dir", "fs", "mman", "process", "socket", "signal", "poll"] }
|
||||
# Do not update to >= 7.0.0, at least. The update will have a significant impact
|
||||
@@ -234,9 +233,10 @@ uuid = { version = "1.6.1", features = ["v4", "v7", "serde"] }
|
||||
walkdir = "2.3.2"
|
||||
rustls-native-certs = "0.8"
|
||||
whoami = "1.5.1"
|
||||
zerocopy = { version = "0.8", features = ["derive", "simd"] }
|
||||
json-structural-diff = { version = "0.2.0" }
|
||||
x509-cert = { version = "0.2.5" }
|
||||
zerocopy = { version = "0.8", features = ["derive", "simd"] }
|
||||
zeroize = "1.8"
|
||||
|
||||
## TODO replace this with tracing
|
||||
env_logger = "0.11"
|
||||
|
||||
@@ -1633,6 +1633,12 @@ FROM pg-build-with-cargo AS neon-ext-build
|
||||
ARG PG_VERSION
|
||||
|
||||
USER root
|
||||
|
||||
# Update the rust toolchain. Running 'make' will do this, but better to do
|
||||
# it as a separately cacheable step.
|
||||
COPY rust-toolchain.toml .
|
||||
RUN rustup show
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN make -j $(getconf _NPROCESSORS_ONLN) -C pgxn -s install-compute \
|
||||
@@ -1731,6 +1737,12 @@ ARG BUILD_TAG
|
||||
ENV BUILD_TAG=$BUILD_TAG
|
||||
|
||||
USER nonroot
|
||||
|
||||
# Update the rust toolchain. Running 'cargo build' will do this, but
|
||||
# better to do it as a separately cacheable step.
|
||||
COPY --chown=nonroot rust-toolchain.toml .
|
||||
RUN rustup show
|
||||
|
||||
# Copy entire project to get Cargo.* files with proper dependencies for the whole project
|
||||
COPY --chown=nonroot . .
|
||||
RUN --mount=type=cache,uid=1000,target=/home/nonroot/.cargo/registry \
|
||||
|
||||
@@ -558,11 +558,11 @@ async fn add_request_id_header_to_response(
|
||||
mut res: Response<Body>,
|
||||
req_info: RequestInfo,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
if let Some(request_id) = req_info.context::<RequestId>() {
|
||||
if let Ok(request_header_value) = HeaderValue::from_str(&request_id.0) {
|
||||
res.headers_mut()
|
||||
.insert(&X_REQUEST_ID_HEADER, request_header_value);
|
||||
};
|
||||
if let Some(request_id) = req_info.context::<RequestId>()
|
||||
&& let Ok(request_header_value) = HeaderValue::from_str(&request_id.0)
|
||||
{
|
||||
res.headers_mut()
|
||||
.insert(&X_REQUEST_ID_HEADER, request_header_value);
|
||||
};
|
||||
|
||||
Ok(res)
|
||||
|
||||
@@ -72,10 +72,10 @@ impl Server {
|
||||
if err.is_incomplete_message() || err.is_closed() || err.is_timeout() {
|
||||
return true;
|
||||
}
|
||||
if let Some(inner) = err.source() {
|
||||
if let Some(io) = inner.downcast_ref::<std::io::Error>() {
|
||||
return suppress_io_error(io);
|
||||
}
|
||||
if let Some(inner) = err.source()
|
||||
&& let Some(io) = inner.downcast_ref::<std::io::Error>()
|
||||
{
|
||||
return suppress_io_error(io);
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
@@ -363,7 +363,7 @@ where
|
||||
// TODO: An Iterator might be nicer. The communicator's clock algorithm needs to
|
||||
// _slowly_ iterate through all buckets with its clock hand, without holding a lock.
|
||||
// If we switch to an Iterator, it must not hold the lock.
|
||||
pub fn get_at_bucket(&self, pos: usize) -> Option<ValueReadGuard<(K, V)>> {
|
||||
pub fn get_at_bucket(&self, pos: usize) -> Option<ValueReadGuard<'_, (K, V)>> {
|
||||
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
|
||||
if pos >= map.buckets.len() {
|
||||
return None;
|
||||
|
||||
@@ -12,7 +12,6 @@ crc32c.workspace = true
|
||||
criterion.workspace = true
|
||||
once_cell.workspace = true
|
||||
log.workspace = true
|
||||
memoffset.workspace = true
|
||||
pprof.workspace = true
|
||||
thiserror.workspace = true
|
||||
serde.workspace = true
|
||||
|
||||
@@ -34,9 +34,8 @@ const SIZEOF_CONTROLDATA: usize = size_of::<ControlFileData>();
|
||||
impl ControlFileData {
|
||||
/// Compute the offset of the `crc` field within the `ControlFileData` struct.
|
||||
/// Equivalent to offsetof(ControlFileData, crc) in C.
|
||||
// Someday this can be const when the right compiler features land.
|
||||
fn pg_control_crc_offset() -> usize {
|
||||
memoffset::offset_of!(ControlFileData, crc)
|
||||
const fn pg_control_crc_offset() -> usize {
|
||||
std::mem::offset_of!(ControlFileData, crc)
|
||||
}
|
||||
|
||||
///
|
||||
|
||||
@@ -49,7 +49,7 @@ impl PerfSpan {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn enter(&self) -> PerfSpanEntered {
|
||||
pub fn enter(&self) -> PerfSpanEntered<'_> {
|
||||
if let Some(ref id) = self.inner.id() {
|
||||
self.dispatch.enter(id);
|
||||
}
|
||||
|
||||
@@ -48,8 +48,6 @@ DATA = \
|
||||
neon--1.3--1.4.sql \
|
||||
neon--1.4--1.5.sql \
|
||||
neon--1.5--1.6.sql \
|
||||
neon--1.6--1.7.sql \
|
||||
neon--1.7--1.6.sql \
|
||||
neon--1.6--1.5.sql \
|
||||
neon--1.5--1.4.sql \
|
||||
neon--1.4--1.3.sql \
|
||||
|
||||
@@ -260,7 +260,7 @@ typedef struct PrefetchState
|
||||
|
||||
/* the buffers */
|
||||
prfh_hash *prf_hash;
|
||||
int max_unflushed_shard_no;
|
||||
int max_shard_no;
|
||||
/* Mark shards involved in prefetch */
|
||||
uint8 shard_bitmap[(MAX_SHARDS + 7)/8];
|
||||
PrefetchRequest prf_buffer[]; /* prefetch buffers */
|
||||
@@ -300,7 +300,6 @@ static void prefetch_do_request(PrefetchRequest *slot, neon_request_lsns *force_
|
||||
static bool prefetch_wait_for(uint64 ring_index);
|
||||
static void prefetch_cleanup_trailing_unused(void);
|
||||
static inline void prefetch_set_unused(uint64 ring_index);
|
||||
static bool prefetch_flush_requests(void);
|
||||
|
||||
static bool neon_prefetch_response_usable(neon_request_lsns *request_lsns,
|
||||
PrefetchRequest *slot);
|
||||
@@ -470,26 +469,13 @@ communicator_prefetch_pump_state(void)
|
||||
{
|
||||
START_PREFETCH_RECEIVE_WORK();
|
||||
|
||||
if (MyPState->ring_receive == MyPState->ring_flush && MyPState->ring_flush < MyPState->ring_unused)
|
||||
{
|
||||
/*
|
||||
* Flush request to avoid requests pending for arbitrary long time,
|
||||
* pinning LSN and holding GC at PS.
|
||||
*/
|
||||
if (!prefetch_flush_requests())
|
||||
{
|
||||
END_PREFETCH_RECEIVE_WORK();
|
||||
return;
|
||||
}
|
||||
}
|
||||
while (MyPState->ring_receive != MyPState->ring_flush)
|
||||
{
|
||||
NeonResponse *response;
|
||||
PrefetchRequest *slot;
|
||||
MemoryContext old;
|
||||
uint64 my_ring_index = MyPState->ring_receive;
|
||||
|
||||
slot = GetPrfSlot(my_ring_index);
|
||||
slot = GetPrfSlot(MyPState->ring_receive);
|
||||
|
||||
old = MemoryContextSwitchTo(MyPState->errctx);
|
||||
response = page_server->try_receive(slot->shard_no);
|
||||
@@ -503,12 +489,12 @@ communicator_prefetch_pump_state(void)
|
||||
/* The slot should still be valid */
|
||||
if (slot->status != PRFS_REQUESTED ||
|
||||
slot->response != NULL ||
|
||||
slot->my_ring_index != my_ring_index)
|
||||
slot->my_ring_index != MyPState->ring_receive)
|
||||
{
|
||||
neon_shard_log(slot->shard_no, PANIC,
|
||||
"Incorrect prefetch slot state after receive: status=%d response=%p my=" UINT64_FORMAT " receive=" UINT64_FORMAT "",
|
||||
slot->status, slot->response,
|
||||
slot->my_ring_index, my_ring_index);
|
||||
slot->my_ring_index, MyPState->ring_receive);
|
||||
}
|
||||
/* update prefetch state */
|
||||
MyPState->n_responses_buffered += 1;
|
||||
@@ -536,19 +522,6 @@ communicator_prefetch_pump_state(void)
|
||||
|
||||
END_PREFETCH_RECEIVE_WORK();
|
||||
|
||||
if (RecoveryInProgress())
|
||||
{
|
||||
/*
|
||||
* Update backend's min in-flight prefetch LSN.
|
||||
*/
|
||||
XLogRecPtr min_backend_prefetch_lsn = last_replay_lsn != InvalidXLogRecPtr ? last_replay_lsn : GetXLogReplayRecPtr(NULL);
|
||||
for (uint64_t ring_index = MyPState->ring_receive; ring_index < MyPState->ring_unused; ring_index++)
|
||||
{
|
||||
PrefetchRequest* slot = GetPrfSlot(ring_index);
|
||||
min_backend_prefetch_lsn = Min(slot->request_lsns.request_lsn, min_backend_prefetch_lsn);
|
||||
}
|
||||
MIN_BACKEND_REQUEST_LSN = min_backend_prefetch_lsn;
|
||||
}
|
||||
communicator_reconfigure_timeout_if_needed();
|
||||
}
|
||||
|
||||
@@ -588,7 +561,7 @@ readahead_buffer_resize(int newsize, void *extra)
|
||||
newPState->ring_last = newsize;
|
||||
newPState->ring_unused = newsize;
|
||||
newPState->ring_receive = newsize;
|
||||
newPState->max_unflushed_shard_no = MyPState->max_unflushed_shard_no;
|
||||
newPState->max_shard_no = MyPState->max_shard_no;
|
||||
memcpy(newPState->shard_bitmap, MyPState->shard_bitmap, sizeof(MyPState->shard_bitmap));
|
||||
|
||||
/*
|
||||
@@ -688,7 +661,6 @@ consume_prefetch_responses(void)
|
||||
{
|
||||
if (MyPState->ring_receive < MyPState->ring_unused)
|
||||
prefetch_wait_for(MyPState->ring_unused - 1);
|
||||
|
||||
/*
|
||||
* We know for sure we're not working on any prefetch pages after
|
||||
* this.
|
||||
@@ -718,7 +690,7 @@ prefetch_cleanup_trailing_unused(void)
|
||||
static bool
|
||||
prefetch_flush_requests(void)
|
||||
{
|
||||
for (shardno_t shard_no = 0; shard_no < MyPState->max_unflushed_shard_no; shard_no++)
|
||||
for (shardno_t shard_no = 0; shard_no < MyPState->max_shard_no; shard_no++)
|
||||
{
|
||||
if (BITMAP_ISSET(MyPState->shard_bitmap, shard_no))
|
||||
{
|
||||
@@ -727,8 +699,7 @@ prefetch_flush_requests(void)
|
||||
BITMAP_CLR(MyPState->shard_bitmap, shard_no);
|
||||
}
|
||||
}
|
||||
MyPState->max_unflushed_shard_no = 0;
|
||||
MyPState->ring_flush = MyPState->ring_unused;
|
||||
MyPState->max_shard_no = 0;
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -752,6 +723,7 @@ prefetch_wait_for(uint64 ring_index)
|
||||
{
|
||||
if (!prefetch_flush_requests())
|
||||
return false;
|
||||
MyPState->ring_flush = MyPState->ring_unused;
|
||||
}
|
||||
|
||||
Assert(MyPState->ring_unused > ring_index);
|
||||
@@ -830,7 +802,6 @@ prefetch_read(PrefetchRequest *slot)
|
||||
old = MemoryContextSwitchTo(MyPState->errctx);
|
||||
response = (NeonResponse *) page_server->receive(shard_no);
|
||||
MemoryContextSwitchTo(old);
|
||||
|
||||
if (response)
|
||||
{
|
||||
check_getpage_response(slot, response);
|
||||
@@ -1039,16 +1010,11 @@ prefetch_do_request(PrefetchRequest *slot, neon_request_lsns *force_request_lsns
|
||||
Assert(mySlotNo == MyPState->ring_unused);
|
||||
|
||||
if (force_request_lsns)
|
||||
{
|
||||
slot->request_lsns = *force_request_lsns;
|
||||
}
|
||||
else
|
||||
{
|
||||
neon_get_request_lsns(BufTagGetNRelFileInfo(slot->buftag),
|
||||
slot->buftag.forkNum, slot->buftag.blockNum,
|
||||
&slot->request_lsns, 1);
|
||||
last_replay_lsn = InvalidXLogRecPtr;
|
||||
}
|
||||
request.hdr.lsn = slot->request_lsns.request_lsn;
|
||||
request.hdr.not_modified_since = slot->request_lsns.not_modified_since;
|
||||
|
||||
@@ -1067,7 +1033,7 @@ prefetch_do_request(PrefetchRequest *slot, neon_request_lsns *force_request_lsns
|
||||
MyPState->n_unused -= 1;
|
||||
MyPState->ring_unused += 1;
|
||||
BITMAP_SET(MyPState->shard_bitmap, slot->shard_no);
|
||||
MyPState->max_unflushed_shard_no = Max(slot->shard_no+1, MyPState->max_unflushed_shard_no);
|
||||
MyPState->max_shard_no = Max(slot->shard_no+1, MyPState->max_shard_no);
|
||||
|
||||
/* update slot state */
|
||||
slot->status = PRFS_REQUESTED;
|
||||
@@ -1075,25 +1041,6 @@ prefetch_do_request(PrefetchRequest *slot, neon_request_lsns *force_request_lsns
|
||||
Assert(!found);
|
||||
}
|
||||
|
||||
/*
|
||||
* Check that returned page LSN is consistent with request lsns
|
||||
*/
|
||||
static void
|
||||
check_page_lsn(NeonGetPageResponse* resp)
|
||||
{
|
||||
if (neon_protocol_version < 3) /* no information to check */
|
||||
return;
|
||||
if (PageGetLSN(resp->page) > resp->req.hdr.not_modified_since)
|
||||
neon_log(PANIC, "Invalid getpage response version: %X/%08X is higher than last modified LSN %X/%08X",
|
||||
LSN_FORMAT_ARGS(PageGetLSN(resp->page)),
|
||||
LSN_FORMAT_ARGS(resp->req.hdr.not_modified_since));
|
||||
|
||||
if (PageGetLSN(resp->page) > resp->req.hdr.lsn)
|
||||
neon_log(PANIC, "Invalid getpage response version: %X/%08X is higher than request LSN %X/%08X",
|
||||
LSN_FORMAT_ARGS(PageGetLSN(resp->page)),
|
||||
LSN_FORMAT_ARGS(resp->req.hdr.lsn));
|
||||
}
|
||||
|
||||
/*
|
||||
* Lookup of already received prefetch requests. Only already received responses matching required LSNs are accepted.
|
||||
* Present pages are marked in "mask" bitmap and total number of such pages is returned.
|
||||
@@ -1117,7 +1064,7 @@ communicator_prefetch_lookupv(NRelFileInfo rinfo, ForkNumber forknum, BlockNumbe
|
||||
for (int i = 0; i < nblocks; i++)
|
||||
{
|
||||
PrfHashEntry *entry;
|
||||
NeonGetPageResponse* resp;
|
||||
|
||||
hashkey.buftag.blockNum = blocknum + i;
|
||||
entry = prfh_lookup(MyPState->prf_hash, &hashkey);
|
||||
|
||||
@@ -1150,9 +1097,8 @@ communicator_prefetch_lookupv(NRelFileInfo rinfo, ForkNumber forknum, BlockNumbe
|
||||
continue;
|
||||
}
|
||||
Assert(slot->response->tag == T_NeonGetPageResponse); /* checked by check_getpage_response when response was assigned to the slot */
|
||||
resp = (NeonGetPageResponse*)slot->response;
|
||||
check_page_lsn(resp);
|
||||
memcpy(buffers[i], resp->page, BLCKSZ);
|
||||
memcpy(buffers[i], ((NeonGetPageResponse*)slot->response)->page, BLCKSZ);
|
||||
|
||||
|
||||
/*
|
||||
* With lfc_store_prefetch_result=true prefetch result is stored in LFC in prefetch_pump_state when response is received
|
||||
@@ -1445,6 +1391,7 @@ Retry:
|
||||
*/
|
||||
goto Retry;
|
||||
}
|
||||
MyPState->ring_flush = MyPState->ring_unused;
|
||||
}
|
||||
|
||||
return last_ring_index;
|
||||
@@ -1514,12 +1461,10 @@ page_server_request(void const *req)
|
||||
MyNeonCounters->pageserver_open_requests--;
|
||||
} while (resp == NULL);
|
||||
cancel_before_shmem_exit(prefetch_on_exit, Int32GetDatum(shard_no));
|
||||
last_replay_lsn = InvalidXLogRecPtr;
|
||||
}
|
||||
PG_CATCH();
|
||||
{
|
||||
cancel_before_shmem_exit(prefetch_on_exit, Int32GetDatum(shard_no));
|
||||
last_replay_lsn = InvalidXLogRecPtr;
|
||||
/* Nothing should cancel disconnect: we should not leave connection in opaque state */
|
||||
HOLD_INTERRUPTS();
|
||||
page_server->disconnect(shard_no);
|
||||
@@ -1919,13 +1864,6 @@ nm_to_string(NeonMessage *msg)
|
||||
return s.data;
|
||||
}
|
||||
|
||||
static void
|
||||
reset_min_request_lsn(int code, Datum arg)
|
||||
{
|
||||
if (MyProcNumber != -1)
|
||||
MIN_BACKEND_REQUEST_LSN = InvalidXLogRecPtr;
|
||||
}
|
||||
|
||||
/*
|
||||
* communicator_init() -- Initialize per-backend private state
|
||||
*/
|
||||
@@ -1937,8 +1875,6 @@ communicator_init(void)
|
||||
if (MyPState != NULL)
|
||||
return;
|
||||
|
||||
before_shmem_exit(reset_min_request_lsn, 0);
|
||||
|
||||
/*
|
||||
* Sanity check that theperf counters array is sized correctly. We got
|
||||
* this wrong once, and the formula for max number of backends and aux
|
||||
@@ -1948,7 +1884,7 @@ communicator_init(void)
|
||||
* the check here. That's OK, we don't expect the logic to change in old
|
||||
* releases.
|
||||
*/
|
||||
#if PG_MAJORVERSION_NUM >= 15
|
||||
#if PG_VERSION_NUM>=150000
|
||||
if (MyNeonCounters >= &neon_per_backend_counters_shared[NUM_NEON_PERF_COUNTER_SLOTS])
|
||||
elog(ERROR, "MyNeonCounters points past end of array");
|
||||
#endif
|
||||
@@ -2287,7 +2223,6 @@ Retry:
|
||||
case T_NeonGetPageResponse:
|
||||
{
|
||||
NeonGetPageResponse* getpage_resp = (NeonGetPageResponse *) resp;
|
||||
check_page_lsn(getpage_resp);
|
||||
memcpy(buffer, getpage_resp->page, BLCKSZ);
|
||||
|
||||
/*
|
||||
@@ -2564,30 +2499,12 @@ communicator_reconfigure_timeout_if_needed(void)
|
||||
!AmPrewarmWorker && /* do not pump prefetch state in prewarm worker */
|
||||
readahead_getpage_pull_timeout_ms > 0;
|
||||
|
||||
if (!needs_set && MIN_BACKEND_REQUEST_LSN != InvalidXLogRecPtr)
|
||||
{
|
||||
if (last_replay_lsn == InvalidXLogRecPtr)
|
||||
MIN_BACKEND_REQUEST_LSN = InvalidXLogRecPtr;
|
||||
else
|
||||
needs_set = true; /* Can not reset MIN_BACKEND_REQUEST_LSN now, have to do it later */
|
||||
}
|
||||
if (needs_set != timeout_set)
|
||||
{
|
||||
/*
|
||||
* The background writer/checkpointer doens't (shouldn't) read any pages.
|
||||
* And definitely they should not run on replica.
|
||||
* The only case when we can get here is replica promotion.
|
||||
*/
|
||||
if (AmBackgroundWriterProcess() || AmCheckpointerProcess())
|
||||
{
|
||||
MIN_BACKEND_REQUEST_LSN = InvalidXLogRecPtr;
|
||||
if (timeout_set)
|
||||
{
|
||||
disable_timeout(PS_TIMEOUT_ID, false);
|
||||
timeout_set = false;
|
||||
}
|
||||
return;
|
||||
}
|
||||
/* The background writer doens't (shouldn't) read any pages */
|
||||
Assert(!AmBackgroundWriterProcess());
|
||||
/* The checkpointer doens't (shouldn't) read any pages */
|
||||
Assert(!AmCheckpointerProcess());
|
||||
|
||||
if (unlikely(PS_TIMEOUT_ID == 0))
|
||||
{
|
||||
@@ -2620,6 +2537,14 @@ communicator_reconfigure_timeout_if_needed(void)
|
||||
static void
|
||||
pagestore_timeout_handler(void)
|
||||
{
|
||||
#if PG_MAJORVERSION_NUM <= 14
|
||||
/*
|
||||
* PG14: Setting a repeating timeout is not possible, so we signal here
|
||||
* that the timeout has already been reset, and by telling the system
|
||||
* that system will re-schedule it later if we need to.
|
||||
*/
|
||||
timeout_set = false;
|
||||
#endif
|
||||
timeout_signaled = true;
|
||||
InterruptPending = true;
|
||||
}
|
||||
@@ -2639,14 +2564,6 @@ communicator_processinterrupts(void)
|
||||
if (!readpage_reentrant_guard && readahead_getpage_pull_timeout_ms > 0)
|
||||
communicator_prefetch_pump_state();
|
||||
|
||||
#if PG_MAJORVERSION_NUM <= 14
|
||||
/*
|
||||
* PG14: Setting a repeating timeout is not possible, so we signal here
|
||||
* that the timeout has already been reset, and by telling the system
|
||||
* that system will re-schedule it later if we need to.
|
||||
*/
|
||||
timeout_set = false;
|
||||
#endif
|
||||
timeout_signaled = false;
|
||||
communicator_reconfigure_timeout_if_needed();
|
||||
}
|
||||
@@ -2656,28 +2573,3 @@ communicator_processinterrupts(void)
|
||||
|
||||
return prev_interrupt_cb();
|
||||
}
|
||||
|
||||
PG_FUNCTION_INFO_V1(neon_communicator_min_inflight_request_lsn);
|
||||
|
||||
Datum
|
||||
neon_communicator_min_inflight_request_lsn(PG_FUNCTION_ARGS)
|
||||
{
|
||||
if (RecoveryInProgress())
|
||||
{
|
||||
/* Do not hold GC for primary */
|
||||
PG_RETURN_INT64(UINT64_MAX);
|
||||
}
|
||||
else
|
||||
{
|
||||
XLogRecPtr min_lsn = GetXLogReplayRecPtr(NULL);
|
||||
size_t n_procs = ProcGlobal->allProcCount;
|
||||
for (size_t i = 0; i < n_procs; i++)
|
||||
{
|
||||
if (neon_per_backend_counters_shared[i].min_request_lsn != InvalidXLogRecPtr)
|
||||
{
|
||||
min_lsn = Min(min_lsn, neon_per_backend_counters_shared[i].min_request_lsn);
|
||||
}
|
||||
}
|
||||
PG_RETURN_INT64(min_lsn);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
create function neon_communicator_min_inflight_request_lsn() returns pg_catalog.pg_lsn
|
||||
AS 'MODULE_PATHNAME', 'neon_communicator_min_inflight_request_lsn'
|
||||
LANGUAGE C;
|
||||
@@ -1 +0,0 @@
|
||||
drop function neon_communicator_min_inflight_request_lsn();
|
||||
@@ -42,6 +42,7 @@ NeonPerfCountersShmemRequest(void)
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
NeonPerfCountersShmemInit(void)
|
||||
{
|
||||
|
||||
@@ -154,11 +154,6 @@ typedef struct
|
||||
* Histogram of query execution time.
|
||||
*/
|
||||
QTHistogramData query_time_hist;
|
||||
|
||||
/*
|
||||
* Minimal LSN of in-fligth request requests
|
||||
*/
|
||||
XLogRecPtr min_request_lsn;
|
||||
} neon_per_backend_counters;
|
||||
|
||||
/* Pointer to the shared memory array of neon_per_backend_counters structs */
|
||||
@@ -174,12 +169,6 @@ extern neon_per_backend_counters *neon_per_backend_counters_shared;
|
||||
|
||||
#define MyNeonCounters (&neon_per_backend_counters_shared[MyProcNumber])
|
||||
|
||||
/*
|
||||
* Backend-local minimal in-flight request LSN.
|
||||
* We store it in neon_per_backend_counters_shared and not in separate array to minimize false cache sharing
|
||||
*/
|
||||
#define MIN_BACKEND_REQUEST_LSN MyNeonCounters->min_request_lsn
|
||||
|
||||
extern void inc_getpage_wait(uint64 latency);
|
||||
extern void inc_page_cache_read_wait(uint64 latency);
|
||||
extern void inc_page_cache_write_wait(uint64 latency);
|
||||
|
||||
@@ -243,7 +243,6 @@ extern char *neon_timeline;
|
||||
extern char *neon_tenant;
|
||||
extern int32 max_cluster_size;
|
||||
extern int neon_protocol_version;
|
||||
extern XLogRecPtr last_replay_lsn;
|
||||
|
||||
extern shardno_t get_shard_number(BufferTag* tag);
|
||||
|
||||
|
||||
@@ -96,8 +96,6 @@ typedef enum
|
||||
|
||||
int debug_compare_local;
|
||||
|
||||
XLogRecPtr last_replay_lsn;
|
||||
|
||||
static NRelFileInfo unlogged_build_rel_info;
|
||||
static UnloggedBuildPhase unlogged_build_phase = UNLOGGED_BUILD_NOT_IN_PROGRESS;
|
||||
|
||||
@@ -161,7 +159,7 @@ log_newpages_copy(NRelFileInfo * rinfo, ForkNumber forkNum, BlockNumber blkno,
|
||||
page_std);
|
||||
}
|
||||
|
||||
return GetXLogInsertRecPtr();
|
||||
return ProcLastRecPtr;
|
||||
}
|
||||
#endif /* PG_MAJORVERSION_NUM >= 17 */
|
||||
|
||||
@@ -590,17 +588,6 @@ neon_get_request_lsns(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
|
||||
/* Request the page at the end of the last fully replayed LSN. */
|
||||
XLogRecPtr replay_lsn = GetXLogReplayRecPtr(NULL);
|
||||
|
||||
if (MIN_BACKEND_REQUEST_LSN == InvalidXLogRecPtr)
|
||||
{
|
||||
/* mark the backend's replay_lsn as "we have a request ongoing", blocking the expiration of any current LSN */
|
||||
MIN_BACKEND_REQUEST_LSN = replay_lsn;
|
||||
/* make sure memory operations are in correct order, even in concurrent systems */
|
||||
pg_memory_barrier();
|
||||
/* get the current LSN to register */
|
||||
replay_lsn = GetXLogReplayRecPtr(NULL);
|
||||
MIN_BACKEND_REQUEST_LSN = replay_lsn;
|
||||
}
|
||||
last_replay_lsn = replay_lsn;
|
||||
for (int i = 0; i < nblocks; i++)
|
||||
{
|
||||
neon_request_lsns *result = &output[i];
|
||||
|
||||
@@ -107,6 +107,7 @@ uuid.workspace = true
|
||||
x509-cert.workspace = true
|
||||
redis.workspace = true
|
||||
zerocopy.workspace = true
|
||||
zeroize.workspace = true
|
||||
# uncomment this to use the real subzero-core crate
|
||||
# subzero-core = { git = "https://github.com/neondatabase/subzero", rev = "396264617e78e8be428682f87469bb25429af88a", features = ["postgresql"], optional = true }
|
||||
# this is a stub for the subzero-core crate
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::auth::{self, AuthFlow};
|
||||
use crate::config::AuthenticationConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::AuthSecret;
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::intern::{EndpointIdInt, RoleNameInt};
|
||||
use crate::sasl;
|
||||
use crate::stream::{self, Stream};
|
||||
|
||||
@@ -25,13 +25,15 @@ pub(crate) async fn authenticate_cleartext(
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||
|
||||
let ep = EndpointIdInt::from(&info.endpoint);
|
||||
let role = RoleNameInt::from(&info.user);
|
||||
|
||||
let auth_flow = AuthFlow::new(
|
||||
client,
|
||||
auth::CleartextPassword {
|
||||
secret,
|
||||
endpoint: ep,
|
||||
pool: config.thread_pool.clone(),
|
||||
role,
|
||||
pool: config.scram_thread_pool.clone(),
|
||||
},
|
||||
);
|
||||
let auth_outcome = {
|
||||
|
||||
@@ -25,7 +25,7 @@ use crate::control_plane::messages::EndpointRateLimitConfig;
|
||||
use crate::control_plane::{
|
||||
self, AccessBlockerFlags, AuthSecret, ControlPlaneApi, EndpointAccessControl, RoleAccessControl,
|
||||
};
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::intern::{EndpointIdInt, RoleNameInt};
|
||||
use crate::pqproto::BeMessage;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::proxy::wake_compute::WakeComputeBackend;
|
||||
@@ -273,9 +273,11 @@ async fn authenticate_with_secret(
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
if let Some(password) = unauthenticated_password {
|
||||
let ep = EndpointIdInt::from(&info.endpoint);
|
||||
let role = RoleNameInt::from(&info.user);
|
||||
|
||||
let auth_outcome =
|
||||
validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
|
||||
validate_password_and_exchange(&config.scram_thread_pool, ep, role, &password, secret)
|
||||
.await?;
|
||||
let keys = match auth_outcome {
|
||||
crate::sasl::Outcome::Success(key) => key,
|
||||
crate::sasl::Outcome::Failure(reason) => {
|
||||
@@ -499,7 +501,7 @@ mod tests {
|
||||
|
||||
static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
|
||||
jwks_cache: JwkCache::default(),
|
||||
thread_pool: ThreadPool::new(1),
|
||||
scram_thread_pool: ThreadPool::new(1),
|
||||
scram_protocol_timeout: std::time::Duration::from_secs(5),
|
||||
ip_allowlist_check_enabled: true,
|
||||
is_vpc_acccess_proxy: false,
|
||||
|
||||
@@ -10,7 +10,7 @@ use super::backend::ComputeCredentialKeys;
|
||||
use super::{AuthError, PasswordHackPayload};
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::AuthSecret;
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::intern::{EndpointIdInt, RoleNameInt};
|
||||
use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
|
||||
use crate::sasl;
|
||||
use crate::scram::threadpool::ThreadPool;
|
||||
@@ -46,6 +46,7 @@ pub(crate) struct PasswordHack;
|
||||
pub(crate) struct CleartextPassword {
|
||||
pub(crate) pool: Arc<ThreadPool>,
|
||||
pub(crate) endpoint: EndpointIdInt,
|
||||
pub(crate) role: RoleNameInt,
|
||||
pub(crate) secret: AuthSecret,
|
||||
}
|
||||
|
||||
@@ -111,6 +112,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
|
||||
let outcome = validate_password_and_exchange(
|
||||
&self.state.pool,
|
||||
self.state.endpoint,
|
||||
self.state.role,
|
||||
password,
|
||||
self.state.secret,
|
||||
)
|
||||
@@ -165,13 +167,15 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
pub(crate) async fn validate_password_and_exchange(
|
||||
pool: &ThreadPool,
|
||||
endpoint: EndpointIdInt,
|
||||
role: RoleNameInt,
|
||||
password: &[u8],
|
||||
secret: AuthSecret,
|
||||
) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
|
||||
match secret {
|
||||
// perform scram authentication as both client and server to validate the keys
|
||||
AuthSecret::Scram(scram_secret) => {
|
||||
let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
|
||||
let outcome =
|
||||
crate::scram::exchange(pool, endpoint, role, &scram_secret, password).await?;
|
||||
|
||||
let client_key = match outcome {
|
||||
sasl::Outcome::Success(client_key) => client_key,
|
||||
|
||||
@@ -29,7 +29,7 @@ use crate::config::{
|
||||
};
|
||||
use crate::control_plane::locks::ApiLocks;
|
||||
use crate::http::health_server::AppMetrics;
|
||||
use crate::metrics::{Metrics, ServiceInfo, ThreadPoolMetrics};
|
||||
use crate::metrics::{Metrics, ServiceInfo};
|
||||
use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo};
|
||||
use crate::scram::threadpool::ThreadPool;
|
||||
use crate::serverless::cancel_set::CancelSet;
|
||||
@@ -114,8 +114,6 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
|
||||
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
|
||||
|
||||
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
|
||||
|
||||
// TODO: refactor these to use labels
|
||||
debug!("Version: {GIT_VERSION}");
|
||||
debug!("Build_tag: {BUILD_TAG}");
|
||||
@@ -284,7 +282,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
|
||||
http_config,
|
||||
authentication_config: AuthenticationConfig {
|
||||
jwks_cache: JwkCache::default(),
|
||||
thread_pool: ThreadPool::new(0),
|
||||
scram_thread_pool: ThreadPool::new(0),
|
||||
scram_protocol_timeout: Duration::from_secs(10),
|
||||
ip_allowlist_check_enabled: true,
|
||||
is_vpc_acccess_proxy: false,
|
||||
|
||||
@@ -26,7 +26,7 @@ use utils::project_git_version;
|
||||
use utils::sentry_init::init_sentry;
|
||||
|
||||
use crate::context::RequestContext;
|
||||
use crate::metrics::{Metrics, ServiceInfo, ThreadPoolMetrics};
|
||||
use crate::metrics::{Metrics, ServiceInfo};
|
||||
use crate::pglb::TlsRequired;
|
||||
use crate::pqproto::FeStartupPacket;
|
||||
use crate::protocol2::ConnectionInfo;
|
||||
@@ -80,8 +80,6 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
|
||||
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
|
||||
|
||||
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
|
||||
|
||||
let args = cli().get_matches();
|
||||
let destination: String = args
|
||||
.get_one::<String>("dest")
|
||||
|
||||
@@ -617,7 +617,12 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
/// ProxyConfig is created at proxy startup, and lives forever.
|
||||
fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
let thread_pool = ThreadPool::new(args.scram_thread_pool_size);
|
||||
Metrics::install(thread_pool.metrics.clone());
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.scram_pool
|
||||
.0
|
||||
.set(thread_pool.metrics.clone())
|
||||
.ok();
|
||||
|
||||
let tls_config = match (&args.tls_key, &args.tls_cert) {
|
||||
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(
|
||||
@@ -690,7 +695,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
};
|
||||
let authentication_config = AuthenticationConfig {
|
||||
jwks_cache: JwkCache::default(),
|
||||
thread_pool,
|
||||
scram_thread_pool: thread_pool,
|
||||
scram_protocol_timeout: args.scram_protocol_timeout,
|
||||
ip_allowlist_check_enabled: !args.is_private_access_proxy,
|
||||
is_vpc_acccess_proxy: args.is_private_access_proxy,
|
||||
|
||||
@@ -19,7 +19,7 @@ use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
|
||||
use crate::ext::TaskExt;
|
||||
use crate::intern::RoleNameInt;
|
||||
use crate::rate_limiter::{RateLimitAlgorithm, RateLimiterConfig};
|
||||
use crate::scram::threadpool::ThreadPool;
|
||||
use crate::scram;
|
||||
use crate::serverless::GlobalConnPoolOptions;
|
||||
use crate::serverless::cancel_set::CancelSet;
|
||||
#[cfg(feature = "rest_broker")]
|
||||
@@ -75,7 +75,7 @@ pub struct HttpConfig {
|
||||
}
|
||||
|
||||
pub struct AuthenticationConfig {
|
||||
pub thread_pool: Arc<ThreadPool>,
|
||||
pub scram_thread_pool: Arc<scram::threadpool::ThreadPool>,
|
||||
pub scram_protocol_timeout: tokio::time::Duration,
|
||||
pub ip_allowlist_check_enabled: bool,
|
||||
pub is_vpc_acccess_proxy: bool,
|
||||
|
||||
@@ -5,6 +5,7 @@ use measured::label::{
|
||||
FixedCardinalitySet, LabelGroupSet, LabelGroupVisitor, LabelName, LabelSet, LabelValue,
|
||||
StaticLabelSet,
|
||||
};
|
||||
use measured::metric::group::Encoding;
|
||||
use measured::metric::histogram::Thresholds;
|
||||
use measured::metric::name::MetricName;
|
||||
use measured::{
|
||||
@@ -18,10 +19,10 @@ use crate::control_plane::messages::ColdStartInfo;
|
||||
use crate::error::ErrorKind;
|
||||
|
||||
#[derive(MetricGroup)]
|
||||
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>))]
|
||||
#[metric(new())]
|
||||
pub struct Metrics {
|
||||
#[metric(namespace = "proxy")]
|
||||
#[metric(init = ProxyMetrics::new(thread_pool))]
|
||||
#[metric(init = ProxyMetrics::new())]
|
||||
pub proxy: ProxyMetrics,
|
||||
|
||||
#[metric(namespace = "wake_compute_lock")]
|
||||
@@ -34,34 +35,27 @@ pub struct Metrics {
|
||||
pub cache: CacheMetrics,
|
||||
}
|
||||
|
||||
static SELF: OnceLock<Metrics> = OnceLock::new();
|
||||
impl Metrics {
|
||||
pub fn install(thread_pool: Arc<ThreadPoolMetrics>) {
|
||||
let mut metrics = Metrics::new(thread_pool);
|
||||
|
||||
metrics.proxy.errors_total.init_all_dense();
|
||||
metrics.proxy.redis_errors_total.init_all_dense();
|
||||
metrics.proxy.redis_events_count.init_all_dense();
|
||||
metrics.proxy.retries_metric.init_all_dense();
|
||||
metrics.proxy.connection_failures_total.init_all_dense();
|
||||
|
||||
SELF.set(metrics)
|
||||
.ok()
|
||||
.expect("proxy metrics must not be installed more than once");
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
pub fn get() -> &'static Self {
|
||||
#[cfg(test)]
|
||||
return SELF.get_or_init(|| Metrics::new(Arc::new(ThreadPoolMetrics::new(0))));
|
||||
static SELF: OnceLock<Metrics> = OnceLock::new();
|
||||
|
||||
#[cfg(not(test))]
|
||||
SELF.get()
|
||||
.expect("proxy metrics must be installed by the main() function")
|
||||
SELF.get_or_init(|| {
|
||||
let mut metrics = Metrics::new();
|
||||
|
||||
metrics.proxy.errors_total.init_all_dense();
|
||||
metrics.proxy.redis_errors_total.init_all_dense();
|
||||
metrics.proxy.redis_events_count.init_all_dense();
|
||||
metrics.proxy.retries_metric.init_all_dense();
|
||||
metrics.proxy.connection_failures_total.init_all_dense();
|
||||
|
||||
metrics
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(MetricGroup)]
|
||||
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>))]
|
||||
#[metric(new())]
|
||||
pub struct ProxyMetrics {
|
||||
#[metric(flatten)]
|
||||
pub db_connections: CounterPairVec<NumDbConnectionsGauge>,
|
||||
@@ -134,6 +128,9 @@ pub struct ProxyMetrics {
|
||||
/// Number of TLS handshake failures
|
||||
pub tls_handshake_failures: Counter,
|
||||
|
||||
/// Number of SHA 256 rounds executed.
|
||||
pub sha_rounds: Counter,
|
||||
|
||||
/// HLL approximate cardinality of endpoints that are connecting
|
||||
pub connecting_endpoints: HyperLogLogVec<StaticLabelSet<Protocol>, 32>,
|
||||
|
||||
@@ -151,8 +148,25 @@ pub struct ProxyMetrics {
|
||||
pub connect_compute_lock: ApiLockMetrics,
|
||||
|
||||
#[metric(namespace = "scram_pool")]
|
||||
#[metric(init = thread_pool)]
|
||||
pub scram_pool: Arc<ThreadPoolMetrics>,
|
||||
pub scram_pool: OnceLockWrapper<Arc<ThreadPoolMetrics>>,
|
||||
}
|
||||
|
||||
/// A Wrapper over [`OnceLock`] to implement [`MetricGroup`].
|
||||
pub struct OnceLockWrapper<T>(pub OnceLock<T>);
|
||||
|
||||
impl<T> Default for OnceLockWrapper<T> {
|
||||
fn default() -> Self {
|
||||
Self(OnceLock::new())
|
||||
}
|
||||
}
|
||||
|
||||
impl<Enc: Encoding, T: MetricGroup<Enc>> MetricGroup<Enc> for OnceLockWrapper<T> {
|
||||
fn collect_group_into(&self, enc: &mut Enc) -> Result<(), Enc::Err> {
|
||||
if let Some(inner) = self.0.get() {
|
||||
inner.collect_group_into(enc)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(MetricGroup)]
|
||||
@@ -553,14 +567,6 @@ impl From<bool> for Bool {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(LabelGroup)]
|
||||
#[label(set = InvalidEndpointsSet)]
|
||||
pub struct InvalidEndpointsGroup {
|
||||
pub protocol: Protocol,
|
||||
pub rejected: Bool,
|
||||
pub outcome: ConnectOutcome,
|
||||
}
|
||||
|
||||
#[derive(LabelGroup)]
|
||||
#[label(set = RetriesMetricSet)]
|
||||
pub struct RetriesMetricGroup {
|
||||
@@ -727,6 +733,7 @@ pub enum CacheKind {
|
||||
ProjectInfoEndpoints,
|
||||
ProjectInfoRoles,
|
||||
Schema,
|
||||
Pbkdf2,
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]
|
||||
|
||||
84
proxy/src/scram/cache.rs
Normal file
84
proxy/src/scram/cache.rs
Normal file
@@ -0,0 +1,84 @@
|
||||
use tokio::time::Instant;
|
||||
use zeroize::Zeroize as _;
|
||||
|
||||
use super::pbkdf2;
|
||||
use crate::cache::Cached;
|
||||
use crate::cache::common::{Cache, count_cache_insert, count_cache_outcome, eviction_listener};
|
||||
use crate::intern::{EndpointIdInt, RoleNameInt};
|
||||
use crate::metrics::{CacheKind, Metrics};
|
||||
|
||||
pub(crate) struct Pbkdf2Cache(moka::sync::Cache<(EndpointIdInt, RoleNameInt), Pbkdf2CacheEntry>);
|
||||
pub(crate) type CachedPbkdf2<'a> = Cached<&'a Pbkdf2Cache>;
|
||||
|
||||
impl Cache for Pbkdf2Cache {
|
||||
type Key = (EndpointIdInt, RoleNameInt);
|
||||
type Value = Pbkdf2CacheEntry;
|
||||
|
||||
fn invalidate(&self, info: &(EndpointIdInt, RoleNameInt)) {
|
||||
self.0.invalidate(info);
|
||||
}
|
||||
}
|
||||
|
||||
/// To speed up password hashing for more active customers, we store the tail results of the
|
||||
/// PBKDF2 algorithm. If the output of PBKDF2 is U1 ^ U2 ^ ⋯ ^ Uc, then we store
|
||||
/// suffix = U17 ^ U18 ^ ⋯ ^ Uc. We only need to calculate U1 ^ U2 ^ ⋯ ^ U15 ^ U16
|
||||
/// to determine the final result.
|
||||
///
|
||||
/// The suffix alone isn't enough to crack the password. The stored_key is still required.
|
||||
/// While both are cached in memory, given they're in different locations is makes it much
|
||||
/// harder to exploit, even if any such memory exploit exists in proxy.
|
||||
#[derive(Clone)]
|
||||
pub struct Pbkdf2CacheEntry {
|
||||
/// corresponds to [`super::ServerSecret::cached_at`]
|
||||
pub(super) cached_from: Instant,
|
||||
pub(super) suffix: pbkdf2::Block,
|
||||
}
|
||||
|
||||
impl Drop for Pbkdf2CacheEntry {
|
||||
fn drop(&mut self) {
|
||||
self.suffix.zeroize();
|
||||
}
|
||||
}
|
||||
|
||||
impl Pbkdf2Cache {
|
||||
pub fn new() -> Self {
|
||||
const SIZE: u64 = 100;
|
||||
const TTL: std::time::Duration = std::time::Duration::from_secs(60);
|
||||
|
||||
let builder = moka::sync::Cache::builder()
|
||||
.name("pbkdf2")
|
||||
.max_capacity(SIZE)
|
||||
// We use time_to_live so we don't refresh the lifetime for an invalid password attempt.
|
||||
.time_to_live(TTL);
|
||||
|
||||
Metrics::get()
|
||||
.cache
|
||||
.capacity
|
||||
.set(CacheKind::Pbkdf2, SIZE as i64);
|
||||
|
||||
let builder =
|
||||
builder.eviction_listener(|_k, _v, cause| eviction_listener(CacheKind::Pbkdf2, cause));
|
||||
|
||||
Self(builder.build())
|
||||
}
|
||||
|
||||
pub fn insert(&self, endpoint: EndpointIdInt, role: RoleNameInt, value: Pbkdf2CacheEntry) {
|
||||
count_cache_insert(CacheKind::Pbkdf2);
|
||||
self.0.insert((endpoint, role), value);
|
||||
}
|
||||
|
||||
fn get(&self, endpoint: EndpointIdInt, role: RoleNameInt) -> Option<Pbkdf2CacheEntry> {
|
||||
count_cache_outcome(CacheKind::Pbkdf2, self.0.get(&(endpoint, role)))
|
||||
}
|
||||
|
||||
pub fn get_entry(
|
||||
&self,
|
||||
endpoint: EndpointIdInt,
|
||||
role: RoleNameInt,
|
||||
) -> Option<CachedPbkdf2<'_>> {
|
||||
self.get(endpoint, role).map(|value| Cached {
|
||||
token: Some((self, (endpoint, role))),
|
||||
value,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,10 +4,8 @@ use std::convert::Infallible;
|
||||
|
||||
use base64::Engine as _;
|
||||
use base64::prelude::BASE64_STANDARD;
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
use tracing::{debug, trace};
|
||||
|
||||
use super::ScramKey;
|
||||
use super::messages::{
|
||||
ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
|
||||
};
|
||||
@@ -15,8 +13,10 @@ use super::pbkdf2::Pbkdf2;
|
||||
use super::secret::ServerSecret;
|
||||
use super::signature::SignatureBuilder;
|
||||
use super::threadpool::ThreadPool;
|
||||
use crate::intern::EndpointIdInt;
|
||||
use super::{ScramKey, pbkdf2};
|
||||
use crate::intern::{EndpointIdInt, RoleNameInt};
|
||||
use crate::sasl::{self, ChannelBinding, Error as SaslError};
|
||||
use crate::scram::cache::Pbkdf2CacheEntry;
|
||||
|
||||
/// The only channel binding mode we currently support.
|
||||
#[derive(Debug)]
|
||||
@@ -77,46 +77,113 @@ impl<'a> Exchange<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
// copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L236-L248>
|
||||
async fn derive_client_key(
|
||||
pool: &ThreadPool,
|
||||
endpoint: EndpointIdInt,
|
||||
password: &[u8],
|
||||
salt: &[u8],
|
||||
iterations: u32,
|
||||
) -> ScramKey {
|
||||
let salted_password = pool
|
||||
.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
|
||||
.await;
|
||||
|
||||
let make_key = |name| {
|
||||
let key = Hmac::<Sha256>::new_from_slice(&salted_password)
|
||||
.expect("HMAC is able to accept all key sizes")
|
||||
.chain_update(name)
|
||||
.finalize();
|
||||
|
||||
<[u8; 32]>::from(key.into_bytes())
|
||||
};
|
||||
|
||||
make_key(b"Client Key").into()
|
||||
) -> pbkdf2::Block {
|
||||
pool.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
|
||||
.await
|
||||
}
|
||||
|
||||
/// For cleartext flow, we need to derive the client key to
|
||||
/// 1. authenticate the client.
|
||||
/// 2. authenticate with compute.
|
||||
pub(crate) async fn exchange(
|
||||
pool: &ThreadPool,
|
||||
endpoint: EndpointIdInt,
|
||||
role: RoleNameInt,
|
||||
secret: &ServerSecret,
|
||||
password: &[u8],
|
||||
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
|
||||
if secret.iterations > CACHED_ROUNDS {
|
||||
exchange_with_cache(pool, endpoint, role, secret, password).await
|
||||
} else {
|
||||
let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?;
|
||||
let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
|
||||
Ok(validate_pbkdf2(secret, &hash))
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the client key using a cache. We cache the suffix of the pbkdf2 result only,
|
||||
/// which is not enough by itself to perform an offline brute force.
|
||||
async fn exchange_with_cache(
|
||||
pool: &ThreadPool,
|
||||
endpoint: EndpointIdInt,
|
||||
role: RoleNameInt,
|
||||
secret: &ServerSecret,
|
||||
password: &[u8],
|
||||
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
|
||||
let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?;
|
||||
let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
|
||||
|
||||
debug_assert!(
|
||||
secret.iterations > CACHED_ROUNDS,
|
||||
"we should not cache password data if there isn't enough rounds needed"
|
||||
);
|
||||
|
||||
// compute the prefix of the pbkdf2 output.
|
||||
let prefix = derive_client_key(pool, endpoint, password, &salt, CACHED_ROUNDS).await;
|
||||
|
||||
if let Some(entry) = pool.cache.get_entry(endpoint, role) {
|
||||
// hot path: let's check the threadpool cache
|
||||
if secret.cached_at == entry.cached_from {
|
||||
// cache is valid. compute the full hash by adding the prefix to the suffix.
|
||||
let mut hash = prefix;
|
||||
pbkdf2::xor_assign(&mut hash, &entry.suffix);
|
||||
let outcome = validate_pbkdf2(secret, &hash);
|
||||
|
||||
if matches!(outcome, sasl::Outcome::Success(_)) {
|
||||
trace!("password validated from cache");
|
||||
}
|
||||
|
||||
return Ok(outcome);
|
||||
}
|
||||
|
||||
// cached key is no longer valid.
|
||||
debug!("invalidating cached password");
|
||||
entry.invalidate();
|
||||
}
|
||||
|
||||
// slow path: full password hash.
|
||||
let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
|
||||
let outcome = validate_pbkdf2(secret, &hash);
|
||||
|
||||
let client_key = match outcome {
|
||||
sasl::Outcome::Success(client_key) => client_key,
|
||||
sasl::Outcome::Failure(_) => return Ok(outcome),
|
||||
};
|
||||
|
||||
trace!("storing cached password");
|
||||
|
||||
// time to cache, compute the suffix by subtracting the prefix from the hash.
|
||||
let mut suffix = hash;
|
||||
pbkdf2::xor_assign(&mut suffix, &prefix);
|
||||
|
||||
pool.cache.insert(
|
||||
endpoint,
|
||||
role,
|
||||
Pbkdf2CacheEntry {
|
||||
cached_from: secret.cached_at,
|
||||
suffix,
|
||||
},
|
||||
);
|
||||
|
||||
Ok(sasl::Outcome::Success(client_key))
|
||||
}
|
||||
|
||||
fn validate_pbkdf2(secret: &ServerSecret, hash: &pbkdf2::Block) -> sasl::Outcome<ScramKey> {
|
||||
let client_key = super::ScramKey::client_key(&(*hash).into());
|
||||
if secret.is_password_invalid(&client_key).into() {
|
||||
Ok(sasl::Outcome::Failure("password doesn't match"))
|
||||
sasl::Outcome::Failure("password doesn't match")
|
||||
} else {
|
||||
Ok(sasl::Outcome::Success(client_key))
|
||||
sasl::Outcome::Success(client_key)
|
||||
}
|
||||
}
|
||||
|
||||
const CACHED_ROUNDS: u32 = 16;
|
||||
|
||||
impl SaslInitial {
|
||||
fn transition(
|
||||
&self,
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
//! Tools for client/server/stored key management.
|
||||
|
||||
use hmac::Mac as _;
|
||||
use sha2::Digest as _;
|
||||
use subtle::ConstantTimeEq;
|
||||
use zeroize::Zeroize as _;
|
||||
|
||||
use crate::metrics::Metrics;
|
||||
use crate::scram::pbkdf2::Prf;
|
||||
|
||||
/// Faithfully taken from PostgreSQL.
|
||||
pub(crate) const SCRAM_KEY_LEN: usize = 32;
|
||||
@@ -14,6 +20,12 @@ pub(crate) struct ScramKey {
|
||||
bytes: [u8; SCRAM_KEY_LEN],
|
||||
}
|
||||
|
||||
impl Drop for ScramKey {
|
||||
fn drop(&mut self) {
|
||||
self.bytes.zeroize();
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for ScramKey {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.ct_eq(other).into()
|
||||
@@ -28,12 +40,26 @@ impl ConstantTimeEq for ScramKey {
|
||||
|
||||
impl ScramKey {
|
||||
pub(crate) fn sha256(&self) -> Self {
|
||||
super::sha256([self.as_ref()]).into()
|
||||
Metrics::get().proxy.sha_rounds.inc_by(1);
|
||||
Self {
|
||||
bytes: sha2::Sha256::digest(self.as_bytes()).into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn as_bytes(&self) -> [u8; SCRAM_KEY_LEN] {
|
||||
self.bytes
|
||||
}
|
||||
|
||||
pub(crate) fn client_key(b: &[u8; 32]) -> Self {
|
||||
// Prf::new_from_slice will run 2 sha256 rounds.
|
||||
// Update + Finalize run 2 sha256 rounds.
|
||||
Metrics::get().proxy.sha_rounds.inc_by(4);
|
||||
|
||||
let mut prf = Prf::new_from_slice(b).expect("HMAC is able to accept all key sizes");
|
||||
prf.update(b"Client Key");
|
||||
let client_key: [u8; 32] = prf.finalize().into_bytes().into();
|
||||
client_key.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u8; SCRAM_KEY_LEN]> for ScramKey {
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-scram.c>
|
||||
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth-scram.c>
|
||||
|
||||
mod cache;
|
||||
mod countmin;
|
||||
mod exchange;
|
||||
mod key;
|
||||
@@ -18,10 +19,8 @@ pub mod threadpool;
|
||||
use base64::Engine as _;
|
||||
use base64::prelude::BASE64_STANDARD;
|
||||
pub(crate) use exchange::{Exchange, exchange};
|
||||
use hmac::{Hmac, Mac};
|
||||
pub(crate) use key::ScramKey;
|
||||
pub(crate) use secret::ServerSecret;
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
|
||||
const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
|
||||
@@ -42,29 +41,13 @@ fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N
|
||||
Some(bytes)
|
||||
}
|
||||
|
||||
/// This function essentially is `Hmac(sha256, key, input)`.
|
||||
/// Further reading: <https://datatracker.ietf.org/doc/html/rfc2104>.
|
||||
fn hmac_sha256<'a>(key: &[u8], parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
|
||||
let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("bad key size");
|
||||
parts.into_iter().for_each(|s| mac.update(s));
|
||||
|
||||
mac.finalize().into_bytes().into()
|
||||
}
|
||||
|
||||
fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
|
||||
let mut hasher = Sha256::new();
|
||||
parts.into_iter().for_each(|s| hasher.update(s));
|
||||
|
||||
hasher.finalize().into()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::threadpool::ThreadPool;
|
||||
use super::{Exchange, ServerSecret};
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::intern::{EndpointIdInt, RoleNameInt};
|
||||
use crate::sasl::{Mechanism, Step};
|
||||
use crate::types::EndpointId;
|
||||
use crate::types::{EndpointId, RoleName};
|
||||
|
||||
#[test]
|
||||
fn snapshot() {
|
||||
@@ -114,23 +97,34 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
async fn run_round_trip_test(server_password: &str, client_password: &str) {
|
||||
let pool = ThreadPool::new(1);
|
||||
|
||||
async fn check(
|
||||
pool: &ThreadPool,
|
||||
scram_secret: &ServerSecret,
|
||||
password: &[u8],
|
||||
) -> Result<(), &'static str> {
|
||||
let ep = EndpointId::from("foo");
|
||||
let ep = EndpointIdInt::from(ep);
|
||||
let role = RoleName::from("user");
|
||||
let role = RoleNameInt::from(&role);
|
||||
|
||||
let scram_secret = ServerSecret::build(server_password).await.unwrap();
|
||||
let outcome = super::exchange(&pool, ep, &scram_secret, client_password.as_bytes())
|
||||
let outcome = super::exchange(pool, ep, role, scram_secret, password)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match outcome {
|
||||
crate::sasl::Outcome::Success(_) => {}
|
||||
crate::sasl::Outcome::Failure(r) => panic!("{r}"),
|
||||
crate::sasl::Outcome::Success(_) => Ok(()),
|
||||
crate::sasl::Outcome::Failure(r) => Err(r),
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_round_trip_test(server_password: &str, client_password: &str) {
|
||||
let pool = ThreadPool::new(1);
|
||||
let scram_secret = ServerSecret::build(server_password).await.unwrap();
|
||||
check(&pool, &scram_secret, client_password.as_bytes())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn round_trip() {
|
||||
run_round_trip_test("pencil", "pencil").await;
|
||||
@@ -141,4 +135,27 @@ mod tests {
|
||||
async fn failure() {
|
||||
run_round_trip_test("pencil", "eraser").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[tracing_test::traced_test]
|
||||
async fn password_cache() {
|
||||
let pool = ThreadPool::new(1);
|
||||
let scram_secret = ServerSecret::build("password").await.unwrap();
|
||||
|
||||
// wrong passwords are not added to cache
|
||||
check(&pool, &scram_secret, b"wrong").await.unwrap_err();
|
||||
assert!(!logs_contain("storing cached password"));
|
||||
|
||||
// correct passwords get cached
|
||||
check(&pool, &scram_secret, b"password").await.unwrap();
|
||||
assert!(logs_contain("storing cached password"));
|
||||
|
||||
// wrong passwords do not match the cache
|
||||
check(&pool, &scram_secret, b"wrong").await.unwrap_err();
|
||||
assert!(!logs_contain("password validated from cache"));
|
||||
|
||||
// correct passwords match the cache
|
||||
check(&pool, &scram_secret, b"password").await.unwrap();
|
||||
assert!(logs_contain("password validated from cache"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,25 +1,50 @@
|
||||
//! For postgres password authentication, we need to perform a PBKDF2 using
|
||||
//! PRF=HMAC-SHA2-256, producing only 1 block (32 bytes) of output key.
|
||||
|
||||
use hmac::Mac as _;
|
||||
use hmac::digest::consts::U32;
|
||||
use hmac::digest::generic_array::GenericArray;
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
use zeroize::Zeroize as _;
|
||||
|
||||
use crate::metrics::Metrics;
|
||||
|
||||
/// The Psuedo-random function used during PBKDF2 and the SCRAM-SHA-256 handshake.
|
||||
pub type Prf = hmac::Hmac<sha2::Sha256>;
|
||||
pub(crate) type Block = GenericArray<u8, U32>;
|
||||
|
||||
pub(crate) struct Pbkdf2 {
|
||||
hmac: Hmac<Sha256>,
|
||||
prev: GenericArray<u8, U32>,
|
||||
hi: GenericArray<u8, U32>,
|
||||
hmac: Prf,
|
||||
/// U{r-1} for whatever iteration r we are currently on.
|
||||
prev: Block,
|
||||
/// the output of `fold(xor, U{1}..U{r})` for whatever iteration r we are currently on.
|
||||
hi: Block,
|
||||
/// number of iterations left
|
||||
iterations: u32,
|
||||
}
|
||||
|
||||
impl Drop for Pbkdf2 {
|
||||
fn drop(&mut self) {
|
||||
self.prev.zeroize();
|
||||
self.hi.zeroize();
|
||||
}
|
||||
}
|
||||
|
||||
// inspired from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L36-L61>
|
||||
impl Pbkdf2 {
|
||||
pub(crate) fn start(str: &[u8], salt: &[u8], iterations: u32) -> Self {
|
||||
pub(crate) fn start(pw: &[u8], salt: &[u8], iterations: u32) -> Self {
|
||||
// key the HMAC and derive the first block in-place
|
||||
let mut hmac =
|
||||
Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
|
||||
let mut hmac = Prf::new_from_slice(pw).expect("HMAC is able to accept all key sizes");
|
||||
|
||||
// U1 = PRF(Password, Salt + INT_32_BE(i))
|
||||
// i = 1 since we only need 1 block of output.
|
||||
hmac.update(salt);
|
||||
hmac.update(&1u32.to_be_bytes());
|
||||
let init_block = hmac.finalize_reset().into_bytes();
|
||||
|
||||
// Prf::new_from_slice will run 2 sha256 rounds.
|
||||
// Our update + finalize run 2 sha256 rounds for each pbkdf2 round.
|
||||
Metrics::get().proxy.sha_rounds.inc_by(4);
|
||||
|
||||
Self {
|
||||
hmac,
|
||||
// one iteration spent above
|
||||
@@ -33,7 +58,11 @@ impl Pbkdf2 {
|
||||
(self.iterations).clamp(0, 4096)
|
||||
}
|
||||
|
||||
pub(crate) fn turn(&mut self) -> std::task::Poll<[u8; 32]> {
|
||||
/// For "fairness", we implement PBKDF2 with cooperative yielding, which is why we use this `turn`
|
||||
/// function that only executes a fixed number of iterations before continuing.
|
||||
///
|
||||
/// Task must be rescheuled if this returns [`std::task::Poll::Pending`].
|
||||
pub(crate) fn turn(&mut self) -> std::task::Poll<Block> {
|
||||
let Self {
|
||||
hmac,
|
||||
prev,
|
||||
@@ -44,25 +73,37 @@ impl Pbkdf2 {
|
||||
// only do up to 4096 iterations per turn for fairness
|
||||
let n = (*iterations).clamp(0, 4096);
|
||||
for _ in 0..n {
|
||||
hmac.update(prev);
|
||||
let block = hmac.finalize_reset().into_bytes();
|
||||
|
||||
for (hi_byte, &b) in hi.iter_mut().zip(block.iter()) {
|
||||
*hi_byte ^= b;
|
||||
}
|
||||
|
||||
*prev = block;
|
||||
let next = single_round(hmac, prev);
|
||||
xor_assign(hi, &next);
|
||||
*prev = next;
|
||||
}
|
||||
|
||||
// Our update + finalize run 2 sha256 rounds for each pbkdf2 round.
|
||||
Metrics::get().proxy.sha_rounds.inc_by(2 * n as u64);
|
||||
|
||||
*iterations -= n;
|
||||
if *iterations == 0 {
|
||||
std::task::Poll::Ready((*hi).into())
|
||||
std::task::Poll::Ready(*hi)
|
||||
} else {
|
||||
std::task::Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn xor_assign(x: &mut Block, y: &Block) {
|
||||
for (x, &y) in std::iter::zip(x, y) {
|
||||
*x ^= y;
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn single_round(prf: &mut Prf, ui: &Block) -> Block {
|
||||
// Ui = PRF(Password, Ui-1)
|
||||
prf.update(ui);
|
||||
prf.finalize_reset().into_bytes()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use pbkdf2::pbkdf2_hmac_array;
|
||||
@@ -76,11 +117,11 @@ mod tests {
|
||||
let pass = b"Ne0n_!5_50_C007";
|
||||
|
||||
let mut job = Pbkdf2::start(pass, salt, 60000);
|
||||
let hash = loop {
|
||||
let hash: [u8; 32] = loop {
|
||||
let std::task::Poll::Ready(hash) = job.turn() else {
|
||||
continue;
|
||||
};
|
||||
break hash;
|
||||
break hash.into();
|
||||
};
|
||||
|
||||
let expected = pbkdf2_hmac_array::<Sha256, 32>(pass, salt, 60000);
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
use base64::Engine as _;
|
||||
use base64::prelude::BASE64_STANDARD;
|
||||
use subtle::{Choice, ConstantTimeEq};
|
||||
use tokio::time::Instant;
|
||||
|
||||
use super::base64_decode_array;
|
||||
use super::key::ScramKey;
|
||||
@@ -11,6 +12,9 @@ use super::key::ScramKey;
|
||||
/// and is used throughout the authentication process.
|
||||
#[derive(Clone, Eq, PartialEq, Debug)]
|
||||
pub(crate) struct ServerSecret {
|
||||
/// When this secret was cached.
|
||||
pub(crate) cached_at: Instant,
|
||||
|
||||
/// Number of iterations for `PBKDF2` function.
|
||||
pub(crate) iterations: u32,
|
||||
/// Salt used to hash user's password.
|
||||
@@ -34,6 +38,7 @@ impl ServerSecret {
|
||||
params.split_once(':').zip(keys.split_once(':'))?;
|
||||
|
||||
let secret = ServerSecret {
|
||||
cached_at: Instant::now(),
|
||||
iterations: iterations.parse().ok()?,
|
||||
salt_base64: salt.into(),
|
||||
stored_key: base64_decode_array(stored_key)?.into(),
|
||||
@@ -54,6 +59,7 @@ impl ServerSecret {
|
||||
/// See `auth-scram.c : mock_scram_secret` for details.
|
||||
pub(crate) fn mock(nonce: [u8; 32]) -> Self {
|
||||
Self {
|
||||
cached_at: Instant::now(),
|
||||
// this doesn't reveal much information as we're going to use
|
||||
// iteration count 1 for our generated passwords going forward.
|
||||
// PG16 users can set iteration count=1 already today.
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
//! Tools for client/server signature management.
|
||||
|
||||
use hmac::Mac as _;
|
||||
|
||||
use super::key::{SCRAM_KEY_LEN, ScramKey};
|
||||
use crate::metrics::Metrics;
|
||||
use crate::scram::pbkdf2::Prf;
|
||||
|
||||
/// A collection of message parts needed to derive the client's signature.
|
||||
#[derive(Debug)]
|
||||
@@ -12,15 +16,18 @@ pub(crate) struct SignatureBuilder<'a> {
|
||||
|
||||
impl SignatureBuilder<'_> {
|
||||
pub(crate) fn build(&self, key: &ScramKey) -> Signature {
|
||||
let parts = [
|
||||
self.client_first_message_bare.as_bytes(),
|
||||
b",",
|
||||
self.server_first_message.as_bytes(),
|
||||
b",",
|
||||
self.client_final_message_without_proof.as_bytes(),
|
||||
];
|
||||
// don't know exactly. this is a rough approx
|
||||
Metrics::get().proxy.sha_rounds.inc_by(8);
|
||||
|
||||
super::hmac_sha256(key.as_ref(), parts).into()
|
||||
let mut mac = Prf::new_from_slice(key.as_ref()).expect("HMAC accepts all key sizes");
|
||||
mac.update(self.client_first_message_bare.as_bytes());
|
||||
mac.update(b",");
|
||||
mac.update(self.server_first_message.as_bytes());
|
||||
mac.update(b",");
|
||||
mac.update(self.client_final_message_without_proof.as_bytes());
|
||||
Signature {
|
||||
bytes: mac.finalize().into_bytes().into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,8 @@ use futures::FutureExt;
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
use super::cache::Pbkdf2Cache;
|
||||
use super::pbkdf2;
|
||||
use super::pbkdf2::Pbkdf2;
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::metrics::{ThreadPoolMetrics, ThreadPoolWorkerId};
|
||||
@@ -23,6 +25,10 @@ use crate::scram::countmin::CountMinSketch;
|
||||
pub struct ThreadPool {
|
||||
runtime: Option<tokio::runtime::Runtime>,
|
||||
pub metrics: Arc<ThreadPoolMetrics>,
|
||||
|
||||
// we hash a lot of passwords.
|
||||
// we keep a cache of partial hashes for faster validation.
|
||||
pub(super) cache: Pbkdf2Cache,
|
||||
}
|
||||
|
||||
/// How often to reset the sketch values
|
||||
@@ -68,6 +74,7 @@ impl ThreadPool {
|
||||
Self {
|
||||
runtime: Some(runtime),
|
||||
metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
|
||||
cache: Pbkdf2Cache::new(),
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -130,7 +137,7 @@ struct JobSpec {
|
||||
}
|
||||
|
||||
impl Future for JobSpec {
|
||||
type Output = [u8; 32];
|
||||
type Output = pbkdf2::Block;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
STATE.with_borrow_mut(|state| {
|
||||
@@ -166,10 +173,10 @@ impl Future for JobSpec {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct JobHandle(tokio::task::JoinHandle<[u8; 32]>);
|
||||
pub(crate) struct JobHandle(tokio::task::JoinHandle<pbkdf2::Block>);
|
||||
|
||||
impl Future for JobHandle {
|
||||
type Output = [u8; 32];
|
||||
type Output = pbkdf2::Block;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match self.0.poll_unpin(cx) {
|
||||
@@ -203,10 +210,10 @@ mod tests {
|
||||
.spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096))
|
||||
.await;
|
||||
|
||||
let expected = [
|
||||
let expected = &[
|
||||
10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242,
|
||||
178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140,
|
||||
];
|
||||
assert_eq!(actual, expected);
|
||||
assert_eq!(actual.as_slice(), expected);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ use crate::context::RequestContext;
|
||||
use crate::control_plane::client::ApiLockError;
|
||||
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::intern::{EndpointIdInt, RoleNameInt};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::{connect_auth, connect_compute};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
@@ -76,9 +76,11 @@ impl PoolingBackend {
|
||||
};
|
||||
|
||||
let ep = EndpointIdInt::from(&user_info.endpoint);
|
||||
let role = RoleNameInt::from(&user_info.user);
|
||||
let auth_outcome = crate::auth::validate_password_and_exchange(
|
||||
&self.config.authentication_config.thread_pool,
|
||||
&self.config.authentication_config.scram_thread_pool,
|
||||
ep,
|
||||
role,
|
||||
password,
|
||||
secret,
|
||||
)
|
||||
|
||||
@@ -102,7 +102,7 @@ pub struct ReportedError {
|
||||
}
|
||||
|
||||
impl ReportedError {
|
||||
pub fn new(e: (impl UserFacingError + Into<anyhow::Error>)) -> Self {
|
||||
pub fn new(e: impl UserFacingError + Into<anyhow::Error>) -> Self {
|
||||
let error_kind = e.get_error_kind();
|
||||
Self {
|
||||
source: e.into(),
|
||||
|
||||
@@ -55,7 +55,7 @@ def test_neon_extension_compatibility(neon_env_builder: NeonEnvBuilder):
|
||||
# Ensure that the default version is also updated in the neon.control file
|
||||
assert cur.fetchone() == ("1.6",)
|
||||
cur.execute("SELECT * from neon.NEON_STAT_FILE_CACHE")
|
||||
all_versions = ["1.7", "1.6", "1.5", "1.4", "1.3", "1.2", "1.1", "1.0"]
|
||||
all_versions = ["1.6", "1.5", "1.4", "1.3", "1.2", "1.1", "1.0"]
|
||||
current_version = "1.6"
|
||||
for idx, begin_version in enumerate(all_versions):
|
||||
for target_version in all_versions[idx + 1 :]:
|
||||
|
||||
Reference in New Issue
Block a user