Compare commits

..

64 Commits

Author SHA1 Message Date
Ruslan Talpa
ee7d8e4512 revert pg-14 submodule changes 2025-07-04 13:55:35 +03:00
Ruslan Talpa
6549708b44 change subzero dep sha 2025-07-04 13:39:49 +03:00
Ruslan Talpa
45631bf2e5 add line to remove from diff 2025-07-04 13:28:52 +03:00
Ruslan Talpa
5dbca8c756 revert changes from original hack branch 2025-07-04 13:27:59 +03:00
Ruslan Talpa
9f46ca5eb1 Merge branch 'main' into ruslan/subzero-integration 2025-07-04 13:03:55 +03:00
Ruslan Talpa
54da030a2d place the entire rest_broker code under a feature flag 2025-07-04 12:46:48 +03:00
Ruslan Talpa
afa4e48071 put subzero dependency under a feature flag 2025-07-04 11:27:36 +03:00
Alex Chi Z.
cc699f6f85 fix(pageserver): do not log no-route-to-host errors (#12468)
## Problem

close https://github.com/neondatabase/neon/issues/12344

## Summary of changes

Add `HostUnreachable` and `NetworkUnreachable` to expected I/O error.
This was new in Rust 1.83.

Signed-off-by: Alex Chi Z <chi@neon.tech>
2025-07-03 21:57:42 +00:00
Konstantin Knizhnik
495112ca50 Add GUC for dynamically enable compare local mode (#12424)
## Problem

DEBUG_LOCAL_COMPARE mode allows to detect data corruption.
But it requires rebuild of neon extension (and so requires special
image) and significantly slowdown execution because always fetch pages
from page server.

## Summary of changes

Introduce new GUC `neon.debug_compare_local`, accepting the following
values: " none", "prefetch", "lfc", "all" (by default it is definitely
disabled).
In mode less than "all", neon SMGR will not fetch page from PS if it is
found in local caches.

Co-authored-by: Konstantin Knizhnik <knizhnik@neon.tech>
2025-07-03 17:37:05 +00:00
Suhas Thalanki
46158ee63f fix(compute): background installed extensions worker would collect data without waiting for interval (#12465)
## Problem

The background installed extensions worker relied on `interval.tick()`
to go to sleep for a period of time. This can lead to bugs due to the
interval being updated at the end of the loop as the first tick is
[instantaneous](https://docs.rs/tokio/latest/tokio/time/struct.Interval.html#method.tick).

## Summary of changes

Changed it to a `tokio::time::sleep` to prevent this issue. Now it puts
the thread to sleep and only wakes up after the specified duration
2025-07-03 17:10:30 +00:00
Alex Chi Z.
305fe61ac1 fix(pageserver): also print open layer size in backpressure (#12440)
## Problem

Better investigate memory usage during backpressure

## Summary of changes

Print open layer size if backpressure is activated

Signed-off-by: Alex Chi Z <chi@neon.tech>
2025-07-03 16:37:11 +00:00
Vlad Lazar
f95fdf5b44 pageserver: fix duplicate tombstones in ancestor detach (#12460)
## Problem

Ancestor detach from a previously detached parent when there were no
writes panics since it tries to upload the tombstone layer twice.

## Summary of Changes

If we're gonna copy the tombstone from the ancestor, don't bother
creating it.

Fixes https://github.com/neondatabase/neon/issues/12458
2025-07-03 16:35:46 +00:00
Arpad Müller
a852bc5e39 Add new activating scheduling policy for safekeepers (#12441)
When deploying new safekeepers, we don't immediately want to send
traffic to them. Maybe they are not ready yet by the time the deploy
script is registering them with the storage controller.

For pageservers, the storcon solves the problem by not scheduling stuff
to them unless there has been a positive heartbeat response. We can't do
the same for safekeepers though, otherwise a single down safekeeper
would mean we can't create new timelines in smaller regions where there
is only three safekeepers in total.

So far we have created safekeepers as `pause` but this adds a manual
step to safekeeper deployment which is prone to oversight. We want
things to be automatted. So we introduce a new state `activating` that
acts just like `pause`, except that we automatically transition the
policy to `active` once we get a positive heartbeat from the safekeeper.
For `pause`, we always keep the safekeeper paused.
2025-07-03 16:27:43 +00:00
Aleksandr Sarantsev
b96983a31c storcon: Ignore keep-failing reconciles (#12391)
## Problem

Currently, if `storcon` (storage controller) reconciliations repeatedly
fail, the system will indefinitely freeze optimizations. This can result
in optimization starvation for several days until the reconciliation
issues are manually resolved. To mitigate this, we should detect
persistently failing reconciliations and exclude them from influencing
the optimization decision.

## Summary of Changes

- A tenant shard reconciliation is now considered "keep-failing" if it
fails 5 consecutive times. These failures are excluded from the
optimization readiness check.
- Added a new metric: `storage_controller_keep_failing_reconciles` to
monitor such cases.
- Added a warning log message when a reconciliation is marked as
"keep-failing".

---------

Co-authored-by: Aleksandr Sarantsev <aleksandr.sarantsev@databricks.com>
2025-07-03 16:21:36 +00:00
Dmitrii Kovalkov
3ed28661b1 storcon: remote feature testing safekeeper quorum checks (#12459)
## Problem
Previous PR didn't fix the creation of timeline in neon_local with <3
safekeepers because there is one more check down the stack.

- Closes: https://github.com/neondatabase/neon/issues/12298
- Follow up on https://github.com/neondatabase/neon/pull/12378

## Summary of changes
- Remove feature `testing` safekeeper quorum checks from storcon

---------

Co-authored-by: Arpad Müller <arpad-m@users.noreply.github.com>
2025-07-03 15:02:30 +00:00
Conrad Ludgate
03e604e432 Nightly lints and small tweaks (#12456)
Let chains available in 1.88 :D new clippy lints coming up in future
releases.
2025-07-03 14:47:12 +00:00
HaoyuHuang
4db934407a SK changes #1 (#12448)
## TLDR
This PR is a no-op. The changes are disabled by default. 

## Problem
I. Currently we don't have a way to detect disk I/O failures from WAL
operations.

II.
We observe that the offloader fails to upload a segment due to race
conditions on XLOG SWITCH and PG start streaming WALs. wal_backup task
continously failing to upload a full segment while the segment remains
partial on the disk.

The consequence is that commit_lsn for all SKs move forward but
backup_lsn stays the same. Then, all SKs run out of disk space.

III.
We have discovered SK bugs where the WAL offload owner cannot keep up
with WAL backup/upload to S3, which results in an unbounded accumulation
of WAL segment files on the Safekeeper's disk until the disk becomes
full. This is a somewhat dangerous operation that is hard to recover
from because the Safekeeper cannot write its control files when it is
out of disk space. There are actually 2 problems here:

1. A single problematic timeline can take over the entire disk for the
SK
2. Once out of disk, it's difficult to recover SK


IV. 
Neon reports certain storage errors as "critical" errors using a marco,
which will increment a counter/metric that can be used to raise alerts.
However, this metric isn't sliced by tenant and/or timeline today. We
need the tenant/timeline dimension to better respond to incidents and
for blast radius analysis.

## Summary of changes
I. 
The PR adds a `safekeeper_wal_disk_io_errors ` which is incremented when
SK fails to create or flush WALs.

II. 
To mitigate this issue, we will re-elect a new offloader if the current
offloader is lagging behind too much.
Each SK makes the decision locally but they are aware of each other's
commit and backup lsns.

The new algorithm is
- determine_offloader will pick a SK. say SK-1.
- Each SK checks
-- if commit_lsn - back_lsn > threshold,
-- -- remove SK-1 from the candidate and call determine_offloader again.

SK-1 will step down and all SKs will elect the same leader again.
After the backup is caught up, the leader will become SK-1 again.

This also helps when SK-1 is slow to backup. 

I'll set the reelect backup lag to 4 GB later. Setting to 128 MB in dev
to trigger the code more frequently.

III. 
This change addresses problem no. 1 by having the Safekeeper perform a
timeline disk utilization check check when processing WAL proposal
messages from Postgres/compute. The Safekeeper now rejects the WAL
proposal message, effectively stops writing more WAL for the timeline to
disk, if the existing WAL files for the timeline on the SK disk exceeds
a certain size (the default threshold is 100GB). The disk utilization is
calculated based on a `last_removed_segno` variable tracked by the
background task removing WAL files, which produces an accurate and
conservative estimate (>= than actual disk usage) of the actual disk
usage.


IV.
* Add a new metric `hadron_critical_storage_event_count` that has the
`tenant_shard_id` and `timeline_id` as dimensions.
* Modified the `crtitical!` marco to include tenant_id and timeline_id
as additional arguments and adapted existing call sites to populate the
tenant shard and timeline ID fields. The `critical!` marco invocation
now increments the `hadron_critical_storage_event_count` with the extra
dimensions. (In SK there isn't the notion of a tenant-shard, so just the
tenant ID is recorded in lieu of tenant shard ID.)

I considered adding a separate marco to avoid merge conflicts, but I
think in this case (detecting critical errors) conflicts are probably
more desirable so that we can be aware whenever Neon adds another
`critical!` invocation in their code.

---------

Co-authored-by: Chen Luo <chen.luo@databricks.com>
Co-authored-by: Haoyu Huang <haoyu.huang@databricks.com>
Co-authored-by: William Huang <william.huang@databricks.com>
2025-07-03 14:32:53 +00:00
Ruslan Talpa
b54872a4dc fix error after merging latest master 2025-07-03 17:24:13 +03:00
Ruslan Talpa
486829f875 Merge branch 'main' into ruslan/subzero-integration 2025-07-03 17:10:43 +03:00
Ruslan Talpa
95e1011cd6 subzero pre-integration refactor (#12416)
## Problem
integrating subzero requires a bit of refactoring. To make the
integration PR a bit more manageable, the refactoring is done in this
separate PR.
 
## Summary of changes
* move common types/functions used in sql_over_http to errors.rs and
http_util.rs
* add the "Local" auth backend to proxy (similar to local_proxy), useful
in local testing
* change the Connect and Send type for the http client to allow for
custom body when making post requests to local_proxy from the proxy

---------

Co-authored-by: Ruslan Talpa <ruslan.talpa@databricks.com>
2025-07-03 11:04:08 +00:00
Conrad Ludgate
1bc1eae5e8 fix redis credentials check (#12455)
## Problem

`keep_connection` does not exit, so it was never setting
`credentials_refreshed`.

## Summary of changes

Set `credentials_refreshed` to true when we first establish a
connection, and after we re-authenticate the connection.
2025-07-03 09:51:35 +00:00
Matthias van de Meent
e12d4f356a Work around Clap's incorrect usage of Display for default_value_t (#12454)
## Problem

#12450 

## Summary of changes

Instead of `#[arg(default_value_t = typed_default_value)]`, we use
`#[arg(default_value = "str that deserializes into the value")]`,
because apparently you can't convince clap to _not_ deserialize from the
Display implementation of an imported enum.
2025-07-03 09:41:09 +00:00
Folke Behrens
3415b90e88 proxy/logging: Add "ep" and "query_id" to list of extracted fields (#12437)
Extract two more interesting fields from spans: ep (endpoint) and
query_id.
Useful for reliable filtering in logging.
2025-07-03 08:09:10 +00:00
Conrad Ludgate
e01c8f238c [proxy] update noisy error logging (#12438)
Health checks for pg-sni-router open a TCP connection and immediately
close it again. This is noisy. We will filter out any EOF errors on the
first message.

"acquired permit" debug log is incorrect since it logs when we timedout
as well. This fixes the debug log.
2025-07-03 07:46:48 +00:00
Conrad Ludgate
45607cbe0c [local_proxy]: ignore TLS for endpoint (#12316)
## Problem

When local proxy is configured with TLS, the certificate does not match
the endpoint string. This currently returns an error.

## Summary of changes

I don't think this code is necessary anymore, taking the prefix from the
hostname is good enough (and is equivalent to what `endpoint_sni` was
doing) and we ignore checking the domain suffix.
2025-07-03 07:35:57 +00:00
Tristan Partin
8b4fbefc29 Patch pgaudit to disable logging in parallel workers (#12325)
We want to turn logging in parallel workers off to reduce log
amplification in queries which use parallel workers.

Part-of: https://github.com/neondatabase/cloud/issues/28483

Signed-off-by: Tristan Partin <tristan.partin@databricks.com>
2025-07-02 19:54:47 +00:00
Alex Chi Z.
a9a51c038b rfc: storage feature flags (#11805)
## Problem

Part of https://github.com/neondatabase/neon/issues/11813

## Summary of changes

---------

Signed-off-by: Alex Chi Z <chi@neon.tech>
2025-07-02 17:41:36 +00:00
Alexey Kondratov
44121cc175 docs(compute): RFC for compute rolling restart with prewarm (#11294)
## Problem

Neon currently implements several features that guarantee high uptime of
compute nodes:

1. Storage high-availability (HA), i.e. each tenant shard has a
secondary pageserver location, so we can quickly switch over compute to
it in case of primary pageserver failure.
2. Fast compute provisioning, i.e. we have a fleet of pre-created empty
computes, that are ready to serve workload, so restarting unresponsive
compute is very fast.
3. Preemptive NeonVM compute provisioning in case of k8s node
unavailability.

This helps us to be well-within the uptime SLO of 99.95% most of the
time. Problems begin when we go up to multi-TB workloads and 32-64 CU
computes. During restart, compute looses all caches: LFC, shared
buffers, file system cache. Depending on the workload, it can take a lot
of time to warm up the caches, so that performance could be degraded and
might be even unacceptable for certain workloads. The latter means that
although current approach works well for small to
medium workloads, we still have to do some additional work to avoid
performance degradation after restart of large instances.

[Rendered
version](https://github.com/neondatabase/neon/blob/alexk/pg-prewarm-rfc/docs/rfcs/2025-03-17-compute-prewarm.md)

Part of https://github.com/neondatabase/cloud/issues/19011
2025-07-02 17:16:00 +00:00
Dmitry Savelev
0429a0db16 Switch the billing metrics storage format to ndjson. (#12427)
## Problem
The billing team wants to change the billing events pipeline and use a
common events format in S3 buckets across different event producers.

## Summary of changes
Change the events storage format for billing events from JSON to NDJSON.
Also partition files by hours, rather than days.

Resolves: https://github.com/neondatabase/cloud/issues/29995
2025-07-02 16:30:47 +00:00
Conrad Ludgate
d6beb3ffbb [proxy] rewrite pg-text to json routines (#12413)
We would like to move towards an arena system for JSON encoding the
responses. This change pushes an "out" parameter into the pg-test to
json routines to make swapping in an arena system easier in the future.
(see #11992)

This additionally removes the redundant `column: &[Type]` argument, as
well as rewriting the pg_array parser.

---

I rewrote the pg_array parser since while making these changes I found
it hard to reason about. I went back to the specification and rewrote it
from scratch. There's 4 separate routines:
1. pg_array_parse - checks for any prelude (multidimensional array
ranges)
2. pg_array_parse_inner - only deals with the arrays themselves
3. pg_array_parse_item - parses a single item from the array, this might
be quoted, unquoted, or another nested array.
4. pg_array_parse_quoted - parses a quoted string, following the
relevant string escaping rules.
2025-07-02 12:46:11 +00:00
Arpad Müller
efd7e52812 Don't error if timeline offload is already in progress (#12428)
Don't print errors like:
```
Compaction failed 1 times, retrying in 2s: Failed to offload timeline: Unexpected offload error: Timeline deletion is already in progress
```

Print it at info log level instead.

https://github.com/neondatabase/cloud/issues/30666
2025-07-02 12:06:55 +00:00
Ivan Efremov
0f879a2e8f [proxy]: Fix redis IRSA expiration failure errors (#12430)
Relates to the
[#30688](https://github.com/neondatabase/cloud/issues/30688)
2025-07-02 08:55:44 +00:00
Dmitrii Kovalkov
8e7ce42229 tests: start primary compute on not-readonly branches (#12408)
## Problem

https://github.com/neondatabase/neon/pull/11712 changed how computes are
started in the test: the lsn is specified, making them read-only static
replicas. Lsn is `last_record_lsn` from pageserver. It works fine with
read-only branches (because their `last_record_lsn` is equal to
`start_lsn` and always valid). But with writable timelines, the
`last_record_lsn` on the pageserver might be stale.

Particularly in this test, after the `detach_branch` operation, the
tenant is reset on the pagesever. It leads to `last_record_lsn` going
back to `disk_consistent_lsn`, so basically rolling back some recent
writes.

If we start a primary compute, it will start at safekeepers' commit Lsn,
which is the correct one , and will wait till pageserver catches up with
this Lsn after reset.

- Closes: https://github.com/neondatabase/neon/issues/12365

## Summary of changes
- Start `primary` compute for writable timelines.
2025-07-02 05:41:17 +00:00
Ruslan Talpa
4775aa3e01 Merge branch 'main' into ruslan/subzero-integration 2025-07-01 13:46:57 +03:00
Ruslan Talpa
1785f856b6 Move the local auth backend under the "testing" feature 2025-06-30 16:52:45 +03:00
Ruslan Talpa
69b22b05da add in readme the way to run auth/rest broker locally 2025-06-30 16:17:35 +03:00
Ruslan Talpa
bf0007fa96 add note about local confir read code 2025-06-30 16:12:04 +03:00
Ruslan Talpa
a9bbe7b00b remove unused imports 2025-06-30 16:02:30 +03:00
Ruslan Talpa
7e3f64b309 implement local auth backend for proxy and remove control plane hacks 2025-06-30 16:00:43 +03:00
Ruslan Talpa
9480d17de7 fix bug in pickcurrent_chema 2025-06-30 12:53:47 +03:00
Ruslan Talpa
424004ec95 apply cargo fmt 2025-06-30 12:32:47 +03:00
Ruslan Talpa
88d1a78260 cleanup the rest path code 2025-06-30 12:30:33 +03:00
Ruslan Talpa
8e544c7f99 import introspection queries instead of loading from files 2025-06-27 14:40:02 +03:00
Ruslan Talpa
4f49fc5b79 move common error types and http realted functions to error.rs and http_util.rs 2025-06-27 13:37:29 +03:00
Ruslan Talpa
5461039c3f implement remote config fetch from the db and cache introspected schema 2025-06-27 10:20:26 +03:00
Ruslan Talpa
d6c36d103e subzero integration WIP6
beginning work on introspection of config and schema shape from the database
2025-06-26 14:42:31 +03:00
Ruslan Talpa
fbb2416685 pg 14 vendor commit changed 2025-06-26 10:25:54 +03:00
Ruslan Talpa
8072fae2fe Merge branch 'main' into ruslan/subzero-integration 2025-06-26 10:19:16 +03:00
Ruslan Talpa
3869d680f9 use a global parsed/cached schema 2025-06-26 10:14:28 +03:00
Ruslan Talpa
d3fa228d92 move subzero local test files to a "dot" folder 2025-06-25 16:43:50 +03:00
Ruslan Talpa
be6a259b85 subzero integration WIP5
cleanup and postprocess the response and set the correct headers/status and handle errors
2025-06-25 14:33:45 +03:00
Ruslan Talpa
af3ca24a5e remove unused enum values 2025-06-25 14:32:46 +03:00
Ruslan Talpa
8b44f5b479 subzero integration WIP5
extract the response body from the local proxy response
2025-06-24 17:03:42 +03:00
Ruslan Talpa
d1445cf3eb subzero integration WIP4
queries generated by subzero reach database and execute succesfully
2025-06-24 15:33:51 +03:00
Ruslan Talpa
67d3026fc4 subzero integration WIP3
* query makes it to the database
2025-06-23 11:48:55 +03:00
Ruslan Talpa
09e62e9b98 subzero integration WIP2 2025-06-23 10:11:06 +03:00
Ruslan Talpa
e121da4bfc subzero integration WIP1 2025-06-20 15:10:45 +03:00
Ruslan Talpa
4a948c9781 add note about ICU lib missing on macs and the fix 2025-06-20 10:58:38 +03:00
Ruslan Talpa
b39f04ab99 add missing parts to make disable_pg_session_jwt flag work 2025-06-20 10:23:42 +03:00
Ruslan Talpa
6bd15908fb make pg_session_jwt instalation optional with a cli flag 2025-06-20 10:17:32 +03:00
Ruslan Talpa
3e36d516c2 vanilla pg dokcer image setup 2025-06-20 09:37:39 +03:00
Conrad Ludgate
cc3af6f7dd code for local setup of auth-broker 2025-06-20 09:37:39 +03:00
Conrad Ludgate
5badc7a3fb code for local setup of auth-broker 2025-06-19 10:34:09 +01:00
Conrad Ludgate
3a73644308 use cargo-chef for compute-tools 2025-06-19 09:24:53 +01:00
92 changed files with 6771 additions and 2513 deletions

5
.gitignore vendored
View File

@@ -25,6 +25,11 @@ compaction-suite-results.*
*.o
*.so
*.Po
*.pid
# pgindent typedef lists
*.list
# various files for local testing
/proxy/.subzero
local_proxy.json

3241
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1572,6 +1572,7 @@ RUN make -j $(getconf _NPROCESSORS_ONLN) && \
FROM build-deps AS pgaudit-src
ARG PG_VERSION
WORKDIR /ext-src
COPY "compute/patches/pgaudit-parallel_workers-${PG_VERSION}.patch" .
RUN case "${PG_VERSION}" in \
"v14") \
export PGAUDIT_VERSION=1.6.3 \
@@ -1594,7 +1595,8 @@ RUN case "${PG_VERSION}" in \
esac && \
wget https://github.com/pgaudit/pgaudit/archive/refs/tags/${PGAUDIT_VERSION}.tar.gz -O pgaudit.tar.gz && \
echo "${PGAUDIT_CHECKSUM} pgaudit.tar.gz" | sha256sum --check && \
mkdir pgaudit-src && cd pgaudit-src && tar xzf ../pgaudit.tar.gz --strip-components=1 -C .
mkdir pgaudit-src && cd pgaudit-src && tar xzf ../pgaudit.tar.gz --strip-components=1 -C . && \
patch -p1 < "/ext-src/pgaudit-parallel_workers-${PG_VERSION}.patch"
FROM pg-build AS pgaudit-build
COPY --from=pgaudit-src /ext-src/ /ext-src/

View File

@@ -0,0 +1,143 @@
commit 7220bb3a3f23fa27207d77562dcc286f9a123313
Author: Tristan Partin <tristan.partin@databricks.com>
Date: 2025-06-23 02:09:31 +0000
Disable logging in parallel workers
When a query uses parallel workers, pgaudit will log the same query for
every parallel worker. This is undesireable since it can result in log
amplification for queries that use parallel workers.
Signed-off-by: Tristan Partin <tristan.partin@databricks.com>
diff --git a/expected/pgaudit.out b/expected/pgaudit.out
index baa8011..a601375 100644
--- a/expected/pgaudit.out
+++ b/expected/pgaudit.out
@@ -2563,6 +2563,37 @@ COMMIT;
NOTICE: AUDIT: SESSION,12,4,MISC,COMMIT,,,COMMIT;,<not logged>
DROP TABLE part_test;
NOTICE: AUDIT: SESSION,13,1,DDL,DROP TABLE,,,DROP TABLE part_test;,<not logged>
+--
+-- Test logging in parallel workers
+SET pgaudit.log = 'read';
+SET pgaudit.log_client = on;
+SET pgaudit.log_level = 'notice';
+-- Force parallel execution for testing
+SET max_parallel_workers_per_gather = 2;
+SET parallel_tuple_cost = 0;
+SET parallel_setup_cost = 0;
+SET min_parallel_table_scan_size = 0;
+SET min_parallel_index_scan_size = 0;
+-- Create table with enough data to trigger parallel execution
+CREATE TABLE parallel_test (id int, data text);
+INSERT INTO parallel_test SELECT generate_series(1, 1000), 'test data';
+SELECT count(*) FROM parallel_test;
+NOTICE: AUDIT: SESSION,14,1,READ,SELECT,,,SELECT count(*) FROM parallel_test;,<not logged>
+ count
+-------
+ 1000
+(1 row)
+
+-- Cleanup parallel test
+DROP TABLE parallel_test;
+RESET max_parallel_workers_per_gather;
+RESET parallel_tuple_cost;
+RESET parallel_setup_cost;
+RESET min_parallel_table_scan_size;
+RESET min_parallel_index_scan_size;
+RESET pgaudit.log;
+RESET pgaudit.log_client;
+RESET pgaudit.log_level;
-- Cleanup
-- Set client_min_messages up to warning to avoid noise
SET client_min_messages = 'warning';
diff --git a/pgaudit.c b/pgaudit.c
index 5e6fd38..ac9ded2 100644
--- a/pgaudit.c
+++ b/pgaudit.c
@@ -11,6 +11,7 @@
#include "postgres.h"
#include "access/htup_details.h"
+#include "access/parallel.h"
#include "access/sysattr.h"
#include "access/xact.h"
#include "access/relation.h"
@@ -1303,7 +1304,7 @@ pgaudit_ExecutorStart_hook(QueryDesc *queryDesc, int eflags)
{
AuditEventStackItem *stackItem = NULL;
- if (!internalStatement)
+ if (!internalStatement && !IsParallelWorker())
{
/* Push the audit even onto the stack */
stackItem = stack_push();
@@ -1384,7 +1385,7 @@ pgaudit_ExecutorCheckPerms_hook(List *rangeTabls, bool abort)
/* Log DML if the audit role is valid or session logging is enabled */
if ((auditOid != InvalidOid || auditLogBitmap != 0) &&
- !IsAbortedTransactionBlockState())
+ !IsAbortedTransactionBlockState() && !IsParallelWorker())
{
/* If auditLogRows is on, wait for rows processed to be set */
if (auditLogRows && auditEventStack != NULL)
@@ -1438,7 +1439,7 @@ pgaudit_ExecutorRun_hook(QueryDesc *queryDesc, ScanDirection direction, uint64 c
else
standard_ExecutorRun(queryDesc, direction, count, execute_once);
- if (auditLogRows && !internalStatement)
+ if (auditLogRows && !internalStatement && !IsParallelWorker())
{
/* Find an item from the stack by the query memory context */
stackItem = stack_find_context(queryDesc->estate->es_query_cxt);
@@ -1458,7 +1459,7 @@ pgaudit_ExecutorEnd_hook(QueryDesc *queryDesc)
AuditEventStackItem *stackItem = NULL;
AuditEventStackItem *auditEventStackFull = NULL;
- if (auditLogRows && !internalStatement)
+ if (auditLogRows && !internalStatement && !IsParallelWorker())
{
/* Find an item from the stack by the query memory context */
stackItem = stack_find_context(queryDesc->estate->es_query_cxt);
diff --git a/sql/pgaudit.sql b/sql/pgaudit.sql
index cc1374a..1870a60 100644
--- a/sql/pgaudit.sql
+++ b/sql/pgaudit.sql
@@ -1612,6 +1612,36 @@ COMMIT;
DROP TABLE part_test;
+--
+-- Test logging in parallel workers
+SET pgaudit.log = 'read';
+SET pgaudit.log_client = on;
+SET pgaudit.log_level = 'notice';
+
+-- Force parallel execution for testing
+SET max_parallel_workers_per_gather = 2;
+SET parallel_tuple_cost = 0;
+SET parallel_setup_cost = 0;
+SET min_parallel_table_scan_size = 0;
+SET min_parallel_index_scan_size = 0;
+
+-- Create table with enough data to trigger parallel execution
+CREATE TABLE parallel_test (id int, data text);
+INSERT INTO parallel_test SELECT generate_series(1, 1000), 'test data';
+
+SELECT count(*) FROM parallel_test;
+
+-- Cleanup parallel test
+DROP TABLE parallel_test;
+RESET max_parallel_workers_per_gather;
+RESET parallel_tuple_cost;
+RESET parallel_setup_cost;
+RESET min_parallel_table_scan_size;
+RESET min_parallel_index_scan_size;
+RESET pgaudit.log;
+RESET pgaudit.log_client;
+RESET pgaudit.log_level;
+
-- Cleanup
-- Set client_min_messages up to warning to avoid noise
SET client_min_messages = 'warning';

View File

@@ -0,0 +1,143 @@
commit 29dc2847f6255541992f18faf8a815dfab79631a
Author: Tristan Partin <tristan.partin@databricks.com>
Date: 2025-06-23 02:09:31 +0000
Disable logging in parallel workers
When a query uses parallel workers, pgaudit will log the same query for
every parallel worker. This is undesireable since it can result in log
amplification for queries that use parallel workers.
Signed-off-by: Tristan Partin <tristan.partin@databricks.com>
diff --git a/expected/pgaudit.out b/expected/pgaudit.out
index b22560b..73f0327 100644
--- a/expected/pgaudit.out
+++ b/expected/pgaudit.out
@@ -2563,6 +2563,37 @@ COMMIT;
NOTICE: AUDIT: SESSION,12,4,MISC,COMMIT,,,COMMIT;,<not logged>
DROP TABLE part_test;
NOTICE: AUDIT: SESSION,13,1,DDL,DROP TABLE,,,DROP TABLE part_test;,<not logged>
+--
+-- Test logging in parallel workers
+SET pgaudit.log = 'read';
+SET pgaudit.log_client = on;
+SET pgaudit.log_level = 'notice';
+-- Force parallel execution for testing
+SET max_parallel_workers_per_gather = 2;
+SET parallel_tuple_cost = 0;
+SET parallel_setup_cost = 0;
+SET min_parallel_table_scan_size = 0;
+SET min_parallel_index_scan_size = 0;
+-- Create table with enough data to trigger parallel execution
+CREATE TABLE parallel_test (id int, data text);
+INSERT INTO parallel_test SELECT generate_series(1, 1000), 'test data';
+SELECT count(*) FROM parallel_test;
+NOTICE: AUDIT: SESSION,14,1,READ,SELECT,,,SELECT count(*) FROM parallel_test;,<not logged>
+ count
+-------
+ 1000
+(1 row)
+
+-- Cleanup parallel test
+DROP TABLE parallel_test;
+RESET max_parallel_workers_per_gather;
+RESET parallel_tuple_cost;
+RESET parallel_setup_cost;
+RESET min_parallel_table_scan_size;
+RESET min_parallel_index_scan_size;
+RESET pgaudit.log;
+RESET pgaudit.log_client;
+RESET pgaudit.log_level;
-- Cleanup
-- Set client_min_messages up to warning to avoid noise
SET client_min_messages = 'warning';
diff --git a/pgaudit.c b/pgaudit.c
index 5e6fd38..ac9ded2 100644
--- a/pgaudit.c
+++ b/pgaudit.c
@@ -11,6 +11,7 @@
#include "postgres.h"
#include "access/htup_details.h"
+#include "access/parallel.h"
#include "access/sysattr.h"
#include "access/xact.h"
#include "access/relation.h"
@@ -1303,7 +1304,7 @@ pgaudit_ExecutorStart_hook(QueryDesc *queryDesc, int eflags)
{
AuditEventStackItem *stackItem = NULL;
- if (!internalStatement)
+ if (!internalStatement && !IsParallelWorker())
{
/* Push the audit even onto the stack */
stackItem = stack_push();
@@ -1384,7 +1385,7 @@ pgaudit_ExecutorCheckPerms_hook(List *rangeTabls, bool abort)
/* Log DML if the audit role is valid or session logging is enabled */
if ((auditOid != InvalidOid || auditLogBitmap != 0) &&
- !IsAbortedTransactionBlockState())
+ !IsAbortedTransactionBlockState() && !IsParallelWorker())
{
/* If auditLogRows is on, wait for rows processed to be set */
if (auditLogRows && auditEventStack != NULL)
@@ -1438,7 +1439,7 @@ pgaudit_ExecutorRun_hook(QueryDesc *queryDesc, ScanDirection direction, uint64 c
else
standard_ExecutorRun(queryDesc, direction, count, execute_once);
- if (auditLogRows && !internalStatement)
+ if (auditLogRows && !internalStatement && !IsParallelWorker())
{
/* Find an item from the stack by the query memory context */
stackItem = stack_find_context(queryDesc->estate->es_query_cxt);
@@ -1458,7 +1459,7 @@ pgaudit_ExecutorEnd_hook(QueryDesc *queryDesc)
AuditEventStackItem *stackItem = NULL;
AuditEventStackItem *auditEventStackFull = NULL;
- if (auditLogRows && !internalStatement)
+ if (auditLogRows && !internalStatement && !IsParallelWorker())
{
/* Find an item from the stack by the query memory context */
stackItem = stack_find_context(queryDesc->estate->es_query_cxt);
diff --git a/sql/pgaudit.sql b/sql/pgaudit.sql
index 8052426..7f0667b 100644
--- a/sql/pgaudit.sql
+++ b/sql/pgaudit.sql
@@ -1612,6 +1612,36 @@ COMMIT;
DROP TABLE part_test;
+--
+-- Test logging in parallel workers
+SET pgaudit.log = 'read';
+SET pgaudit.log_client = on;
+SET pgaudit.log_level = 'notice';
+
+-- Force parallel execution for testing
+SET max_parallel_workers_per_gather = 2;
+SET parallel_tuple_cost = 0;
+SET parallel_setup_cost = 0;
+SET min_parallel_table_scan_size = 0;
+SET min_parallel_index_scan_size = 0;
+
+-- Create table with enough data to trigger parallel execution
+CREATE TABLE parallel_test (id int, data text);
+INSERT INTO parallel_test SELECT generate_series(1, 1000), 'test data';
+
+SELECT count(*) FROM parallel_test;
+
+-- Cleanup parallel test
+DROP TABLE parallel_test;
+RESET max_parallel_workers_per_gather;
+RESET parallel_tuple_cost;
+RESET parallel_setup_cost;
+RESET min_parallel_table_scan_size;
+RESET min_parallel_index_scan_size;
+RESET pgaudit.log;
+RESET pgaudit.log_client;
+RESET pgaudit.log_level;
+
-- Cleanup
-- Set client_min_messages up to warning to avoid noise
SET client_min_messages = 'warning';

View File

@@ -0,0 +1,143 @@
commit cc708dde7ef2af2a8120d757102d2e34c0463a0f
Author: Tristan Partin <tristan.partin@databricks.com>
Date: 2025-06-23 02:09:31 +0000
Disable logging in parallel workers
When a query uses parallel workers, pgaudit will log the same query for
every parallel worker. This is undesireable since it can result in log
amplification for queries that use parallel workers.
Signed-off-by: Tristan Partin <tristan.partin@databricks.com>
diff --git a/expected/pgaudit.out b/expected/pgaudit.out
index 8772054..9b66ac6 100644
--- a/expected/pgaudit.out
+++ b/expected/pgaudit.out
@@ -2556,6 +2556,37 @@ DROP SERVER fdw_server;
NOTICE: AUDIT: SESSION,11,1,DDL,DROP SERVER,,,DROP SERVER fdw_server;,<not logged>
DROP EXTENSION postgres_fdw;
NOTICE: AUDIT: SESSION,12,1,DDL,DROP EXTENSION,,,DROP EXTENSION postgres_fdw;,<not logged>
+--
+-- Test logging in parallel workers
+SET pgaudit.log = 'read';
+SET pgaudit.log_client = on;
+SET pgaudit.log_level = 'notice';
+-- Force parallel execution for testing
+SET max_parallel_workers_per_gather = 2;
+SET parallel_tuple_cost = 0;
+SET parallel_setup_cost = 0;
+SET min_parallel_table_scan_size = 0;
+SET min_parallel_index_scan_size = 0;
+-- Create table with enough data to trigger parallel execution
+CREATE TABLE parallel_test (id int, data text);
+INSERT INTO parallel_test SELECT generate_series(1, 1000), 'test data';
+SELECT count(*) FROM parallel_test;
+NOTICE: AUDIT: SESSION,13,1,READ,SELECT,,,SELECT count(*) FROM parallel_test;,<not logged>
+ count
+-------
+ 1000
+(1 row)
+
+-- Cleanup parallel test
+DROP TABLE parallel_test;
+RESET max_parallel_workers_per_gather;
+RESET parallel_tuple_cost;
+RESET parallel_setup_cost;
+RESET min_parallel_table_scan_size;
+RESET min_parallel_index_scan_size;
+RESET pgaudit.log;
+RESET pgaudit.log_client;
+RESET pgaudit.log_level;
-- Cleanup
-- Set client_min_messages up to warning to avoid noise
SET client_min_messages = 'warning';
diff --git a/pgaudit.c b/pgaudit.c
index 004d1f9..f061164 100644
--- a/pgaudit.c
+++ b/pgaudit.c
@@ -11,6 +11,7 @@
#include "postgres.h"
#include "access/htup_details.h"
+#include "access/parallel.h"
#include "access/sysattr.h"
#include "access/xact.h"
#include "access/relation.h"
@@ -1339,7 +1340,7 @@ pgaudit_ExecutorStart_hook(QueryDesc *queryDesc, int eflags)
{
AuditEventStackItem *stackItem = NULL;
- if (!internalStatement)
+ if (!internalStatement && !IsParallelWorker())
{
/* Push the audit even onto the stack */
stackItem = stack_push();
@@ -1420,7 +1421,7 @@ pgaudit_ExecutorCheckPerms_hook(List *rangeTabls, List *permInfos, bool abort)
/* Log DML if the audit role is valid or session logging is enabled */
if ((auditOid != InvalidOid || auditLogBitmap != 0) &&
- !IsAbortedTransactionBlockState())
+ !IsAbortedTransactionBlockState() && !IsParallelWorker())
{
/* If auditLogRows is on, wait for rows processed to be set */
if (auditLogRows && auditEventStack != NULL)
@@ -1475,7 +1476,7 @@ pgaudit_ExecutorRun_hook(QueryDesc *queryDesc, ScanDirection direction, uint64 c
else
standard_ExecutorRun(queryDesc, direction, count, execute_once);
- if (auditLogRows && !internalStatement)
+ if (auditLogRows && !internalStatement && !IsParallelWorker())
{
/* Find an item from the stack by the query memory context */
stackItem = stack_find_context(queryDesc->estate->es_query_cxt);
@@ -1495,7 +1496,7 @@ pgaudit_ExecutorEnd_hook(QueryDesc *queryDesc)
AuditEventStackItem *stackItem = NULL;
AuditEventStackItem *auditEventStackFull = NULL;
- if (auditLogRows && !internalStatement)
+ if (auditLogRows && !internalStatement && !IsParallelWorker())
{
/* Find an item from the stack by the query memory context */
stackItem = stack_find_context(queryDesc->estate->es_query_cxt);
diff --git a/sql/pgaudit.sql b/sql/pgaudit.sql
index 6aae88b..de6d7fd 100644
--- a/sql/pgaudit.sql
+++ b/sql/pgaudit.sql
@@ -1631,6 +1631,36 @@ DROP USER MAPPING FOR regress_user1 SERVER fdw_server;
DROP SERVER fdw_server;
DROP EXTENSION postgres_fdw;
+--
+-- Test logging in parallel workers
+SET pgaudit.log = 'read';
+SET pgaudit.log_client = on;
+SET pgaudit.log_level = 'notice';
+
+-- Force parallel execution for testing
+SET max_parallel_workers_per_gather = 2;
+SET parallel_tuple_cost = 0;
+SET parallel_setup_cost = 0;
+SET min_parallel_table_scan_size = 0;
+SET min_parallel_index_scan_size = 0;
+
+-- Create table with enough data to trigger parallel execution
+CREATE TABLE parallel_test (id int, data text);
+INSERT INTO parallel_test SELECT generate_series(1, 1000), 'test data';
+
+SELECT count(*) FROM parallel_test;
+
+-- Cleanup parallel test
+DROP TABLE parallel_test;
+RESET max_parallel_workers_per_gather;
+RESET parallel_tuple_cost;
+RESET parallel_setup_cost;
+RESET min_parallel_table_scan_size;
+RESET min_parallel_index_scan_size;
+RESET pgaudit.log;
+RESET pgaudit.log_client;
+RESET pgaudit.log_level;
+
-- Cleanup
-- Set client_min_messages up to warning to avoid noise
SET client_min_messages = 'warning';

View File

@@ -0,0 +1,143 @@
commit 8d02e4c6c5e1e8676251b0717a46054267091cb4
Author: Tristan Partin <tristan.partin@databricks.com>
Date: 2025-06-23 02:09:31 +0000
Disable logging in parallel workers
When a query uses parallel workers, pgaudit will log the same query for
every parallel worker. This is undesireable since it can result in log
amplification for queries that use parallel workers.
Signed-off-by: Tristan Partin <tristan.partin@databricks.com>
diff --git a/expected/pgaudit.out b/expected/pgaudit.out
index d696287..4b1059a 100644
--- a/expected/pgaudit.out
+++ b/expected/pgaudit.out
@@ -2568,6 +2568,37 @@ DROP SERVER fdw_server;
NOTICE: AUDIT: SESSION,11,1,DDL,DROP SERVER,,,DROP SERVER fdw_server,<not logged>
DROP EXTENSION postgres_fdw;
NOTICE: AUDIT: SESSION,12,1,DDL,DROP EXTENSION,,,DROP EXTENSION postgres_fdw,<not logged>
+--
+-- Test logging in parallel workers
+SET pgaudit.log = 'read';
+SET pgaudit.log_client = on;
+SET pgaudit.log_level = 'notice';
+-- Force parallel execution for testing
+SET max_parallel_workers_per_gather = 2;
+SET parallel_tuple_cost = 0;
+SET parallel_setup_cost = 0;
+SET min_parallel_table_scan_size = 0;
+SET min_parallel_index_scan_size = 0;
+-- Create table with enough data to trigger parallel execution
+CREATE TABLE parallel_test (id int, data text);
+INSERT INTO parallel_test SELECT generate_series(1, 1000), 'test data';
+SELECT count(*) FROM parallel_test;
+NOTICE: AUDIT: SESSION,13,1,READ,SELECT,,,SELECT count(*) FROM parallel_test,<not logged>
+ count
+-------
+ 1000
+(1 row)
+
+-- Cleanup parallel test
+DROP TABLE parallel_test;
+RESET max_parallel_workers_per_gather;
+RESET parallel_tuple_cost;
+RESET parallel_setup_cost;
+RESET min_parallel_table_scan_size;
+RESET min_parallel_index_scan_size;
+RESET pgaudit.log;
+RESET pgaudit.log_client;
+RESET pgaudit.log_level;
-- Cleanup
-- Set client_min_messages up to warning to avoid noise
SET client_min_messages = 'warning';
diff --git a/pgaudit.c b/pgaudit.c
index 1764af1..0e48875 100644
--- a/pgaudit.c
+++ b/pgaudit.c
@@ -11,6 +11,7 @@
#include "postgres.h"
#include "access/htup_details.h"
+#include "access/parallel.h"
#include "access/sysattr.h"
#include "access/xact.h"
#include "access/relation.h"
@@ -1406,7 +1407,7 @@ pgaudit_ExecutorStart_hook(QueryDesc *queryDesc, int eflags)
{
AuditEventStackItem *stackItem = NULL;
- if (!internalStatement)
+ if (!internalStatement && !IsParallelWorker())
{
/* Push the audit event onto the stack */
stackItem = stack_push();
@@ -1489,7 +1490,7 @@ pgaudit_ExecutorCheckPerms_hook(List *rangeTabls, List *permInfos, bool abort)
/* Log DML if the audit role is valid or session logging is enabled */
if ((auditOid != InvalidOid || auditLogBitmap != 0) &&
- !IsAbortedTransactionBlockState())
+ !IsAbortedTransactionBlockState() && !IsParallelWorker())
{
/* If auditLogRows is on, wait for rows processed to be set */
if (auditLogRows && auditEventStack != NULL)
@@ -1544,7 +1545,7 @@ pgaudit_ExecutorRun_hook(QueryDesc *queryDesc, ScanDirection direction, uint64 c
else
standard_ExecutorRun(queryDesc, direction, count, execute_once);
- if (auditLogRows && !internalStatement)
+ if (auditLogRows && !internalStatement && !IsParallelWorker())
{
/* Find an item from the stack by the query memory context */
stackItem = stack_find_context(queryDesc->estate->es_query_cxt);
@@ -1564,7 +1565,7 @@ pgaudit_ExecutorEnd_hook(QueryDesc *queryDesc)
AuditEventStackItem *stackItem = NULL;
AuditEventStackItem *auditEventStackFull = NULL;
- if (auditLogRows && !internalStatement)
+ if (auditLogRows && !internalStatement && !IsParallelWorker())
{
/* Find an item from the stack by the query memory context */
stackItem = stack_find_context(queryDesc->estate->es_query_cxt);
diff --git a/sql/pgaudit.sql b/sql/pgaudit.sql
index e161f01..c873098 100644
--- a/sql/pgaudit.sql
+++ b/sql/pgaudit.sql
@@ -1637,6 +1637,36 @@ DROP USER MAPPING FOR regress_user1 SERVER fdw_server;
DROP SERVER fdw_server;
DROP EXTENSION postgres_fdw;
+--
+-- Test logging in parallel workers
+SET pgaudit.log = 'read';
+SET pgaudit.log_client = on;
+SET pgaudit.log_level = 'notice';
+
+-- Force parallel execution for testing
+SET max_parallel_workers_per_gather = 2;
+SET parallel_tuple_cost = 0;
+SET parallel_setup_cost = 0;
+SET min_parallel_table_scan_size = 0;
+SET min_parallel_index_scan_size = 0;
+
+-- Create table with enough data to trigger parallel execution
+CREATE TABLE parallel_test (id int, data text);
+INSERT INTO parallel_test SELECT generate_series(1, 1000), 'test data';
+
+SELECT count(*) FROM parallel_test;
+
+-- Cleanup parallel test
+DROP TABLE parallel_test;
+RESET max_parallel_workers_per_gather;
+RESET parallel_tuple_cost;
+RESET parallel_setup_cost;
+RESET min_parallel_table_scan_size;
+RESET min_parallel_index_scan_size;
+RESET pgaudit.log;
+RESET pgaudit.log_client;
+RESET pgaudit.log_level;
+
-- Cleanup
-- Set client_min_messages up to warning to avoid noise
SET client_min_messages = 'warning';

View File

@@ -2371,24 +2371,23 @@ LIMIT 100",
installed_extensions_collection_interval
);
let handle = tokio::spawn(async move {
// An initial sleep is added to ensure that two collections don't happen at the same time.
// The first collection happens during compute startup.
tokio::time::sleep(tokio::time::Duration::from_secs(
installed_extensions_collection_interval,
))
.await;
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(
installed_extensions_collection_interval,
));
loop {
interval.tick().await;
info!(
"[NEON_EXT_INT_SLEEP]: Interval: {}",
installed_extensions_collection_interval
);
// Sleep at the start of the loop to ensure that two collections don't happen at the same time.
// The first collection happens during compute startup.
tokio::time::sleep(tokio::time::Duration::from_secs(
installed_extensions_collection_interval,
))
.await;
let _ = installed_extensions(conf.clone()).await;
// Acquire a read lock on the compute spec and then update the interval if necessary
interval = tokio::time::interval(tokio::time::Duration::from_secs(std::cmp::max(
installed_extensions_collection_interval = std::cmp::max(
installed_extensions_collection_interval,
2 * atomic_interval.load(std::sync::atomic::Ordering::SeqCst),
)));
installed_extensions_collection_interval = interval.period().as_secs();
);
}
});

View File

@@ -64,7 +64,9 @@ const DEFAULT_PAGESERVER_ID: NodeId = NodeId(1);
const DEFAULT_BRANCH_NAME: &str = "main";
project_git_version!(GIT_VERSION);
#[allow(dead_code)]
const DEFAULT_PG_VERSION: PgMajorVersion = PgMajorVersion::PG17;
const DEFAULT_PG_VERSION_NUM: &str = "17";
const DEFAULT_PAGESERVER_CONTROL_PLANE_API: &str = "http://127.0.0.1:1234/upcall/v1/";
@@ -167,7 +169,7 @@ struct TenantCreateCmdArgs {
#[clap(short = 'c')]
config: Vec<String>,
#[arg(default_value_t = DEFAULT_PG_VERSION)]
#[arg(default_value = DEFAULT_PG_VERSION_NUM)]
#[clap(long, help = "Postgres version to use for the initial timeline")]
pg_version: PgMajorVersion,
@@ -290,7 +292,7 @@ struct TimelineCreateCmdArgs {
#[clap(long, help = "Human-readable alias for the new timeline")]
branch_name: String,
#[arg(default_value_t = DEFAULT_PG_VERSION)]
#[arg(default_value = DEFAULT_PG_VERSION_NUM)]
#[clap(long, help = "Postgres version")]
pg_version: PgMajorVersion,
}
@@ -322,7 +324,7 @@ struct TimelineImportCmdArgs {
#[clap(long, help = "Lsn the basebackup ends at")]
end_lsn: Option<Lsn>,
#[arg(default_value_t = DEFAULT_PG_VERSION)]
#[arg(default_value = DEFAULT_PG_VERSION_NUM)]
#[clap(long, help = "Postgres version of the backup being imported")]
pg_version: PgMajorVersion,
}
@@ -601,7 +603,7 @@ struct EndpointCreateCmdArgs {
)]
config_only: bool,
#[arg(default_value_t = DEFAULT_PG_VERSION)]
#[arg(default_value = DEFAULT_PG_VERSION_NUM)]
#[clap(long, help = "Postgres version")]
pg_version: PgMajorVersion,

View File

@@ -0,0 +1,179 @@
# Storage Feature Flags
In this RFC, we will describe how we will implement per-tenant feature flags.
## PostHog as Feature Flag Service
Before we start, let's talk about how current feature flag services work. PostHog is the feature flag service we are currently using across multiple user-facing components in the company. PostHog has two modes of operation: HTTP evaluation and server-side local evaluation.
Let's assume we have a storage feature flag called gc-compaction and we want to roll it out to scale-tier users with resident size >= 10GB and <= 100GB.
### Define User Profiles
The first step is to synchronize our user profiles to the PostHog service. We can simply assume that each tenant is a user in PostHog. Each user profile has some properties associated with it. In our case, it will be: plan type (free, scale, enterprise, etc); resident size (in bytes); primary pageserver (string); region (string).
### Define Feature Flags
We would create a feature flag called gc-compaction in PostHog with 4 variants: disabled, stage-1, stage-2, fully-enabled. We will flip the feature flags from disabled to fully-enabled stage by stage for some percentage of our users.
### Option 1: HTTP Evaluation Mode
When using PostHog's HTTP evaluation mode, the client will make request to the PostHog service, asking for the value of a feature flag for a specific user.
* Control plane will report the plan type to PostHog each time it attaches a tenant to the storcon or when the user upgrades/downgrades. It calls the PostHog profile API to associate tenant ID with the plan type. Assume we have X active tenants and such attach or plan change event happens each week, that would be 4X profile update requests per month.
* Pageservers will report the resident size and the primary pageserver to the PostHog service. Assume we report resident size every 24 hours, that would be 30X requests per month.
* Each tenant will request the state of the feature flag every 1 hour, that's 720X requests per month.
* The Rust client would be easy to implement as we only need to call the `/decide` API on PostHog.
Using the HTTP evaluation mode we will issue 754X requests a month.
### Option 2: Local Evaluation Mode
When using PostHog's HTTP evaluation mode, the client (usually the server in a browser/server architecture) will poll the feature flag configuration every 30s (default in the Python client) from PostHog. Such configuration contains data like:
<details>
<summary>Example JSON response from the PostHog local evaluation API</summary>
```
[
{
"id": 1,
"name": "Beta Feature",
"key": "person-flag",
"is_simple_flag": True,
"active": True,
"filters": {
"groups": [
{
"properties": [
{
"key": "location",
"operator": "exact",
"value": ["Straße"],
"type": "person",
}
],
"rollout_percentage": 100,
},
{
"properties": [
{
"key": "star",
"operator": "exact",
"value": ["ſun"],
"type": "person",
}
],
"rollout_percentage": 100,
},
],
},
}
]
```
</details>
Note that the API only contains information like "under what condition => rollout percentage". The user is responsible to provide the properties required to the client for local evaluation, and the PostHog service (web UI) cannot know if a feature is enabled for the tenant or not until the client uses the `capture` API to report the result back. To control the rollout percentage, the user ID gets mapped to a float number in `[0, 1)` on a consistent hash ring. All values <= the percentage will get the feature enabled or set to the desired value.
To use the local evaluation mode, the system needs:
* Assume each pageserver will poll PostHog for the local evaluation JSON every 5 minutes (instead of the 30s default as it's too frequent). That's 8640Y per month, Y is the number of pageservers. Local evaluation requests cost 10x more than the normal decide request, so that's 86400Y request units to bill.
* Storcon needs to store the plan type in the database and pass that information to the pageserver when attaching the tenant.
* Storcon also needs to update PostHog with the active tenants, for example, when the tenant gets detached/attached. Assume each active tenant gets detached/attached every week, that would be 4X requests per month.
* We do not need to update bill type or resident size to PostHog as all these are evaluated locally.
* After each local evaluation of the feature flag, we need to call PostHog's capture event API to update the result of the evaluation that the feature is enabled. We can do this when the flag gets changed compared with the last cached state in memory. That would be at least 4X (assume we do deployment every week so the cache gets cleared) and maybe an additional multiplifier of 10 assume we have 10 active features.
In this case, we will issue 86400Y + 40X requests per month.
Assume X = 1,000,000 and Y = 100,
| | HTTP Evaluation | Local Evaluation |
|---|---|---|
| Latency of propagating the conditions/properties for feature flag | 24 hours | available locally |
| Latency of applying the feature flag | 1 hour | 5 minutes |
| Can properties be reported from different services | Yes | No |
| Do we need to sync billing info etc to pageserver | No | Yes |
| Cost | 75400$ / month | 4864$ / month |
# Our Solution
We will use PostHog _only_ as an UI to configure the feature flags. Whether a feature is enabled or not can only be queried through storcon/pageserver instead of using the PostHog UI. (We could report it back to PostHog via `capture_event` but it costs $$$.) This allows us to ramp up the feature flag functionality fast at first. At the same time, it would also give us the option to migrate to our own solution once we want to have more properties and more complex evaluation rules in our system.
* We will create several fake users (tenants) in PostHog that contains all the properties we will use for evaluating a feature flag (i.e., resident size, billing type, pageserver id, etc.)
* We will use PostHog's local evaluation API to poll the configuration of the feature flags and evaluate them locally on each of the pageserver.
* The evaluation result will not be reported back to PostHog.
* Storcon needs to pull some information from cplane database.
* To know if a feature is currently enabled or not, we need to call the storcon/pageserver API; and we won't be able to know if a feature has been enabled on a tenant before easily: we need to look at the Grafana logs.
We only need to pay for the 86400Y local evaluation requests (that would be setting Y=0 in solution 2 => $864/month, and even less if we proxy it through storcon).
## Implementation
* Pageserver: implement a PostHog local evaluation client. The client will be shared across all tenants on the pageserver with a single API: `evaluate(tenant_id, feature_flag, properties) -> json`.
* Storcon: if we need plan type as the evaluation condition, pull it from cplane database.
* Storcon/Pageserver: implement an HTTP API `:tenant_id/feature/:feature` to retrieve the current feature flag status.
* Storcon/Pageserver: a loop to update the feature flag spec on both storcon and pageserver. Pageserver loop will only be activated if storcon does not push the specs to the pageserver.
## Difference from Tenant Config
* Feature flags can be modified by percentage, and the default config for each feature flag can be modified in UI without going through the release process.
* Feature flags are more flexible and won't be persisted anywhere and will be passed as plain JSON over the wire so that do not need to handle backward/forward compatibility as in tenant config.
* The expectation of tenant config is that once we add a flag we cannot remove it (or it will be hard to remove), but feature flags are more flexible.
# Final Implementation
* We added a new crate `posthog_lite_client` that supports local feature evaluations.
* We set up two projects "Storage (staging)" and "Storage (production)" in the PostHog console.
* Each pageserver reports 10 fake tenants to PostHog so that we can get all combinations of regions (and other properties) in the PostHog UI.
* Supported properties: AZ, neon_region, pageserver, tenant_id.
* You may use "Pageserver Feature Flags" dashboard to see the evaluation status.
* The feature flag spec is polled on storcon every 30s (in each of the region) and storcon will propagate the spec to the pageservers.
* The pageserver housekeeping loop updates the tenant-specific properties (e.g., remote size) for evaluation.
Each tenant has a `feature_resolver` object. After you add a feature flag in the PostHog console, you can retrieve it with:
```rust
// Boolean flag
self
.feature_resolver
.evaluate_boolean("flag")
.is_ok()
// Multivariate flag
self
.feature_resolver
.evaluate_multivariate("gc-comapction-strategy")
.ok();
```
The user needs to handle the case where the evaluation result is an error. This can occur in a variety of cases:
* During the pageserver start, the feature flag spec has not been retrieved.
* No condition group is matched.
* The feature flag spec contains an operand/operation not supported by the lite PostHog library.
For boolean flags, the return value is `Result<(), Error>`. `Ok(())` means the flag is evaluated to true. Otherwise,
there is either an error in evaluation or it does not match any groups.
For multivariate flags, the return value is `Result<String, Error>`. `Ok(variant)` indicates the flag is evaluated
to a variant. Otherwise, there is either an error in evaluation or it does not match any groups.
The evaluation logic is documented in the PostHog lite library. It compares the consistent hash of a flag key + tenant_id
with the rollout percentage and determines which tenant to roll out a specific feature.
Users can use the feature flag evaluation API to get the flag evaluation result of a specific tenant for debugging purposes.
```
curl http://localhost:9898/v1/tenant/:tenant_id/feature_flag?flag=:key&as=multivariate/boolean"
```
By default, the storcon pushes the feature flag specs to the pageservers every 30 seconds, which means that a change in feature flag in the
PostHog UI will propagate to the pageservers within 30 seconds.
# Future Works
* Support dynamic tenant properties like logical size as the evaluation condition.
* Support properties like `plan_type` (needs cplane to pass it down).
* Report feature flag evaluation result back to PostHog (if the cost is okay).
* Fast feature flag evaluation cache on critical paths (e.g., cache a feature flag result in `AtomicBool` and use it on the read path).

View File

@@ -0,0 +1,399 @@
# Compute rolling restart with prewarm
Created on 2025-03-17
Implemented on _TBD_
Author: Alexey Kondratov (@ololobus)
## Summary
This RFC describes an approach to reduce performance degradation due to missing caches after compute node restart, i.e.:
1. Rolling restart of the running instance via 'warm' replica.
2. Auto-prewarm compute caches after unplanned restart or scale-to-zero.
## Motivation
Neon currently implements several features that guarantee high uptime of compute nodes:
1. Storage high-availability (HA), i.e. each tenant shard has a secondary pageserver location, so we can quickly switch over compute to it in case of primary pageserver failure.
2. Fast compute provisioning, i.e. we have a fleet of pre-created empty computes, that are ready to serve workload, so restarting unresponsive compute is very fast.
3. Preemptive NeonVM compute provisioning in case of k8s node unavailability.
This helps us to be well-within the uptime SLO of 99.95% most of the time. Problems begin when we go up to multi-TB workloads and 32-64 CU computes.
During restart, compute loses all caches: LFC, shared buffers, file system cache. Depending on the workload, it can take a lot of time to warm up the caches,
so that performance could be degraded and might be even unacceptable for certain workloads. The latter means that although current approach works well for small to
medium workloads, we still have to do some additional work to avoid performance degradation after restart of large instances.
## Non Goals
- Details of the persistence storage for prewarm data are out of scope, there is a separate RFC for that: <https://github.com/neondatabase/neon/pull/9661>.
- Complete compute/Postgres HA setup and flow. Although it was originally in scope of this RFC, during preliminary research it appeared to be a rabbit hole, so it's worth of a separate RFC.
- Low-level implementation details for Postgres replica-to-primary promotion. There are a lot of things to think and care about: how to start walproposer, [logical replication failover](https://www.postgresql.org/docs/current/logical-replication-failover.html), and so on, but it's worth of at least a separate one-pager design document if not RFC.
## Impacted components
Postgres, compute_ctl, Control plane, Endpoint storage for unlogged storage of compute files.
For the latter, we will need to implement a uniform abstraction layer on top of S3, ABS, etc., but
S3 is used in text interchangeably with 'endpoint storage' for simplicity.
## Proposed implementation
### compute_ctl spec changes and auto-prewarm
We are going to extend the current compute spec with the following attributes
```rust
struct ComputeSpec {
/// [All existing attributes]
...
/// Whether to do auto-prewarm at start or not.
/// Default to `false`.
pub lfc_auto_prewarm: bool
/// Interval in seconds between automatic dumps of
/// LFC state into S3. Default `None`, which means 'off'.
pub lfc_dump_interval_sec: Option<i32>
}
```
When `lfc_dump_interval_sec` is set to `N`, `compute_ctl` will periodically dump the LFC state
and store it in S3, so that it could be used either for auto-prewarm after restart or by replica
during the rolling restart. For enabling periodic dumping, we should consider the following value
`lfc_dump_interval_sec=300` (5 minutes), same as in the upstream's `pg_prewarm.autoprewarm_interval`.
When `lfc_auto_prewarm` is set to `true`, `compute_ctl` will start prewarming the LFC upon restart
iif some of the previous states is present in S3.
### compute_ctl API
1. `POST /store_lfc_state` -- dump LFC state using Postgres SQL interface and store result in S3.
This has to be a blocking call, i.e. it will return only after the state is stored in S3.
If there is any concurrent request in progress, we should return `429 Too Many Requests`,
and let the caller to retry.
2. `GET /dump_lfc_state` -- dump LFC state using Postgres SQL interface and return it as is
in text format suitable for the future restore/prewarm. This API is not strictly needed at
the end state, but could be useful for a faster prototyping of a complete rolling restart flow
with prewarm, as it doesn't require persistent for LFC state storage.
3. `POST /restore_lfc_state` -- restore/prewarm LFC state with request
```yaml
RestoreLFCStateRequest:
oneOf:
- type: object
required:
- lfc_state
properties:
lfc_state:
type: string
description: Raw LFC content dumped with GET `/dump_lfc_state`
- type: object
required:
- lfc_cache_key
properties:
lfc_cache_key:
type: string
description: |
endpoint_id of the source endpoint on the same branch
to use as a 'donor' for LFC content. Compute will look up
LFC content dump in S3 using this key and do prewarm.
```
where `lfc_state` and `lfc_cache_key` are mutually exclusive.
The actual prewarming will happen asynchronously, so the caller need to check the
prewarm status using the compute's standard `GET /status` API.
4. `GET /status` -- extend existing API with following attributes
```rust
struct ComputeStatusResponse {
// [All existing attributes]
...
pub prewarm_state: PrewarmState
}
/// Compute prewarm state. Will be stored in the shared Compute state
/// in compute_ctl
struct PrewarmState {
pub status: PrewarmStatus
/// Total number of pages to prewarm
pub pages_total: i64
/// Number of pages prewarmed so far
pub pages_processed: i64
/// Optional prewarm error
pub error: Option<String>
}
pub enum PrewarmStatus {
/// Prewarming was never requested on this compute
Off,
/// Prewarming was requested, but not started yet
Pending,
/// Prewarming is in progress. The caller should follow
/// `PrewarmState::progress`.
InProgress,
/// Prewarming has been successfully completed
Completed,
/// Prewarming failed. The caller should look at
/// `PrewarmState::error` for the reason.
Failed,
/// It is intended to be used by auto-prewarm if none of
/// the previous LFC states is available in S3.
/// This is a distinct state from the `Failed` because
/// technically it's not a failure and could happen if
/// compute was restart before it dumped anything into S3,
/// or just after the initial rollout of the feature.
Skipped,
}
```
5. `POST /promote` -- this is a **blocking** API call to promote compute replica into primary.
This API should be very similar to the existing `POST /configure` API, i.e. accept the
spec (primary spec, because originally compute was started as replica). It's a distinct
API method because semantics and response codes are different:
- If promotion is done successfully, it will return `200 OK`.
- If compute is already primary, the call will be no-op and `compute_ctl`
will return `412 Precondition Failed`.
- If, for some reason, second request reaches compute that is in progress of promotion,
it will respond with `429 Too Many Requests`.
- If compute hit any permanent failure during promotion `500 Internal Server Error`
will be returned.
### Control plane operations
The complete flow will be present as a sequence diagram in the next section, but here
we just want to list some important steps that have to be done by control plane during
the rolling restart via warm replica, but without much of low-level implementation details.
1. Register the 'intent' of the instance restart, but not yet interrupt any workload at
primary and also accept new connections. This may require some endpoint state machine
changes, e.g. introduction of the `pending_restart` state. Being in this state also
**mustn't prevent any other operations except restart**: suspend, live-reconfiguration
(e.g. due to notify-attach call from the storage controller), deletion.
2. Start new replica compute on the same timeline and start prewarming it. This process
may take quite a while, so the same concurrency considerations as in 1. should be applied
here as well.
3. When warm replica is ready, control plane should:
3.1. Terminate the primary compute. Starting from here, **this is a critical section**,
if anything goes off, the only option is to start the primary normally and proceed
with auto-prewarm.
3.2. Send cache invalidation message to all proxies, notifying them that all new connections
should request and wait for the new connection details. At this stage, proxy has to also
drop any existing connections to the old primary, so they didn't do stale reads.
3.3. Attach warm replica compute to the primary endpoint inside control plane metadata
database.
3.4. Promote replica to primary.
3.5. When everything is done, finalize the endpoint state to be just `active`.
### Complete rolling restart flow
```mermaid
sequenceDiagram
autonumber
participant proxy as Neon proxy
participant cplane as Control plane
participant primary as Compute (primary)
box Compute (replica)
participant ctl as compute_ctl
participant pg as Postgres
end
box Endpoint unlogged storage
participant s3proxy as Endpoint storage service
participant s3 as S3/ABS/etc.
end
cplane ->> primary: POST /store_lfc_state
primary -->> cplane: 200 OK
cplane ->> ctl: POST /restore_lfc_state
activate ctl
ctl -->> cplane: 202 Accepted
activate cplane
cplane ->> ctl: GET /status: poll prewarm status
ctl ->> s3proxy: GET /read_file
s3proxy ->> s3: read file
s3 -->> s3proxy: file content
s3proxy -->> ctl: 200 OK: file content
proxy ->> cplane: GET /proxy_wake_compute
cplane -->> proxy: 200 OK: old primary conninfo
ctl ->> pg: prewarm LFC
activate pg
pg -->> ctl: prewarm is completed
deactivate pg
ctl -->> cplane: 200 OK: prewarm is completed
deactivate ctl
deactivate cplane
cplane -->> cplane: reassign replica compute to endpoint,<br>start terminating the old primary compute
activate cplane
cplane ->> proxy: invalidate caches
proxy ->> cplane: GET /proxy_wake_compute
cplane -x primary: POST /terminate
primary -->> cplane: 200 OK
note over primary: old primary<br>compute terminated
cplane ->> ctl: POST /promote
activate ctl
ctl ->> pg: pg_ctl promote
activate pg
pg -->> ctl: done
deactivate pg
ctl -->> cplane: 200 OK
deactivate ctl
cplane -->> cplane: finalize operation
cplane -->> proxy: 200 OK: new primary conninfo
deactivate cplane
```
### Network bandwidth and prewarm speed
It's currently known that pageserver can sustain about 3000 RPS per shard for a few running computes.
Large tenants are usually split into 8 shards, so the final formula may look like this:
```text
8 shards * 3000 RPS * 8 KB =~ 190 MB/s
```
so depending on the LFC size, prewarming will take at least:
- ~5s for 1 GB
- ~50s for 10 GB
- ~5m for 100 GB
- \>1h for 1 TB
In total, one pageserver is normally capped by 30k RPS, so it obviously can't sustain many computes
doing prewarm at the same time. Later, we may need an additional mechanism for computes to throttle
the prewarming requests gracefully.
### Reliability, failure modes and corner cases
We consider following failures while implementing this RFC:
1. Compute got interrupted/crashed/restarted during prewarm. The caller -- control plane -- should
detect that and start prewarm from the beginning.
2. Control plane promotion request timed out or hit network issues. If it never reached the
compute, control plane should just repeat it. If it did reach the compute, then during
retry control plane can hit `409` as previous request triggered the promotion already.
In this case, control plane need to retry until either `200` or
permanent error `500` is returned.
3. Compute got interrupted/crashed/restarted during promotion. At restart it will ask for
a spec from control plane, and its content should signal compute to start as **primary**,
so it's expected that control plane will continue polling for certain period of time and
will discover that compute is ready to accept connections if restart is fast enough.
4. Any other unexpected failure or timeout during prewarming. This **failure mustn't be fatal**,
control plane has to report failure, terminate replica and keep primary running.
5. Any other unexpected failure or timeout during promotion. Unfortunately, at this moment
we already have the primary node stopped, so the only option is to start primary again
and proceed with auto-prewarm.
6. Any unexpected failure during auto-prewarm. This **failure mustn't be fatal**,
`compute_ctl` has to report the failure, but do not crash the compute.
7. Control plane failed to confirm that old primary has terminated. This can happen, especially
in the future HA setup. In this case, control plane has to ensure that it sent VM deletion
and pod termination requests to k8s, so long-term we do not have two running primaries
on the same timeline.
### Security implications
There are two security implications to consider:
1. Access to `compute_ctl` API. It has to be accessible from the outside of compute, so all
new API methods have to be exposed on the **external** HTTP port and **must** be authenticated
with JWT.
2. Read/write only your own LFC state data in S3. Although it's not really a security concern,
since LFC state is just a mapping of blocks present in LFC at certain moment in time;
it still has to be highly restricted, so that i) only computes on the same timeline can
read S3 state; ii) each compute can only write to the path that contains it's `endpoint_id`.
Both of this must be validated by Endpoint storage service using the JWT token provided by `compute_ctl`.
### Unresolved questions
#### Billing, metrics and monitoring
Currently, we only label computes with `endpoint_id` after attaching them to the endpoint.
In this proposal, this means that temporary replica will remain unlabelled until it's promoted
to primary. We can also hide it from users in the control plane API, but what to do with
billing and monitoring is still unclear.
We can probably mark it as 'billable' and tag with `project_id`, so it will be billed, but
not interfere in any way with the current primary monitoring.
Another thing to consider is how logs and metrics export will switch to the new compute.
It's expected that OpenTelemetry collector will auto-discover the new compute and start
scraping metrics from it.
#### Auto-prewarm
It's still an open question whether we need auto-prewarm at all. The author's gut-feeling is
that yes, we need it, but might be not for all workloads, so it could end up exposed as a
user-controllable knob on the endpoint. There are two arguments for that:
1. Auto-prewarm existing in upstream's `pg_prewarm`, _probably for a reason_.
2. There are still could be 2 flows when we cannot perform the rolling restart via the warm
replica: i) any failure or interruption during promotion; ii) wake up after scale-to-zero.
The latter might be challenged as well, i.e. one can argue that auto-prewarm may and will
compete with user-workload for storage resources. This is correct, but it might as well
reduce the time to get warm LFC and good performance.
#### Low-level details of the replica promotion
There are many things to consider here, but three items just off the top of my head:
1. How to properly start the `walproposer` inside Postgres.
2. What to do with logical replication. Currently, we do not include logical replication slots
inside basebackup, because nobody advances them at replica, so they just prevent the WAL
deletion. Yet, we do need to have them at primary after promotion. Starting with Postgres 17,
there is a new feature called
[logical replication failover](https://www.postgresql.org/docs/current/logical-replication-failover.html)
and `synchronized_standby_slots` setting, but we need a plan for the older versions. Should we
request a new basebackup during promotion?
3. How do we guarantee that replica will receive all the latest WAL from safekeepers? Do some
'shallow' version of sync safekeepers without data copying? Or just a standard version of
sync safekeepers?
## Alternative implementation
The proposal already assumes one of the alternatives -- do not have any persistent storage for
LFC state. This is possible to implement faster with the proposed API, but it means that
we do not implement auto-prewarm yet.
## Definition of Done
At the end of implementing this RFC we should have two high-level settings that enable:
1. Auto-prewarm of user computes upon restart.
2. Perform primary compute restart via the warm replica promotion.
It also has to be decided what's the criteria for enabling one or both of these flows for
certain clients.

View File

@@ -420,6 +420,7 @@ impl From<NodeSchedulingPolicy> for String {
#[derive(Serialize, Deserialize, Clone, Copy, Eq, PartialEq, Debug)]
pub enum SkSchedulingPolicy {
Active,
Activating,
Pause,
Decomissioned,
}
@@ -430,6 +431,7 @@ impl FromStr for SkSchedulingPolicy {
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(match s {
"active" => Self::Active,
"activating" => Self::Activating,
"pause" => Self::Pause,
"decomissioned" => Self::Decomissioned,
_ => {
@@ -446,6 +448,7 @@ impl From<SkSchedulingPolicy> for String {
use SkSchedulingPolicy::*;
match value {
Active => "active",
Activating => "activating",
Pause => "pause",
Decomissioned => "decomissioned",
}

View File

@@ -78,7 +78,13 @@ pub fn is_expected_io_error(e: &io::Error) -> bool {
use io::ErrorKind::*;
matches!(
e.kind(),
BrokenPipe | ConnectionRefused | ConnectionAborted | ConnectionReset | TimedOut
HostUnreachable
| NetworkUnreachable
| BrokenPipe
| ConnectionRefused
| ConnectionAborted
| ConnectionReset
| TimedOut,
)
}

View File

@@ -52,7 +52,7 @@ pub(crate) async fn hi(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] {
}
// yield every ~250us
// hopefully reduces tail latencies
if i % 1024 == 0 {
if i.is_multiple_of(1024) {
yield_now().await
}
}

View File

@@ -90,7 +90,7 @@ pub struct InnerClient {
}
impl InnerClient {
pub fn start(&mut self) -> Result<PartialQuery, Error> {
pub fn start(&mut self) -> Result<PartialQuery<'_>, Error> {
self.responses.waiting += 1;
Ok(PartialQuery(Some(self)))
}
@@ -227,7 +227,7 @@ impl Client {
&mut self,
statement: &str,
params: I,
) -> Result<RowStream, Error>
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,
@@ -262,7 +262,7 @@ impl Client {
pub(crate) async fn simple_query_raw(
&mut self,
query: &str,
) -> Result<SimpleQueryStream, Error> {
) -> Result<SimpleQueryStream<'_>, Error> {
simple_query::simple_query(self.inner_mut(), query).await
}

View File

@@ -12,7 +12,11 @@ mod private {
/// This trait is "sealed", and cannot be implemented outside of this crate.
pub trait GenericClient: private::Sealed {
/// Like `Client::query_raw_txt`.
async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
@@ -22,7 +26,11 @@ pub trait GenericClient: private::Sealed {
impl private::Sealed for Client {}
impl GenericClient for Client {
async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
@@ -35,7 +43,11 @@ impl GenericClient for Client {
impl private::Sealed for Transaction<'_> {}
impl GenericClient for Transaction<'_> {
async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,

View File

@@ -47,7 +47,7 @@ impl<'a> Transaction<'a> {
&mut self,
statement: &str,
params: I,
) -> Result<RowStream, Error>
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,

View File

@@ -24,12 +24,28 @@ macro_rules! critical {
if cfg!(debug_assertions) {
panic!($($arg)*);
}
// Increment both metrics
$crate::logging::TRACING_EVENT_COUNT_METRIC.inc_critical();
let backtrace = std::backtrace::Backtrace::capture();
tracing::error!("CRITICAL: {}\n{backtrace}", format!($($arg)*));
}};
}
#[macro_export]
macro_rules! critical_timeline {
($tenant_shard_id:expr, $timeline_id:expr, $($arg:tt)*) => {{
if cfg!(debug_assertions) {
panic!($($arg)*);
}
// Increment both metrics
$crate::logging::TRACING_EVENT_COUNT_METRIC.inc_critical();
$crate::logging::HADRON_CRITICAL_STORAGE_EVENT_COUNT_METRIC.inc(&$tenant_shard_id.to_string(), &$timeline_id.to_string());
let backtrace = std::backtrace::Backtrace::capture();
tracing::error!("CRITICAL: [tenant_shard_id: {}, timeline_id: {}] {}\n{backtrace}",
$tenant_shard_id, $timeline_id, format!($($arg)*));
}};
}
#[derive(EnumString, strum_macros::Display, VariantNames, Eq, PartialEq, Debug, Clone, Copy)]
#[strum(serialize_all = "snake_case")]
pub enum LogFormat {
@@ -61,6 +77,36 @@ pub struct TracingEventCountMetric {
trace: IntCounter,
}
// Begin Hadron: Add a HadronCriticalStorageEventCountMetric metric that is sliced by tenant_id and timeline_id
pub struct HadronCriticalStorageEventCountMetric {
critical: IntCounterVec,
}
pub static HADRON_CRITICAL_STORAGE_EVENT_COUNT_METRIC: Lazy<HadronCriticalStorageEventCountMetric> =
Lazy::new(|| {
let vec = metrics::register_int_counter_vec!(
"hadron_critical_storage_event_count",
"Number of critical storage events, by tenant_id and timeline_id",
&["tenant_shard_id", "timeline_id"]
)
.expect("failed to define metric");
HadronCriticalStorageEventCountMetric::new(vec)
});
impl HadronCriticalStorageEventCountMetric {
fn new(vec: IntCounterVec) -> Self {
Self { critical: vec }
}
// Allow public access from `critical!` macro.
pub fn inc(&self, tenant_shard_id: &str, timeline_id: &str) {
self.critical
.with_label_values(&[tenant_shard_id, timeline_id])
.inc();
}
}
// End Hadron
pub static TRACING_EVENT_COUNT_METRIC: Lazy<TracingEventCountMetric> = Lazy::new(|| {
let vec = metrics::register_int_counter_vec!(
"libmetrics_tracing_event_count",

View File

@@ -99,7 +99,7 @@ pub(super) async fn upload_metrics_bucket(
// Compose object path
let datetime: DateTime<Utc> = SystemTime::now().into();
let ts_prefix = datetime.format("year=%Y/month=%m/day=%d/%H:%M:%SZ");
let ts_prefix = datetime.format("year=%Y/month=%m/day=%d/hour=%H/%H:%M:%SZ");
let path = RemotePath::from_string(&format!("{ts_prefix}_{node_id}.ndjson.gz"))?;
// Set up a gzip writer into a buffer
@@ -109,7 +109,7 @@ pub(super) async fn upload_metrics_bucket(
// Serialize and write into compressed buffer
let started_at = std::time::Instant::now();
for res in serialize_in_chunks(CHUNK_SIZE, metrics, idempotency_keys) {
for res in serialize_in_chunks_ndjson(CHUNK_SIZE, metrics, idempotency_keys) {
let (_chunk, body) = res?;
gzip_writer.write_all(&body).await?;
}
@@ -216,6 +216,86 @@ fn serialize_in_chunks<'a>(
}
}
/// Serializes the input metrics as NDJSON in chunks of chunk_size. Each event
/// is serialized as a separate JSON object on its own line. The provided
/// idempotency keys are injected into the corresponding metric events (reused
/// across different metrics sinks), and must have the same length as input.
fn serialize_in_chunks_ndjson<'a>(
chunk_size: usize,
input: &'a [NewRawMetric],
idempotency_keys: &'a [IdempotencyKey<'a>],
) -> impl ExactSizeIterator<Item = Result<(&'a [NewRawMetric], bytes::Bytes), serde_json::Error>> + 'a
{
use bytes::BufMut;
assert_eq!(input.len(), idempotency_keys.len());
struct Iter<'a> {
inner: std::slice::Chunks<'a, NewRawMetric>,
idempotency_keys: std::slice::Iter<'a, IdempotencyKey<'a>>,
chunk_size: usize,
// write to a BytesMut so that we can cheaply clone the frozen Bytes for retries
buffer: bytes::BytesMut,
// chunk amount of events are reused to produce the serialized document
scratch: Vec<Event<Ids, Name>>,
}
impl<'a> Iterator for Iter<'a> {
type Item = Result<(&'a [NewRawMetric], bytes::Bytes), serde_json::Error>;
fn next(&mut self) -> Option<Self::Item> {
let chunk = self.inner.next()?;
if self.scratch.is_empty() {
// first round: create events with N strings
self.scratch.extend(
chunk
.iter()
.zip(&mut self.idempotency_keys)
.map(|(raw_metric, key)| raw_metric.as_event(key)),
);
} else {
// next rounds: update_in_place to reuse allocations
assert_eq!(self.scratch.len(), self.chunk_size);
itertools::izip!(self.scratch.iter_mut(), chunk, &mut self.idempotency_keys)
.for_each(|(slot, raw_metric, key)| raw_metric.update_in_place(slot, key));
}
// Serialize each event as NDJSON (one JSON object per line)
for event in self.scratch[..chunk.len()].iter() {
let res = serde_json::to_writer((&mut self.buffer).writer(), event);
if let Err(e) = res {
return Some(Err(e));
}
// Add newline after each event to follow NDJSON format
self.buffer.put_u8(b'\n');
}
Some(Ok((chunk, self.buffer.split().freeze())))
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
impl ExactSizeIterator for Iter<'_> {}
let buffer = bytes::BytesMut::new();
let inner = input.chunks(chunk_size);
let idempotency_keys = idempotency_keys.iter();
let scratch = Vec::new();
Iter {
inner,
idempotency_keys,
chunk_size,
buffer,
scratch,
}
}
trait RawMetricExt {
fn as_event(&self, key: &IdempotencyKey<'_>) -> Event<Ids, Name>;
fn update_in_place(&self, event: &mut Event<Ids, Name>, key: &IdempotencyKey<'_>);
@@ -479,6 +559,43 @@ mod tests {
}
}
#[test]
fn chunked_serialization_ndjson() {
let examples = metric_samples();
assert!(examples.len() > 1);
let now = Utc::now();
let idempotency_keys = (0..examples.len())
.map(|i| FixedGen::new(now, "1", i as u16).generate())
.collect::<Vec<_>>();
// Parse NDJSON format - each line is a separate JSON object
let parse_ndjson = |body: &[u8]| -> Vec<Event<Ids, Name>> {
let body_str = std::str::from_utf8(body).unwrap();
body_str
.trim_end_matches('\n')
.lines()
.filter(|line| !line.is_empty())
.map(|line| serde_json::from_str::<Event<Ids, Name>>(line).unwrap())
.collect()
};
let correct = serialize_in_chunks_ndjson(examples.len(), &examples, &idempotency_keys)
.map(|res| res.unwrap().1)
.flat_map(|body| parse_ndjson(&body))
.collect::<Vec<_>>();
for chunk_size in 1..examples.len() {
let actual = serialize_in_chunks_ndjson(chunk_size, &examples, &idempotency_keys)
.map(|res| res.unwrap().1)
.flat_map(|body| parse_ndjson(&body))
.collect::<Vec<_>>();
// if these are equal, it means that multi-chunking version works as well
assert_eq!(correct, actual);
}
}
#[derive(Clone, Copy)]
struct FixedGen<'a>(chrono::DateTime<chrono::Utc>, &'a str, u16);

View File

@@ -2438,6 +2438,7 @@ async fn timeline_offload_handler(
.map_err(|e| {
match e {
OffloadError::Cancelled => ApiError::ResourceUnavailable("Timeline shutting down".into()),
OffloadError::AlreadyInProgress => ApiError::Conflict("Timeline already being offloaded or deleted".into()),
_ => ApiError::InternalServerError(anyhow!(e))
}
})?;

View File

@@ -3285,6 +3285,7 @@ impl TenantShard {
.or_else(|err| match err {
// Ignore this, we likely raced with unarchival.
OffloadError::NotArchived => Ok(()),
OffloadError::AlreadyInProgress => Ok(()),
err => Err(err),
})?;
}

View File

@@ -78,7 +78,7 @@ use utils::rate_limit::RateLimit;
use utils::seqwait::SeqWait;
use utils::simple_rcu::{Rcu, RcuReadGuard};
use utils::sync::gate::{Gate, GateGuard};
use utils::{completion, critical, fs_ext, pausable_failpoint};
use utils::{completion, critical_timeline, fs_ext, pausable_failpoint};
#[cfg(test)]
use wal_decoder::models::value::Value;
use wal_decoder::serialized_batch::{SerializedValueBatch, ValueMeta};
@@ -4729,7 +4729,7 @@ impl Timeline {
}
// Fetch the next layer to flush, if any.
let (layer, l0_count, frozen_count, frozen_size) = {
let (layer, l0_count, frozen_count, frozen_size, open_layer_size) = {
let layers = self.layers.read(LayerManagerLockHolder::FlushLoop).await;
let Ok(lm) = layers.layer_map() else {
info!("dropping out of flush loop for timeline shutdown");
@@ -4742,8 +4742,13 @@ impl Timeline {
.iter()
.map(|l| l.estimated_in_mem_size())
.sum();
let open_layer_size: u64 = lm
.open_layer
.as_ref()
.map(|l| l.estimated_in_mem_size())
.unwrap_or(0);
let layer = lm.frozen_layers.front().cloned();
(layer, l0_count, frozen_count, frozen_size)
(layer, l0_count, frozen_count, frozen_size, open_layer_size)
// drop 'layers' lock
};
let Some(layer) = layer else {
@@ -4756,7 +4761,7 @@ impl Timeline {
if l0_count >= stall_threshold {
warn!(
"stalling layer flushes for compaction backpressure at {l0_count} \
L0 layers ({frozen_count} frozen layers with {frozen_size} bytes)"
L0 layers ({frozen_count} frozen layers with {frozen_size} bytes, {open_layer_size} bytes in open layer)"
);
let stall_timer = self
.metrics
@@ -4809,7 +4814,7 @@ impl Timeline {
let delay = flush_duration.as_secs_f64();
info!(
"delaying layer flush by {delay:.3}s for compaction backpressure at \
{l0_count} L0 layers ({frozen_count} frozen layers with {frozen_size} bytes)"
{l0_count} L0 layers ({frozen_count} frozen layers with {frozen_size} bytes, {open_layer_size} bytes in open layer)"
);
let _delay_timer = self
.metrics
@@ -6819,7 +6824,11 @@ impl Timeline {
Err(walredo::Error::Cancelled) => return Err(PageReconstructError::Cancelled),
Err(walredo::Error::Other(err)) => {
if fire_critical_error {
critical!("walredo failure during page reconstruction: {err:?}");
critical_timeline!(
self.tenant_shard_id,
self.timeline_id,
"walredo failure during page reconstruction: {err:?}"
);
}
return Err(PageReconstructError::WalRedo(
err.context("reconstruct a page image"),

View File

@@ -36,7 +36,7 @@ use serde::Serialize;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info, info_span, trace, warn};
use utils::critical;
use utils::critical_timeline;
use utils::id::TimelineId;
use utils::lsn::Lsn;
use wal_decoder::models::record::NeonWalRecord;
@@ -1390,7 +1390,11 @@ impl Timeline {
GetVectoredError::MissingKey(_),
) = err
{
critical!("missing key during compaction: {err:?}");
critical_timeline!(
self.tenant_shard_id,
self.timeline_id,
"missing key during compaction: {err:?}"
);
}
})?;
@@ -1418,7 +1422,11 @@ impl Timeline {
// Alert on critical errors that indicate data corruption.
Err(err) if err.is_critical() => {
critical!("could not compact, repartitioning keyspace failed: {err:?}");
critical_timeline!(
self.tenant_shard_id,
self.timeline_id,
"could not compact, repartitioning keyspace failed: {err:?}"
);
}
// Log other errors. No partitioning? This is normal, if the timeline was just created

View File

@@ -182,6 +182,7 @@ pub(crate) async fn generate_tombstone_image_layer(
detached: &Arc<Timeline>,
ancestor: &Arc<Timeline>,
ancestor_lsn: Lsn,
historic_layers_to_copy: &Vec<Layer>,
ctx: &RequestContext,
) -> Result<Option<ResidentLayer>, Error> {
tracing::info!(
@@ -199,6 +200,20 @@ pub(crate) async fn generate_tombstone_image_layer(
let image_lsn = ancestor_lsn;
{
for layer in historic_layers_to_copy {
let desc = layer.layer_desc();
if !desc.is_delta
&& desc.lsn_range.start == image_lsn
&& overlaps_with(&key_range, &desc.key_range)
{
tracing::info!(
layer=%layer, "will copy tombstone from ancestor instead of creating a new one"
);
return Ok(None);
}
}
let layers = detached
.layers
.read(LayerManagerLockHolder::DetachAncestor)
@@ -450,7 +465,8 @@ pub(super) async fn prepare(
Vec::with_capacity(straddling_branchpoint.len() + rest_of_historic.len() + 1);
if let Some(tombstone_layer) =
generate_tombstone_image_layer(detached, &ancestor, ancestor_lsn, ctx).await?
generate_tombstone_image_layer(detached, &ancestor, ancestor_lsn, &rest_of_historic, ctx)
.await?
{
new_layers.push(tombstone_layer.into());
}

View File

@@ -19,6 +19,8 @@ pub(crate) enum OffloadError {
NotArchived,
#[error(transparent)]
RemoteStorage(anyhow::Error),
#[error("Offload or deletion already in progress")]
AlreadyInProgress,
#[error("Unexpected offload error: {0}")]
Other(anyhow::Error),
}
@@ -44,20 +46,26 @@ pub(crate) async fn offload_timeline(
timeline.timeline_id,
TimelineDeleteGuardKind::Offload,
);
if let Err(DeleteTimelineError::HasChildren(children)) = delete_guard_res {
let is_archived = timeline.is_archived();
if is_archived == Some(true) {
tracing::error!("timeline is archived but has non-archived children: {children:?}");
let (timeline, guard) = match delete_guard_res {
Ok(timeline_and_guard) => timeline_and_guard,
Err(DeleteTimelineError::HasChildren(children)) => {
let is_archived = timeline.is_archived();
if is_archived == Some(true) {
tracing::error!("timeline is archived but has non-archived children: {children:?}");
return Err(OffloadError::NotArchived);
}
tracing::info!(
?is_archived,
"timeline is not archived and has unarchived children"
);
return Err(OffloadError::NotArchived);
}
tracing::info!(
?is_archived,
"timeline is not archived and has unarchived children"
);
return Err(OffloadError::NotArchived);
Err(DeleteTimelineError::AlreadyInProgress(_)) => {
tracing::info!("timeline offload or deletion already in progress");
return Err(OffloadError::AlreadyInProgress);
}
Err(e) => return Err(OffloadError::Other(anyhow::anyhow!(e))),
};
let (timeline, guard) =
delete_guard_res.map_err(|e| OffloadError::Other(anyhow::anyhow!(e)))?;
let TimelineOrOffloaded::Timeline(timeline) = timeline else {
tracing::error!("timeline already offloaded, but given timeline object");

View File

@@ -25,7 +25,7 @@ use tokio_postgres::replication::ReplicationStream;
use tokio_postgres::{Client, SimpleQueryMessage, SimpleQueryRow};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info, trace, warn};
use utils::critical;
use utils::critical_timeline;
use utils::id::NodeId;
use utils::lsn::Lsn;
use utils::pageserver_feedback::PageserverFeedback;
@@ -368,9 +368,13 @@ pub(super) async fn handle_walreceiver_connection(
match raw_wal_start_lsn.cmp(&expected_wal_start) {
std::cmp::Ordering::Greater => {
let msg = format!(
"Gap in streamed WAL: [{expected_wal_start}, {raw_wal_start_lsn})"
"Gap in streamed WAL: [{expected_wal_start}, {raw_wal_start_lsn}"
);
critical_timeline!(
timeline.tenant_shard_id,
timeline.timeline_id,
"{msg}"
);
critical!("{msg}");
return Err(WalReceiverError::Other(anyhow!(msg)));
}
std::cmp::Ordering::Less => {
@@ -383,7 +387,11 @@ pub(super) async fn handle_walreceiver_connection(
"Received record with next_record_lsn multiple times ({} < {})",
first_rec.next_record_lsn, expected_wal_start
);
critical!("{msg}");
critical_timeline!(
timeline.tenant_shard_id,
timeline.timeline_id,
"{msg}"
);
return Err(WalReceiverError::Other(anyhow!(msg)));
}
}
@@ -452,7 +460,11 @@ pub(super) async fn handle_walreceiver_connection(
// TODO: we can't differentiate cancellation errors with
// anyhow::Error, so just ignore it if we're cancelled.
if !cancellation.is_cancelled() && !timeline.is_stopping() {
critical!("{err:?}")
critical_timeline!(
timeline.tenant_shard_id,
timeline.timeline_id,
"{err:?}"
);
}
})?;

View File

@@ -40,7 +40,7 @@ use tracing::*;
use utils::bin_ser::{DeserializeError, SerializeError};
use utils::lsn::Lsn;
use utils::rate_limit::RateLimit;
use utils::{critical, failpoint_support};
use utils::{critical_timeline, failpoint_support};
use wal_decoder::models::record::NeonWalRecord;
use wal_decoder::models::*;
@@ -418,18 +418,30 @@ impl WalIngest {
// as there has historically been cases where PostgreSQL has cleared spurious VM pages. See:
// https://github.com/neondatabase/neon/pull/10634.
let Some(vm_size) = get_relsize(modification, vm_rel, ctx).await? else {
critical!("clear_vm_bits for unknown VM relation {vm_rel}");
critical_timeline!(
modification.tline.tenant_shard_id,
modification.tline.timeline_id,
"clear_vm_bits for unknown VM relation {vm_rel}"
);
return Ok(());
};
if let Some(blknum) = new_vm_blk {
if blknum >= vm_size {
critical!("new_vm_blk {blknum} not in {vm_rel} of size {vm_size}");
critical_timeline!(
modification.tline.tenant_shard_id,
modification.tline.timeline_id,
"new_vm_blk {blknum} not in {vm_rel} of size {vm_size}"
);
new_vm_blk = None;
}
}
if let Some(blknum) = old_vm_blk {
if blknum >= vm_size {
critical!("old_vm_blk {blknum} not in {vm_rel} of size {vm_size}");
critical_timeline!(
modification.tline.tenant_shard_id,
modification.tline.timeline_id,
"old_vm_blk {blknum} not in {vm_rel} of size {vm_size}"
);
old_vm_blk = None;
}
}

View File

@@ -87,6 +87,14 @@ static const struct config_enum_entry running_xacts_overflow_policies[] = {
{NULL, 0, false}
};
static const struct config_enum_entry debug_compare_local_modes[] = {
{"none", DEBUG_COMPARE_LOCAL_NONE, false},
{"prefetch", DEBUG_COMPARE_LOCAL_PREFETCH, false},
{"lfc", DEBUG_COMPARE_LOCAL_LFC, false},
{"all", DEBUG_COMPARE_LOCAL_ALL, false},
{NULL, 0, false}
};
/*
* XXX: These private to procarray.c, but we need them here.
*/
@@ -519,6 +527,16 @@ _PG_init(void)
GUC_UNIT_KB,
NULL, NULL, NULL);
DefineCustomEnumVariable(
"neon.debug_compare_local",
"Debug mode for compaing content of pages in prefetch ring/LFC/PS and local disk",
NULL,
&debug_compare_local,
DEBUG_COMPARE_LOCAL_NONE,
debug_compare_local_modes,
PGC_POSTMASTER,
0,
NULL, NULL, NULL);
/*
* Important: This must happen after other parts of the extension are
* loaded, otherwise any settings to GUCs that were set before the

View File

@@ -177,6 +177,22 @@ extern StringInfoData nm_pack_request(NeonRequest *msg);
extern NeonResponse *nm_unpack_response(StringInfo s);
extern char *nm_to_string(NeonMessage *msg);
/*
* If debug_compare_local>DEBUG_COMPARE_LOCAL_NONE, we pass through all the SMGR API
* calls to md.c, and *also* do the calls to the Page Server. On every
* read, compare the versions we read from local disk and Page Server,
* and Assert that they are identical.
*/
typedef enum
{
DEBUG_COMPARE_LOCAL_NONE, /* normal mode - pages are storted locally only for unlogged relations */
DEBUG_COMPARE_LOCAL_PREFETCH, /* if page is found in prefetch ring, then compare it with local and return */
DEBUG_COMPARE_LOCAL_LFC, /* if page is found in LFC or prefetch ring, then compare it with local and return */
DEBUG_COMPARE_LOCAL_ALL /* always fetch page from PS and compare it with local */
} DebugCompareLocalMode;
extern int debug_compare_local;
/*
* API
*/

View File

@@ -76,21 +76,11 @@
typedef PGAlignedBlock PGIOAlignedBlock;
#endif
/*
* If DEBUG_COMPARE_LOCAL is defined, we pass through all the SMGR API
* calls to md.c, and *also* do the calls to the Page Server. On every
* read, compare the versions we read from local disk and Page Server,
* and Assert that they are identical.
*/
/* #define DEBUG_COMPARE_LOCAL */
#ifdef DEBUG_COMPARE_LOCAL
#include "access/nbtree.h"
#include "storage/bufpage.h"
#include "access/xlog_internal.h"
static char *hexdump_page(char *page);
#endif
#define IS_LOCAL_REL(reln) (\
NInfoGetDbOid(InfoFromSMgrRel(reln)) != 0 && \
@@ -108,6 +98,8 @@ typedef enum
UNLOGGED_BUILD_NOT_PERMANENT
} UnloggedBuildPhase;
int debug_compare_local;
static NRelFileInfo unlogged_build_rel_info;
static UnloggedBuildPhase unlogged_build_phase = UNLOGGED_BUILD_NOT_IN_PROGRESS;
@@ -478,9 +470,10 @@ neon_init(void)
old_redo_read_buffer_filter = redo_read_buffer_filter;
redo_read_buffer_filter = neon_redo_read_buffer_filter;
#ifdef DEBUG_COMPARE_LOCAL
mdinit();
#endif
if (debug_compare_local)
{
mdinit();
}
}
/*
@@ -803,13 +796,16 @@ neon_create(SMgrRelation reln, ForkNumber forkNum, bool isRedo)
case RELPERSISTENCE_TEMP:
case RELPERSISTENCE_UNLOGGED:
#ifdef DEBUG_COMPARE_LOCAL
mdcreate(reln, forkNum, forkNum == INIT_FORKNUM || isRedo);
if (forkNum == MAIN_FORKNUM)
mdcreate(reln, INIT_FORKNUM, true);
#else
mdcreate(reln, forkNum, isRedo);
#endif
if (debug_compare_local)
{
mdcreate(reln, forkNum, forkNum == INIT_FORKNUM || isRedo);
if (forkNum == MAIN_FORKNUM)
mdcreate(reln, INIT_FORKNUM, true);
}
else
{
mdcreate(reln, forkNum, isRedo);
}
return;
default:
@@ -848,10 +844,11 @@ neon_create(SMgrRelation reln, ForkNumber forkNum, bool isRedo)
else
set_cached_relsize(InfoFromSMgrRel(reln), forkNum, 0);
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdcreate(reln, forkNum, isRedo);
#endif
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdcreate(reln, forkNum, isRedo);
}
}
/*
@@ -877,7 +874,7 @@ neon_unlink(NRelFileInfoBackend rinfo, ForkNumber forkNum, bool isRedo)
{
/*
* Might or might not exist locally, depending on whether it's an unlogged
* or permanent relation (or if DEBUG_COMPARE_LOCAL is set). Try to
* or permanent relation (or if debug_compare_local is set). Try to
* unlink, it won't do any harm if the file doesn't exist.
*/
mdunlink(rinfo, forkNum, isRedo);
@@ -973,10 +970,11 @@ neon_extend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno,
lfc_write(InfoFromSMgrRel(reln), forkNum, blkno, buffer);
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdextend(reln, forkNum, blkno, buffer, skipFsync);
#endif
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdextend(reln, forkNum, blkno, buffer, skipFsync);
}
/*
* smgr_extend is often called with an all-zeroes page, so
@@ -1051,10 +1049,11 @@ neon_zeroextend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blocknum,
relpath(reln->smgr_rlocator, forkNum),
InvalidBlockNumber)));
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdzeroextend(reln, forkNum, blocknum, nblocks, skipFsync);
#endif
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdzeroextend(reln, forkNum, blocknum, nblocks, skipFsync);
}
/* Don't log any pages if we're not allowed to do so. */
if (!XLogInsertAllowed())
@@ -1265,10 +1264,11 @@ neon_writeback(SMgrRelation reln, ForkNumber forknum,
communicator_prefetch_pump_state();
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdwriteback(reln, forknum, blocknum, nblocks);
#endif
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdwriteback(reln, forknum, blocknum, nblocks);
}
}
/*
@@ -1282,7 +1282,6 @@ neon_read_at_lsn(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
communicator_read_at_lsnv(rinfo, forkNum, blkno, &request_lsns, &buffer, 1, NULL);
}
#ifdef DEBUG_COMPARE_LOCAL
static void
compare_with_local(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, void* buffer, XLogRecPtr request_lsn)
{
@@ -1364,7 +1363,6 @@ compare_with_local(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, voi
}
}
}
#endif
#if PG_MAJORVERSION_NUM < 17
@@ -1417,22 +1415,28 @@ neon_read(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, void *buffer
if (communicator_prefetch_lookupv(InfoFromSMgrRel(reln), forkNum, blkno, &request_lsns, 1, &bufferp, &present))
{
/* Prefetch hit */
#ifdef DEBUG_COMPARE_LOCAL
compare_with_local(reln, forkNum, blkno, buffer, request_lsns.request_lsn);
#else
return;
#endif
if (debug_compare_local >= DEBUG_COMPARE_LOCAL_PREFETCH)
{
compare_with_local(reln, forkNum, blkno, buffer, request_lsns.request_lsn);
}
if (debug_compare_local <= DEBUG_COMPARE_LOCAL_PREFETCH)
{
return;
}
}
/* Try to read from local file cache */
if (lfc_read(InfoFromSMgrRel(reln), forkNum, blkno, buffer))
{
MyNeonCounters->file_cache_hits_total++;
#ifdef DEBUG_COMPARE_LOCAL
compare_with_local(reln, forkNum, blkno, buffer, request_lsns.request_lsn);
#else
return;
#endif
if (debug_compare_local >= DEBUG_COMPARE_LOCAL_LFC)
{
compare_with_local(reln, forkNum, blkno, buffer, request_lsns.request_lsn);
}
if (debug_compare_local <= DEBUG_COMPARE_LOCAL_LFC)
{
return;
}
}
neon_read_at_lsn(InfoFromSMgrRel(reln), forkNum, blkno, request_lsns, buffer);
@@ -1442,15 +1446,15 @@ neon_read(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, void *buffer
*/
communicator_prefetch_pump_state();
#ifdef DEBUG_COMPARE_LOCAL
compare_with_local(reln, forkNum, blkno, buffer, request_lsns.request_lsn);
#endif
if (debug_compare_local)
{
compare_with_local(reln, forkNum, blkno, buffer, request_lsns.request_lsn);
}
}
#endif /* PG_MAJORVERSION_NUM <= 16 */
#if PG_MAJORVERSION_NUM >= 17
#ifdef DEBUG_COMPARE_LOCAL
static void
compare_with_localv(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, void** buffers, BlockNumber nblocks, neon_request_lsns* request_lsns, bits8* read_pages)
{
@@ -1465,7 +1469,6 @@ compare_with_localv(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, vo
}
}
}
#endif
static void
@@ -1516,13 +1519,19 @@ neon_readv(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
blocknum, request_lsns, nblocks,
buffers, read_pages);
#ifdef DEBUG_COMPARE_LOCAL
compare_with_localv(reln, forknum, blocknum, buffers, nblocks, request_lsns, read_pages);
memset(read_pages, 0, sizeof(read_pages));
#else
if (prefetch_result == nblocks)
if (debug_compare_local >= DEBUG_COMPARE_LOCAL_PREFETCH)
{
compare_with_localv(reln, forknum, blocknum, buffers, nblocks, request_lsns, read_pages);
}
if (debug_compare_local <= DEBUG_COMPARE_LOCAL_PREFETCH && prefetch_result == nblocks)
{
return;
#endif
}
if (debug_compare_local > DEBUG_COMPARE_LOCAL_PREFETCH)
{
memset(read_pages, 0, sizeof(read_pages));
}
/* Try to read from local file cache */
lfc_result = lfc_readv_select(InfoFromSMgrRel(reln), forknum, blocknum, buffers,
@@ -1531,14 +1540,19 @@ neon_readv(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
if (lfc_result > 0)
MyNeonCounters->file_cache_hits_total += lfc_result;
#ifdef DEBUG_COMPARE_LOCAL
compare_with_localv(reln, forknum, blocknum, buffers, nblocks, request_lsns, read_pages);
memset(read_pages, 0, sizeof(read_pages));
#else
/* Read all blocks from LFC, so we're done */
if (prefetch_result + lfc_result == nblocks)
if (debug_compare_local >= DEBUG_COMPARE_LOCAL_LFC)
{
compare_with_localv(reln, forknum, blocknum, buffers, nblocks, request_lsns, read_pages);
}
if (debug_compare_local <= DEBUG_COMPARE_LOCAL_LFC && prefetch_result + lfc_result == nblocks)
{
/* Read all blocks from LFC, so we're done */
return;
#endif
}
if (debug_compare_local > DEBUG_COMPARE_LOCAL_LFC)
{
memset(read_pages, 0, sizeof(read_pages));
}
communicator_read_at_lsnv(InfoFromSMgrRel(reln), forknum, blocknum, request_lsns,
buffers, nblocks, read_pages);
@@ -1548,14 +1562,14 @@ neon_readv(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
*/
communicator_prefetch_pump_state();
#ifdef DEBUG_COMPARE_LOCAL
memset(read_pages, 0xFF, sizeof(read_pages));
compare_with_localv(reln, forknum, blocknum, buffers, nblocks, request_lsns, read_pages);
#endif
if (debug_compare_local)
{
memset(read_pages, 0xFF, sizeof(read_pages));
compare_with_localv(reln, forknum, blocknum, buffers, nblocks, request_lsns, read_pages);
}
}
#endif
#ifdef DEBUG_COMPARE_LOCAL
static char *
hexdump_page(char *page)
{
@@ -1574,7 +1588,6 @@ hexdump_page(char *page)
return result.data;
}
#endif
#if PG_MAJORVERSION_NUM < 17
/*
@@ -1596,12 +1609,8 @@ neon_write(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, const vo
switch (reln->smgr_relpersistence)
{
case 0:
#ifndef DEBUG_COMPARE_LOCAL
/* This is a bit tricky. Check if the relation exists locally */
if (mdexists(reln, forknum))
#else
if (mdexists(reln, INIT_FORKNUM))
#endif
if (mdexists(reln, debug_compare_local ? INIT_FORKNUM : forknum))
{
/* It exists locally. Guess it's unlogged then. */
#if PG_MAJORVERSION_NUM >= 17
@@ -1656,14 +1665,17 @@ neon_write(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, const vo
communicator_prefetch_pump_state();
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
{
#if PG_MAJORVERSION_NUM >= 17
mdwritev(reln, forknum, blocknum, &buffer, 1, skipFsync);
mdwritev(reln, forknum, blocknum, &buffer, 1, skipFsync);
#else
mdwrite(reln, forknum, blocknum, buffer, skipFsync);
mdwrite(reln, forknum, blocknum, buffer, skipFsync);
#endif
#endif
}
}
}
#endif
@@ -1677,12 +1689,8 @@ neon_writev(SMgrRelation reln, ForkNumber forknum, BlockNumber blkno,
switch (reln->smgr_relpersistence)
{
case 0:
#ifndef DEBUG_COMPARE_LOCAL
/* This is a bit tricky. Check if the relation exists locally */
if (mdexists(reln, forknum))
#else
if (mdexists(reln, INIT_FORKNUM))
#endif
if (mdexists(reln, debug_compare_local ? INIT_FORKNUM : forknum))
{
/* It exists locally. Guess it's unlogged then. */
mdwritev(reln, forknum, blkno, buffers, nblocks, skipFsync);
@@ -1720,10 +1728,11 @@ neon_writev(SMgrRelation reln, ForkNumber forknum, BlockNumber blkno,
communicator_prefetch_pump_state();
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdwritev(reln, forknum, blkno, buffers, nblocks, skipFsync);
#endif
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdwritev(reln, forknum, blkno, buffers, nblocks, skipFsync);
}
}
#endif
@@ -1862,10 +1871,11 @@ neon_truncate(SMgrRelation reln, ForkNumber forknum, BlockNumber old_blocks, Blo
*/
neon_set_lwlsn_relation(lsn, InfoFromSMgrRel(reln), forknum);
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdtruncate(reln, forknum, old_blocks, nblocks);
#endif
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdtruncate(reln, forknum, old_blocks, nblocks);
}
}
/*
@@ -1904,10 +1914,11 @@ neon_immedsync(SMgrRelation reln, ForkNumber forknum)
communicator_prefetch_pump_state();
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdimmedsync(reln, forknum);
#endif
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdimmedsync(reln, forknum);
}
}
#if PG_MAJORVERSION_NUM >= 17
@@ -1934,10 +1945,11 @@ neon_registersync(SMgrRelation reln, ForkNumber forknum)
neon_log(SmgrTrace, "[NEON_SMGR] registersync noop");
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdimmedsync(reln, forknum);
#endif
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdimmedsync(reln, forknum);
}
}
#endif
@@ -1978,10 +1990,11 @@ neon_start_unlogged_build(SMgrRelation reln)
case RELPERSISTENCE_UNLOGGED:
unlogged_build_rel_info = InfoFromSMgrRel(reln);
unlogged_build_phase = UNLOGGED_BUILD_NOT_PERMANENT;
#ifdef DEBUG_COMPARE_LOCAL
if (!IsParallelWorker())
mdcreate(reln, INIT_FORKNUM, true);
#endif
if (debug_compare_local)
{
if (!IsParallelWorker())
mdcreate(reln, INIT_FORKNUM, true);
}
return;
default:
@@ -2009,11 +2022,7 @@ neon_start_unlogged_build(SMgrRelation reln)
*/
if (!IsParallelWorker())
{
#ifndef DEBUG_COMPARE_LOCAL
mdcreate(reln, MAIN_FORKNUM, false);
#else
mdcreate(reln, INIT_FORKNUM, true);
#endif
mdcreate(reln, debug_compare_local ? INIT_FORKNUM : MAIN_FORKNUM, false);
}
}
@@ -2107,14 +2116,14 @@ neon_end_unlogged_build(SMgrRelation reln)
lfc_invalidate(InfoFromNInfoB(rinfob), forknum, nblocks);
mdclose(reln, forknum);
#ifndef DEBUG_COMPARE_LOCAL
/* use isRedo == true, so that we drop it immediately */
mdunlink(rinfob, forknum, true);
#endif
if (!debug_compare_local)
{
/* use isRedo == true, so that we drop it immediately */
mdunlink(rinfob, forknum, true);
}
}
#ifdef DEBUG_COMPARE_LOCAL
mdunlink(rinfob, INIT_FORKNUM, true);
#endif
if (debug_compare_local)
mdunlink(rinfob, INIT_FORKNUM, true);
}
NRelFileInfoInvalidate(unlogged_build_rel_info);
unlogged_build_phase = UNLOGGED_BUILD_NOT_IN_PROGRESS;

View File

@@ -5,8 +5,9 @@ edition = "2024"
license.workspace = true
[features]
default = []
default = ["rest_broker"]
testing = ["dep:tokio-postgres"]
rest_broker = ["subzero-core", "jsonpath_lib", "ouroboros"]
[dependencies]
ahash.workspace = true
@@ -103,6 +104,9 @@ uuid.workspace = true
x509-cert.workspace = true
redis.workspace = true
zerocopy.workspace = true
subzero-core = { git = "https://github.com/neondatabase-labs/subzero", rev = "0b3d3278f5f9ac9311a7280cb1676de80e021f06", features = ["postgresql"], optional = true }
jsonpath_lib = { version = "0.3.0", optional = true }
ouroboros = { version = "0.18", optional = true }
# jwt stuff
jose-jwa = "0.1.2"

View File

@@ -138,3 +138,69 @@ Now from client you can start a new session:
```sh
PGSSLROOTCERT=./server.crt psql "postgresql://proxy:password@endpoint.local.neon.build:4432/postgres?sslmode=verify-full"
```
## auth broker setup:
Create a postgres instance:
```sh
docker run \
--detach \
--name proxy-postgres \
--env POSTGRES_HOST_AUTH_METHOD=trust \
--env POSTGRES_USER=authenticated \
--env POSTGRES_DB=database \
--publish 5432:5432 \
postgres:17-bookworm
```
Create a configuration file called `local_proxy.json` in the root of the repo (used also by the auth broker to validate JWTs)
```sh
{
"jwks": [
{
"id": "1",
"role_names": ["authenticator", "authenticated", "anon"],
"jwks_url": "https://climbing-minnow-11.clerk.accounts.dev/.well-known/jwks.json",
"provider_name": "foo",
"jwt_audience": null
}
]
}
```
Start the local proxy:
```sh
cargo run --bin local_proxy --features testing -- \
--disable-pg-session-jwt \
--http 0.0.0.0:7432
```
Start the auth broker:
```sh
LOGFMT=text OTEL_SDK_DISABLED=true cargo run --bin proxy --features testing -- \
-c server.crt -k server.key \
--is-auth-broker true \
--is-rest-broker true \
--wss 0.0.0.0:8080 \
--http 0.0.0.0:7002 \
--auth-backend local
```
Create a JWT in your auth provider (e.g. Clerk) and set it in the `NEON_JWT` environment variable.
```sh
export NEON_JWT="..."
```
Run a query against the auth broker:
```sh
curl -k "https://foo.local.neon.build:8080/sql" \
-H "Authorization: Bearer $NEON_JWT" \
-H "neon-connection-string: postgresql://authenticator@foo.local.neon.build/database" \
-d '{"query":"select 1","params":[]}'
```
Make a rest request against the auth broker (rest broker):
```sh
curl -k "https://foo.local.neon.build:8080/database/rest/v1/items?select=id,name&id=eq.1" \
-H "Authorization: Bearer $NEON_JWT"
```

View File

@@ -164,21 +164,20 @@ async fn authenticate(
})?
.map_err(ConsoleRedirectError::from)?;
if auth_config.ip_allowlist_check_enabled {
if let Some(allowed_ips) = &db_info.allowed_ips {
if !auth::check_peer_addr_is_in_list(&ctx.peer_addr(), allowed_ips) {
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
}
}
if auth_config.ip_allowlist_check_enabled
&& let Some(allowed_ips) = &db_info.allowed_ips
&& !auth::check_peer_addr_is_in_list(&ctx.peer_addr(), allowed_ips)
{
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
}
// Check if the access over the public internet is allowed, otherwise block. Note that
// the console redirect is not behind the VPC service endpoint, so we don't need to check
// the VPC endpoint ID.
if let Some(public_access_allowed) = db_info.public_access_allowed {
if !public_access_allowed {
return Err(auth::AuthError::NetworkNotAllowed);
}
if let Some(public_access_allowed) = db_info.public_access_allowed
&& !public_access_allowed
{
return Err(auth::AuthError::NetworkNotAllowed);
}
client.write_message(BeMessage::NoticeResponse("Connecting to database."));

View File

@@ -399,36 +399,36 @@ impl JwkCacheEntryLock {
tracing::debug!(?payload, "JWT signature valid with claims");
if let Some(aud) = expected_audience {
if payload.audience.0.iter().all(|s| s != aud) {
return Err(JwtError::InvalidClaims(
JwtClaimsError::InvalidJwtTokenAudience,
));
}
if let Some(aud) = expected_audience
&& payload.audience.0.iter().all(|s| s != aud)
{
return Err(JwtError::InvalidClaims(
JwtClaimsError::InvalidJwtTokenAudience,
));
}
let now = SystemTime::now();
if let Some(exp) = payload.expiration {
if now >= exp + CLOCK_SKEW_LEEWAY {
return Err(JwtError::InvalidClaims(JwtClaimsError::JwtTokenHasExpired(
exp.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
)));
}
if let Some(exp) = payload.expiration
&& now >= exp + CLOCK_SKEW_LEEWAY
{
return Err(JwtError::InvalidClaims(JwtClaimsError::JwtTokenHasExpired(
exp.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
)));
}
if let Some(nbf) = payload.not_before {
if nbf >= now + CLOCK_SKEW_LEEWAY {
return Err(JwtError::InvalidClaims(
JwtClaimsError::JwtTokenNotYetReadyToUse(
nbf.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
),
));
}
if let Some(nbf) = payload.not_before
&& nbf >= now + CLOCK_SKEW_LEEWAY
{
return Err(JwtError::InvalidClaims(
JwtClaimsError::JwtTokenNotYetReadyToUse(
nbf.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
),
));
}
Ok(ComputeCredentialKeys::JwtPayload(payloadb))

View File

@@ -171,7 +171,6 @@ impl ComputeUserInfo {
pub(crate) enum ComputeCredentialKeys {
AuthKeys(AuthKeys),
JwtPayload(Vec<u8>),
None,
}
impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
@@ -346,15 +345,13 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
Err(e) => {
// The password could have been changed, so we invalidate the cache.
// We should only invalidate the cache if the TTL might have expired.
if e.is_password_failed() {
#[allow(irrefutable_let_patterns)]
if let ControlPlaneClient::ProxyV1(api) = &*api {
if let Some(ep) = &user_info.endpoint_id {
api.caches
.project_info
.maybe_invalidate_role_secret(ep, &user_info.user);
}
}
if e.is_password_failed()
&& let ControlPlaneClient::ProxyV1(api) = &*api
&& let Some(ep) = &user_info.endpoint_id
{
api.caches
.project_info
.maybe_invalidate_role_secret(ep, &user_info.user);
}
Err(e)

View File

@@ -1,43 +1,40 @@
use std::net::SocketAddr;
use std::pin::pin;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, bail, ensure};
use anyhow::bail;
use arc_swap::ArcSwapOption;
use camino::{Utf8Path, Utf8PathBuf};
use camino::Utf8PathBuf;
use clap::Parser;
use compute_api::spec::LocalProxySpec;
use futures::future::Either;
use thiserror::Error;
use tokio::net::TcpListener;
use tokio::sync::Notify;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
use tracing::{debug, error, info};
use utils::sentry_init::init_sentry;
use utils::{pid_file, project_build_tag, project_git_version};
use crate::auth::backend::jwt::JwkCache;
use crate::auth::backend::local::{JWKS_ROLE_MAP, LocalBackend};
use crate::auth::backend::local::LocalBackend;
use crate::auth::{self};
use crate::cancellation::CancellationHandler;
#[cfg(feature = "rest_broker")]
use crate::config::RestConfig;
use crate::config::{
self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig,
refresh_config_loop,
};
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use crate::ext::TaskExt;
use crate::http::health_server::AppMetrics;
use crate::intern::RoleNameInt;
use crate::metrics::{Metrics, ThreadPoolMetrics};
use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo};
use crate::scram::threadpool::ThreadPool;
use crate::serverless::cancel_set::CancelSet;
use crate::serverless::{self, GlobalConnPoolOptions};
use crate::tls::client_config::compute_client_config_with_root_certs;
use crate::types::RoleName;
use crate::url::ApiUrl;
project_git_version!(GIT_VERSION);
@@ -82,6 +79,11 @@ struct LocalProxyCliArgs {
/// Path of the local proxy PID file
#[clap(long, default_value = "./local_proxy.pid")]
pid_path: Utf8PathBuf,
/// Disable pg_session_jwt extension installation
/// This is useful for testing the local proxy with vanilla postgres.
#[clap(long, default_value = "false")]
#[cfg(feature = "testing")]
disable_pg_session_jwt: bool,
}
#[derive(clap::Args, Clone, Copy, Debug)]
@@ -277,11 +279,18 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
accept_jwts: true,
console_redirect_confirmation_timeout: Duration::ZERO,
},
#[cfg(feature = "rest_broker")]
rest_config: RestConfig {
is_rest_broker: false,
db_schema_cache: None,
},
proxy_protocol_v2: config::ProxyProtocolV2::Rejected,
handshake_timeout: Duration::from_secs(10),
wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
connect_compute_locks,
connect_to_compute: compute_config,
#[cfg(feature = "testing")]
disable_pg_session_jwt: args.disable_pg_session_jwt,
})))
}
@@ -293,132 +302,3 @@ fn build_auth_backend(args: &LocalProxyCliArgs) -> &'static auth::Backend<'stati
Box::leak(Box::new(auth_backend))
}
#[derive(Error, Debug)]
enum RefreshConfigError {
#[error(transparent)]
Read(#[from] std::io::Error),
#[error(transparent)]
Parse(#[from] serde_json::Error),
#[error(transparent)]
Validate(anyhow::Error),
#[error(transparent)]
Tls(anyhow::Error),
}
async fn refresh_config_loop(config: &ProxyConfig, path: Utf8PathBuf, rx: Arc<Notify>) {
let mut init = true;
loop {
rx.notified().await;
match refresh_config_inner(config, &path).await {
Ok(()) => {}
// don't log for file not found errors if this is the first time we are checking
// for computes that don't use local_proxy, this is not an error.
Err(RefreshConfigError::Read(e))
if init && e.kind() == std::io::ErrorKind::NotFound =>
{
debug!(error=?e, ?path, "could not read config file");
}
Err(RefreshConfigError::Tls(e)) => {
error!(error=?e, ?path, "could not read TLS certificates");
}
Err(e) => {
error!(error=?e, ?path, "could not read config file");
}
}
init = false;
}
}
async fn refresh_config_inner(
config: &ProxyConfig,
path: &Utf8Path,
) -> Result<(), RefreshConfigError> {
let bytes = tokio::fs::read(&path).await?;
let data: LocalProxySpec = serde_json::from_slice(&bytes)?;
let mut jwks_set = vec![];
fn parse_jwks_settings(jwks: compute_api::spec::JwksSettings) -> anyhow::Result<JwksSettings> {
let mut jwks_url = url::Url::from_str(&jwks.jwks_url).context("parsing JWKS url")?;
ensure!(
jwks_url.has_authority()
&& (jwks_url.scheme() == "http" || jwks_url.scheme() == "https"),
"Invalid JWKS url. Must be HTTP",
);
ensure!(
jwks_url.host().is_some_and(|h| h != url::Host::Domain("")),
"Invalid JWKS url. No domain listed",
);
// clear username, password and ports
jwks_url
.set_username("")
.expect("url can be a base and has a valid host and is not a file. should not error");
jwks_url
.set_password(None)
.expect("url can be a base and has a valid host and is not a file. should not error");
// local testing is hard if we need to have a specific restricted port
if cfg!(not(feature = "testing")) {
jwks_url.set_port(None).expect(
"url can be a base and has a valid host and is not a file. should not error",
);
}
// clear query params
jwks_url.set_fragment(None);
jwks_url.query_pairs_mut().clear().finish();
if jwks_url.scheme() != "https" {
// local testing is hard if we need to set up https support.
if cfg!(not(feature = "testing")) {
jwks_url
.set_scheme("https")
.expect("should not error to set the scheme to https if it was http");
} else {
warn!(scheme = jwks_url.scheme(), "JWKS url is not HTTPS");
}
}
Ok(JwksSettings {
id: jwks.id,
jwks_url,
_provider_name: jwks.provider_name,
jwt_audience: jwks.jwt_audience,
role_names: jwks
.role_names
.into_iter()
.map(RoleName::from)
.map(|s| RoleNameInt::from(&s))
.collect(),
})
}
for jwks in data.jwks.into_iter().flatten() {
jwks_set.push(parse_jwks_settings(jwks).map_err(RefreshConfigError::Validate)?);
}
info!("successfully loaded new config");
JWKS_ROLE_MAP.store(Some(Arc::new(EndpointJwksResponse { jwks: jwks_set })));
if let Some(tls_config) = data.tls {
let tls_config = tokio::task::spawn_blocking(move || {
crate::tls::server_config::configure_tls(
tls_config.key_path.as_ref(),
tls_config.cert_path.as_ref(),
None,
false,
)
})
.await
.propagate_task_panic()
.map_err(RefreshConfigError::Tls)?;
config.tls_config.store(Some(Arc::new(tls_config)));
}
Ok(())
}

View File

@@ -4,6 +4,7 @@
//! This allows connecting to pods/services running in the same Kubernetes cluster from
//! the outside. Similar to an ingress controller for HTTPS.
use std::io;
use std::net::SocketAddr;
use std::path::Path;
use std::sync::Arc;
@@ -229,7 +230,6 @@ pub(super) async fn task_main(
.set_nodelay(true)
.context("failed to set socket option")?;
info!(%peer_addr, "serving");
let ctx = RequestContext::new(
session_id,
ConnectionInfo {
@@ -241,6 +241,14 @@ pub(super) async fn task_main(
handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await
}
.unwrap_or_else(|e| {
if let Some(FirstMessage(io_error)) = e.downcast_ref() {
// this is noisy. if we get EOF on the very first message that's likely
// just NLB doing a healthcheck.
if io_error.kind() == io::ErrorKind::UnexpectedEof {
return;
}
}
// Acknowledge that the task has finished with an error.
error!("per-client task finished with an error: {e:#}");
})
@@ -257,12 +265,19 @@ pub(super) async fn task_main(
Ok(())
}
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
struct FirstMessage(io::Error);
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
ctx: &RequestContext,
raw_stream: S,
tls_config: Arc<rustls::ServerConfig>,
) -> anyhow::Result<TlsStream<S>> {
let (mut stream, msg) = PqStream::parse_startup(Stream::from_raw(raw_stream)).await?;
let (mut stream, msg) = PqStream::parse_startup(Stream::from_raw(raw_stream))
.await
.map_err(FirstMessage)?;
match msg {
FeStartupPacket::SslRequest { direct: None } => {
let raw = stream.accept_tls().await?;

View File

@@ -10,11 +10,15 @@ use std::time::Duration;
use anyhow::Context;
use anyhow::{bail, ensure};
use arc_swap::ArcSwapOption;
#[cfg(any(test, feature = "testing"))]
use camino::Utf8PathBuf;
use futures::future::Either;
use itertools::{Itertools, Position};
use rand::{Rng, thread_rng};
use remote_storage::RemoteStorageConfig;
use tokio::net::TcpListener;
#[cfg(any(test, feature = "testing"))]
use tokio::sync::Notify;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, error, info, warn};
@@ -22,9 +26,15 @@ use utils::sentry_init::init_sentry;
use utils::{project_build_tag, project_git_version};
use crate::auth::backend::jwt::JwkCache;
#[cfg(any(test, feature = "testing"))]
use crate::auth::backend::local::LocalBackend;
use crate::auth::backend::{ConsoleRedirectBackend, MaybeOwned};
use crate::batch::BatchQueue;
use crate::cancellation::{CancellationHandler, CancellationProcessor};
#[cfg(feature = "rest_broker")]
use crate::config::RestConfig;
#[cfg(any(test, feature = "testing"))]
use crate::config::refresh_config_loop;
use crate::config::{
self, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig, ProjectInfoCacheOptions,
ProxyConfig, ProxyProtocolV2, remote_storage_from_toml,
@@ -39,6 +49,8 @@ use crate::redis::{elasticache, notifications};
use crate::scram::threadpool::ThreadPool;
use crate::serverless::GlobalConnPoolOptions;
use crate::serverless::cancel_set::CancelSet;
#[cfg(feature = "rest_broker")]
use crate::serverless::rest::DbSchemaCache;
use crate::tls::client_config::compute_client_config_with_root_certs;
#[cfg(any(test, feature = "testing"))]
use crate::url::ApiUrl;
@@ -60,6 +72,9 @@ enum AuthBackendType {
#[cfg(any(test, feature = "testing"))]
Postgres,
#[cfg(any(test, feature = "testing"))]
Local,
}
/// Neon proxy/router
@@ -74,6 +89,10 @@ struct ProxyCliArgs {
proxy: SocketAddr,
#[clap(value_enum, long, default_value_t = AuthBackendType::ConsoleRedirect)]
auth_backend: AuthBackendType,
/// Path of the local proxy config file (used for local-file auth backend)
#[clap(long, default_value = "./local_proxy.json")]
#[cfg(any(test, feature = "testing"))]
config_path: Utf8PathBuf,
/// listen for management callback connection on ip:port
#[clap(short, long, default_value = "127.0.0.1:7000")]
mgmt: SocketAddr,
@@ -226,6 +245,16 @@ struct ProxyCliArgs {
#[clap(flatten)]
pg_sni_router: PgSniRouterArgs,
/// if this is not local proxy, this toggles whether we accept Postgres REST requests
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
#[cfg(feature = "rest_broker")]
is_rest_broker: bool,
/// cache for `db_schema_cache` introspection (use `size=0` to disable)
#[clap(long, default_value = "size=1000,ttl=1h")]
#[cfg(feature = "rest_broker")]
db_schema_cache: String,
}
#[derive(clap::Args, Clone, Copy, Debug)]
@@ -386,6 +415,8 @@ pub async fn run() -> anyhow::Result<()> {
64,
));
#[cfg(any(test, feature = "testing"))]
let refresh_config_notify = Arc::new(Notify::new());
// client facing tasks. these will exit on error or on cancellation
// cancellation returns Ok(())
let mut client_tasks = JoinSet::new();
@@ -412,6 +443,17 @@ pub async fn run() -> anyhow::Result<()> {
endpoint_rate_limiter.clone(),
));
}
// if auth backend is local, we need to load the config file
#[cfg(any(test, feature = "testing"))]
if let auth::Backend::Local(_) = &auth_backend {
refresh_config_notify.notify_one();
tokio::spawn(refresh_config_loop(
config,
args.config_path,
refresh_config_notify.clone(),
));
}
}
Either::Right(auth_backend) => {
if let Some(proxy_listener) = proxy_listener {
@@ -462,7 +504,13 @@ pub async fn run() -> anyhow::Result<()> {
// maintenance tasks. these never return unless there's an error
let mut maintenance_tasks = JoinSet::new();
maintenance_tasks.spawn(crate::signals::handle(cancellation_token.clone(), || {}));
maintenance_tasks.spawn(crate::signals::handle(cancellation_token.clone(), {
move || {
#[cfg(any(test, feature = "testing"))]
refresh_config_notify.notify_one();
}
}));
maintenance_tasks.spawn(http::health_server::task_main(
http_listener,
AppMetrics {
@@ -473,59 +521,67 @@ pub async fn run() -> anyhow::Result<()> {
));
maintenance_tasks.spawn(control_plane::mgmt::task_main(mgmt_listener));
// add a task to flush the db_schema cache every 10 minutes
#[cfg(feature = "rest_broker")]
if let Some(db_schema_cache) = &config.rest_config.db_schema_cache {
maintenance_tasks.spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(600)).await;
db_schema_cache.flush();
}
});
}
if let Some(metrics_config) = &config.metric_collection {
// TODO: Add gc regardles of the metric collection being enabled.
maintenance_tasks.spawn(usage_metrics::task_main(metrics_config));
}
#[cfg_attr(not(any(test, feature = "testing")), expect(irrefutable_let_patterns))]
if let Either::Left(auth::Backend::ControlPlane(api, ())) = &auth_backend {
if let crate::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api {
if let Some(client) = redis_client {
// project info cache and invalidation of that cache.
let cache = api.caches.project_info.clone();
maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone()));
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
if let Either::Left(auth::Backend::ControlPlane(api, ())) = &auth_backend
&& let crate::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api
&& let Some(client) = redis_client
{
// project info cache and invalidation of that cache.
let cache = api.caches.project_info.clone();
maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone()));
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
// Try to connect to Redis 3 times with 1 + (0..0.1) second interval.
// This prevents immediate exit and pod restart,
// which can cause hammering of the redis in case of connection issues.
// cancellation key management
let mut redis_kv_client = RedisKVClient::new(client.clone());
for attempt in (0..3).with_position() {
match redis_kv_client.try_connect().await {
Ok(()) => {
info!("Connected to Redis KV client");
cancellation_handler.init_tx(BatchQueue::new(CancellationProcessor {
client: redis_kv_client,
batch_size: args.cancellation_batch_size,
}));
// Try to connect to Redis 3 times with 1 + (0..0.1) second interval.
// This prevents immediate exit and pod restart,
// which can cause hammering of the redis in case of connection issues.
// cancellation key management
let mut redis_kv_client = RedisKVClient::new(client.clone());
for attempt in (0..3).with_position() {
match redis_kv_client.try_connect().await {
Ok(()) => {
info!("Connected to Redis KV client");
cancellation_handler.init_tx(BatchQueue::new(CancellationProcessor {
client: redis_kv_client,
batch_size: args.cancellation_batch_size,
}));
break;
}
Err(e) => {
error!("Failed to connect to Redis KV client: {e}");
if matches!(attempt, Position::Last(_)) {
bail!(
"Failed to connect to Redis KV client after {} attempts",
attempt.into_inner()
);
}
let jitter = thread_rng().gen_range(0..100);
tokio::time::sleep(Duration::from_millis(1000 + jitter)).await;
}
}
break;
}
Err(e) => {
error!("Failed to connect to Redis KV client: {e}");
if matches!(attempt, Position::Last(_)) {
bail!(
"Failed to connect to Redis KV client after {} attempts",
attempt.into_inner()
);
}
let jitter = thread_rng().gen_range(0..100);
tokio::time::sleep(Duration::from_millis(1000 + jitter)).await;
}
// listen for notifications of new projects/endpoints/branches
let cache = api.caches.endpoints_cache.clone();
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(client, cancellation_token.clone()).await }
.instrument(span),
);
}
}
// listen for notifications of new projects/endpoints/branches
let cache = api.caches.endpoints_cache.clone();
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(client, cancellation_token.clone()).await }.instrument(span),
);
}
let maintenance = loop {
@@ -643,6 +699,28 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
timeout: Duration::from_secs(2),
};
#[cfg(feature = "rest_broker")]
let rest_config = {
let db_schema_cache_config: CacheOptions = args.db_schema_cache.parse()?;
info!("Using DbSchemaCache with options={db_schema_cache_config:?}");
let db_schema_cache = if args.is_rest_broker {
Some(DbSchemaCache::new(
"db_schema_cache",
db_schema_cache_config.size,
db_schema_cache_config.ttl,
true,
))
} else {
None
};
RestConfig {
is_rest_broker: args.is_rest_broker,
db_schema_cache,
}
};
let config = ProxyConfig {
tls_config,
metric_collection,
@@ -653,6 +731,10 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
connect_compute_locks,
connect_to_compute: compute_config,
#[cfg(feature = "testing")]
disable_pg_session_jwt: false,
#[cfg(feature = "rest_broker")]
rest_config,
};
let config = Box::leak(Box::new(config));
@@ -806,6 +888,19 @@ fn build_auth_backend(
Ok(Either::Right(config))
}
#[cfg(any(test, feature = "testing"))]
AuthBackendType::Local => {
let postgres: SocketAddr = "127.0.0.1:7432".parse()?;
let compute_ctl: ApiUrl = "http://127.0.0.1:3081/".parse()?;
let auth_backend = crate::auth::Backend::Local(
crate::auth::backend::MaybeOwned::Owned(LocalBackend::new(postgres, compute_ctl)),
);
let config = Box::leak(Box::new(auth_backend));
Ok(Either::Left(config))
}
}
}

View File

@@ -204,6 +204,10 @@ impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
self.insert_raw_ttl(key, value, ttl, false);
}
pub(crate) fn insert(&self, key: K, value: V) {
self.insert_raw_ttl(key, value, self.ttl, self.update_ttl_on_retrieval);
}
pub(crate) fn insert_unit(&self, key: K, value: V) -> (Option<V>, Cached<&Self, ()>) {
let (_, old) = self.insert_raw(key.clone(), value);
@@ -214,6 +218,28 @@ impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
(old, cached)
}
pub(crate) fn flush(&self) {
let now = Instant::now();
let mut cache = self.cache.lock();
// Collect keys of expired entries first
let expired_keys: Vec<_> = cache
.iter()
.filter_map(|(key, entry)| {
if entry.expires_at <= now {
Some(key.clone())
} else {
None
}
})
.collect();
// Remove expired entries
for key in expired_keys {
cache.remove(&key);
}
}
}
impl<K: Hash + Eq, V: Clone> TimedLru<K, V> {

View File

@@ -64,6 +64,13 @@ impl Pipeline {
let responses = self.replies;
let batch_size = self.inner.len();
if !client.credentials_refreshed() {
tracing::debug!(
"Redis credentials are not refreshed. Sleeping for 5 seconds before retrying..."
);
tokio::time::sleep(Duration::from_secs(5)).await;
}
match client.query(&self.inner).await {
// for each reply, we expect that many values.
Ok(Value::Array(values)) if values.len() == responses => {
@@ -127,6 +134,14 @@ impl QueueProcessing for CancellationProcessor {
}
async fn apply(&mut self, batch: Vec<Self::Req>) -> Vec<Self::Res> {
if !self.client.credentials_refreshed() {
// this will cause a timeout for cancellation operations
tracing::debug!(
"Redis credentials are not refreshed. Sleeping for 5 seconds before retrying..."
);
tokio::time::sleep(Duration::from_secs(5)).await;
}
let mut pipeline = Pipeline::with_capacity(batch.len());
let batch_size = batch.len();

View File

@@ -165,7 +165,7 @@ impl AuthInfo {
ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => {
Some(Auth::Scram(Box::new(auth_keys)))
}
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => None,
ComputeCredentialKeys::JwtPayload(_) => None,
},
server_params: StartupMessageParams::default(),
skip_db_user: false,

View File

@@ -4,28 +4,43 @@ use std::time::Duration;
use anyhow::{Context, Ok, bail, ensure};
use arc_swap::ArcSwapOption;
use camino::{Utf8Path, Utf8PathBuf};
use clap::ValueEnum;
use compute_api::spec::LocalProxySpec;
use remote_storage::RemoteStorageConfig;
use thiserror::Error;
use tokio::sync::Notify;
use tracing::{debug, error, info, warn};
use crate::auth::backend::jwt::JwkCache;
use crate::auth::backend::local::JWKS_ROLE_MAP;
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use crate::ext::TaskExt;
use crate::intern::RoleNameInt;
use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig};
use crate::scram::threadpool::ThreadPool;
use crate::serverless::GlobalConnPoolOptions;
use crate::serverless::cancel_set::CancelSet;
#[cfg(feature = "rest_broker")]
use crate::serverless::rest::DbSchemaCache;
pub use crate::tls::server_config::{TlsConfig, configure_tls};
use crate::types::Host;
use crate::types::{Host, RoleName};
pub struct ProxyConfig {
pub tls_config: ArcSwapOption<TlsConfig>,
pub metric_collection: Option<MetricCollectionConfig>,
pub http_config: HttpConfig,
pub authentication_config: AuthenticationConfig,
#[cfg(feature = "rest_broker")]
pub rest_config: RestConfig,
pub proxy_protocol_v2: ProxyProtocolV2,
pub handshake_timeout: Duration,
pub wake_compute_retry_config: RetryConfig,
pub connect_compute_locks: ApiLocks<Host>,
pub connect_to_compute: ComputeConfig,
#[cfg(feature = "testing")]
pub disable_pg_session_jwt: bool,
}
pub struct ComputeConfig {
@@ -69,6 +84,12 @@ pub struct AuthenticationConfig {
pub console_redirect_confirmation_timeout: tokio::time::Duration,
}
#[cfg(feature = "rest_broker")]
pub struct RestConfig {
pub is_rest_broker: bool,
pub db_schema_cache: Option<DbSchemaCache>,
}
#[derive(Debug)]
pub struct EndpointCacheConfig {
/// Batch size to receive all endpoints on the startup.
@@ -409,6 +430,135 @@ impl FromStr for ConcurrencyLockOptions {
}
}
#[derive(Error, Debug)]
pub(crate) enum RefreshConfigError {
#[error(transparent)]
Read(#[from] std::io::Error),
#[error(transparent)]
Parse(#[from] serde_json::Error),
#[error(transparent)]
Validate(anyhow::Error),
#[error(transparent)]
Tls(anyhow::Error),
}
pub(crate) async fn refresh_config_loop(config: &ProxyConfig, path: Utf8PathBuf, rx: Arc<Notify>) {
let mut init = true;
loop {
rx.notified().await;
match refresh_config_inner(config, &path).await {
std::result::Result::Ok(()) => {}
// don't log for file not found errors if this is the first time we are checking
// for computes that don't use local_proxy, this is not an error.
Err(RefreshConfigError::Read(e))
if init && e.kind() == std::io::ErrorKind::NotFound =>
{
debug!(error=?e, ?path, "could not read config file");
}
Err(RefreshConfigError::Tls(e)) => {
error!(error=?e, ?path, "could not read TLS certificates");
}
Err(e) => {
error!(error=?e, ?path, "could not read config file");
}
}
init = false;
}
}
pub(crate) async fn refresh_config_inner(
config: &ProxyConfig,
path: &Utf8Path,
) -> Result<(), RefreshConfigError> {
let bytes = tokio::fs::read(&path).await?;
let data: LocalProxySpec = serde_json::from_slice(&bytes)?;
let mut jwks_set = vec![];
fn parse_jwks_settings(jwks: compute_api::spec::JwksSettings) -> anyhow::Result<JwksSettings> {
let mut jwks_url = url::Url::from_str(&jwks.jwks_url).context("parsing JWKS url")?;
ensure!(
jwks_url.has_authority()
&& (jwks_url.scheme() == "http" || jwks_url.scheme() == "https"),
"Invalid JWKS url. Must be HTTP",
);
ensure!(
jwks_url.host().is_some_and(|h| h != url::Host::Domain("")),
"Invalid JWKS url. No domain listed",
);
// clear username, password and ports
jwks_url
.set_username("")
.expect("url can be a base and has a valid host and is not a file. should not error");
jwks_url
.set_password(None)
.expect("url can be a base and has a valid host and is not a file. should not error");
// local testing is hard if we need to have a specific restricted port
if cfg!(not(feature = "testing")) {
jwks_url.set_port(None).expect(
"url can be a base and has a valid host and is not a file. should not error",
);
}
// clear query params
jwks_url.set_fragment(None);
jwks_url.query_pairs_mut().clear().finish();
if jwks_url.scheme() != "https" {
// local testing is hard if we need to set up https support.
if cfg!(not(feature = "testing")) {
jwks_url
.set_scheme("https")
.expect("should not error to set the scheme to https if it was http");
} else {
warn!(scheme = jwks_url.scheme(), "JWKS url is not HTTPS");
}
}
Ok(JwksSettings {
id: jwks.id,
jwks_url,
_provider_name: jwks.provider_name,
jwt_audience: jwks.jwt_audience,
role_names: jwks
.role_names
.into_iter()
.map(RoleName::from)
.map(|s| RoleNameInt::from(&s))
.collect(),
})
}
for jwks in data.jwks.into_iter().flatten() {
jwks_set.push(parse_jwks_settings(jwks).map_err(RefreshConfigError::Validate)?);
}
info!("successfully loaded new config");
JWKS_ROLE_MAP.store(Some(Arc::new(EndpointJwksResponse { jwks: jwks_set })));
if let Some(tls_config) = data.tls {
let tls_config = tokio::task::spawn_blocking(move || {
crate::tls::server_config::configure_tls(
tls_config.key_path.as_ref(),
tls_config.cert_path.as_ref(),
None,
false,
)
})
.await
.propagate_task_panic()
.map_err(RefreshConfigError::Tls)?;
config.tls_config.store(Some(Arc::new(tls_config)));
}
std::result::Result::Ok(())
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -209,11 +209,9 @@ impl RequestContext {
if let Some(options_str) = options.get("options") {
// If not found directly, try to extract it from the options string
for option in options_str.split_whitespace() {
if option.starts_with("neon_query_id:") {
if let Some(value) = option.strip_prefix("neon_query_id:") {
this.set_testodrome_id(value.into());
break;
}
if let Some(value) = option.strip_prefix("neon_query_id:") {
this.set_testodrome_id(value.into());
break;
}
}
}

View File

@@ -250,10 +250,8 @@ impl NeonControlPlaneClient {
info!(duration = ?start.elapsed(), "received http response");
let body = parse_body::<WakeCompute>(response.status(), response.bytes().await?)?;
// Unfortunately, ownership won't let us use `Option::ok_or` here.
let (host, port) = match parse_host_port(&body.address) {
None => return Err(WakeComputeError::BadComputeAddress(body.address)),
Some(x) => x,
let Some((host, port)) = parse_host_port(&body.address) else {
return Err(WakeComputeError::BadComputeAddress(body.address));
};
let host_addr = IpAddr::from_str(host).ok();

View File

@@ -8,7 +8,7 @@ use std::time::Duration;
use clashmap::ClashMap;
use tokio::time::Instant;
use tracing::{debug, info};
use tracing::debug;
use super::{EndpointAccessControl, RoleAccessControl};
use crate::auth::backend::ComputeUserInfo;
@@ -213,7 +213,12 @@ impl<K: Hash + Eq + Clone> ApiLocks<K> {
self.metrics
.semaphore_acquire_seconds
.observe(now.elapsed().as_secs_f64());
debug!("acquired permit {:?}", now.elapsed().as_secs_f64());
if permit.is_ok() {
debug!(elapsed = ?now.elapsed(), "acquired permit");
} else {
debug!(elapsed = ?now.elapsed(), "timed out acquiring permit");
}
Ok(WakeComputePermit { permit: permit? })
}
@@ -229,7 +234,8 @@ impl<K: Hash + Eq + Clone> ApiLocks<K> {
// temporary lock a single shard and then clear any semaphores that aren't currently checked out
// race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
// therefore releasing it is safe from race conditions
info!(
debug!(
//FIXME: is anything depending on this being info?
name = self.name,
shard = i,
"performing epoch reclamation on api lock"

View File

@@ -52,7 +52,7 @@ pub async fn init() -> anyhow::Result<LoggingGuard> {
StderrWriter {
stderr: std::io::stderr(),
},
&["request_id", "session_id", "conn_id"],
&["conn_id", "ep", "query_id", "request_id", "session_id"],
))
} else {
None
@@ -271,18 +271,18 @@ where
});
// In case logging fails we generate a simpler JSON object.
if let Err(err) = res {
if let Ok(mut line) = serde_json::to_vec(&serde_json::json!( {
if let Err(err) = res
&& let Ok(mut line) = serde_json::to_vec(&serde_json::json!( {
"timestamp": now.to_rfc3339_opts(chrono::SecondsFormat::Micros, true),
"level": "ERROR",
"message": format_args!("cannot log event: {err:?}"),
"fields": {
"event": format_args!("{event:?}"),
},
})) {
line.push(b'\n');
self.writer.make_writer().write_all(&line).ok();
}
}))
{
line.push(b'\n');
self.writer.make_writer().write_all(&line).ok();
}
}
@@ -583,10 +583,11 @@ impl EventFormatter {
THREAD_ID.with(|tid| serializer.serialize_entry("thread_id", tid))?;
// TODO: tls cache? name could change
if let Some(thread_name) = std::thread::current().name() {
if !thread_name.is_empty() && thread_name != "tokio-runtime-worker" {
serializer.serialize_entry("thread_name", thread_name)?;
}
if let Some(thread_name) = std::thread::current().name()
&& !thread_name.is_empty()
&& thread_name != "tokio-runtime-worker"
{
serializer.serialize_entry("thread_name", thread_name)?;
}
if let Some(task_id) = tokio::task::try_id() {
@@ -596,10 +597,10 @@ impl EventFormatter {
serializer.serialize_entry("target", meta.target())?;
// Skip adding module if it's the same as target.
if let Some(module) = meta.module_path() {
if module != meta.target() {
serializer.serialize_entry("module", module)?;
}
if let Some(module) = meta.module_path()
&& module != meta.target()
{
serializer.serialize_entry("module", module)?;
}
if let Some(file) = meta.file() {

View File

@@ -236,13 +236,6 @@ pub enum Bool {
False,
}
#[derive(FixedCardinalityLabel, Copy, Clone)]
#[label(singleton = "outcome")]
pub enum Outcome {
Success,
Failed,
}
#[derive(FixedCardinalityLabel, Copy, Clone)]
#[label(singleton = "outcome")]
pub enum CacheOutcome {

View File

@@ -90,27 +90,27 @@ where
// TODO: 1 info log, with a enum label for close direction.
// Early termination checks from compute to client.
if let TransferState::Done(_) = compute_to_client {
if let TransferState::Running(buf) = &client_to_compute {
info!("Compute is done, terminate client");
// Initiate shutdown
client_to_compute = TransferState::ShuttingDown(buf.amt);
client_to_compute_result =
transfer_one_direction(cx, &mut client_to_compute, client, compute)
.map_err(ErrorSource::from_client)?;
}
if let TransferState::Done(_) = compute_to_client
&& let TransferState::Running(buf) = &client_to_compute
{
info!("Compute is done, terminate client");
// Initiate shutdown
client_to_compute = TransferState::ShuttingDown(buf.amt);
client_to_compute_result =
transfer_one_direction(cx, &mut client_to_compute, client, compute)
.map_err(ErrorSource::from_client)?;
}
// Early termination checks from client to compute.
if let TransferState::Done(_) = client_to_compute {
if let TransferState::Running(buf) = &compute_to_client {
info!("Client is done, terminate compute");
// Initiate shutdown
compute_to_client = TransferState::ShuttingDown(buf.amt);
compute_to_client_result =
transfer_one_direction(cx, &mut compute_to_client, compute, client)
.map_err(ErrorSource::from_compute)?;
}
if let TransferState::Done(_) = client_to_compute
&& let TransferState::Running(buf) = &compute_to_client
{
info!("Client is done, terminate compute");
// Initiate shutdown
compute_to_client = TransferState::ShuttingDown(buf.amt);
compute_to_client_result =
transfer_one_direction(cx, &mut compute_to_client, compute, client)
.map_err(ErrorSource::from_compute)?;
}
// It is not a problem if ready! returns early ... (comment remains the same)

View File

@@ -39,7 +39,11 @@ impl<K: Hash + Eq> LeakyBucketRateLimiter<K> {
let config = config.map_or(self.default_config, Into::into);
if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
if self
.access_count
.fetch_add(1, Ordering::AcqRel)
.is_multiple_of(2048)
{
self.do_gc(now);
}

View File

@@ -211,7 +211,11 @@ impl<K: Hash + Eq, R: Rng, S: BuildHasher + Clone> BucketRateLimiter<K, R, S> {
// worst case memory usage is about:
// = 2 * 2048 * 64 * (48B + 72B)
// = 30MB
if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
if self
.access_count
.fetch_add(1, Ordering::AcqRel)
.is_multiple_of(2048)
{
self.do_gc();
}

View File

@@ -1,79 +0,0 @@
use core::net::IpAddr;
use std::sync::Arc;
use tokio::sync::Mutex;
use uuid::Uuid;
use crate::pqproto::CancelKeyData;
pub trait CancellationPublisherMut: Send + Sync + 'static {
#[allow(async_fn_in_trait)]
async fn try_publish(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
peer_addr: IpAddr,
) -> anyhow::Result<()>;
}
pub trait CancellationPublisher: Send + Sync + 'static {
#[allow(async_fn_in_trait)]
async fn try_publish(
&self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
peer_addr: IpAddr,
) -> anyhow::Result<()>;
}
impl CancellationPublisher for () {
async fn try_publish(
&self,
_cancel_key_data: CancelKeyData,
_session_id: Uuid,
_peer_addr: IpAddr,
) -> anyhow::Result<()> {
Ok(())
}
}
impl<P: CancellationPublisher> CancellationPublisherMut for P {
async fn try_publish(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
peer_addr: IpAddr,
) -> anyhow::Result<()> {
<P as CancellationPublisher>::try_publish(self, cancel_key_data, session_id, peer_addr)
.await
}
}
impl<P: CancellationPublisher> CancellationPublisher for Option<P> {
async fn try_publish(
&self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
peer_addr: IpAddr,
) -> anyhow::Result<()> {
if let Some(p) = self {
p.try_publish(cancel_key_data, session_id, peer_addr).await
} else {
Ok(())
}
}
}
impl<P: CancellationPublisherMut> CancellationPublisher for Arc<Mutex<P>> {
async fn try_publish(
&self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
peer_addr: IpAddr,
) -> anyhow::Result<()> {
self.lock()
.await
.try_publish(cancel_key_data, session_id, peer_addr)
.await
}
}

View File

@@ -1,11 +1,12 @@
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use futures::FutureExt;
use redis::aio::{ConnectionLike, MultiplexedConnection};
use redis::{ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisResult};
use tokio::task::JoinHandle;
use tracing::{debug, error, info, warn};
use tokio::task::AbortHandle;
use tracing::{error, info, warn};
use super::elasticache::CredentialsProvider;
@@ -31,8 +32,9 @@ pub struct ConnectionWithCredentialsProvider {
credentials: Credentials,
// TODO: with more load on the connection, we should consider using a connection pool
con: Option<MultiplexedConnection>,
refresh_token_task: Option<JoinHandle<()>>,
refresh_token_task: Option<AbortHandle>,
mutex: tokio::sync::Mutex<()>,
credentials_refreshed: Arc<AtomicBool>,
}
impl Clone for ConnectionWithCredentialsProvider {
@@ -42,6 +44,7 @@ impl Clone for ConnectionWithCredentialsProvider {
con: None,
refresh_token_task: None,
mutex: tokio::sync::Mutex::new(()),
credentials_refreshed: Arc::new(AtomicBool::new(false)),
}
}
}
@@ -65,6 +68,7 @@ impl ConnectionWithCredentialsProvider {
con: None,
refresh_token_task: None,
mutex: tokio::sync::Mutex::new(()),
credentials_refreshed: Arc::new(AtomicBool::new(false)),
}
}
@@ -78,6 +82,7 @@ impl ConnectionWithCredentialsProvider {
con: None,
refresh_token_task: None,
mutex: tokio::sync::Mutex::new(()),
credentials_refreshed: Arc::new(AtomicBool::new(true)),
}
}
@@ -85,6 +90,10 @@ impl ConnectionWithCredentialsProvider {
redis::cmd("PING").query_async(con).await
}
pub(crate) fn credentials_refreshed(&self) -> bool {
self.credentials_refreshed.load(Ordering::Relaxed)
}
pub(crate) async fn connect(&mut self) -> anyhow::Result<()> {
let _guard = self.mutex.lock().await;
if let Some(con) = self.con.as_mut() {
@@ -112,13 +121,13 @@ impl ConnectionWithCredentialsProvider {
if let Credentials::Dynamic(credentials_provider, _) = &self.credentials {
let credentials_provider = credentials_provider.clone();
let con2 = con.clone();
let f = tokio::spawn(async move {
Self::keep_connection(con2, credentials_provider)
.await
.inspect_err(|e| debug!("keep_connection failed: {e}"))
.ok();
});
self.refresh_token_task = Some(f);
let credentials_refreshed = self.credentials_refreshed.clone();
let f = tokio::spawn(Self::keep_connection(
con2,
credentials_provider,
credentials_refreshed,
));
self.refresh_token_task = Some(f.abort_handle());
}
match Self::ping(&mut con).await {
Ok(()) => {
@@ -153,6 +162,7 @@ impl ConnectionWithCredentialsProvider {
async fn get_client(&self) -> anyhow::Result<redis::Client> {
let client = redis::Client::open(self.get_connection_info().await?)?;
self.credentials_refreshed.store(true, Ordering::Relaxed);
Ok(client)
}
@@ -168,16 +178,19 @@ impl ConnectionWithCredentialsProvider {
async fn keep_connection(
mut con: MultiplexedConnection,
credentials_provider: Arc<CredentialsProvider>,
) -> anyhow::Result<()> {
credentials_refreshed: Arc<AtomicBool>,
) -> ! {
loop {
// The connection lives for 12h, for the sanity check we refresh it every hour.
tokio::time::sleep(Duration::from_secs(60 * 60)).await;
match Self::refresh_token(&mut con, credentials_provider.clone()).await {
Ok(()) => {
info!("Token refreshed");
credentials_refreshed.store(true, Ordering::Relaxed);
}
Err(e) => {
error!("Error during token refresh: {e:?}");
credentials_refreshed.store(false, Ordering::Relaxed);
}
}
}
@@ -231,7 +244,7 @@ impl ConnectionLike for ConnectionWithCredentialsProvider {
&'a mut self,
cmd: &'a redis::Cmd,
) -> redis::RedisFuture<'a, redis::Value> {
(async move { self.send_packed_command(cmd).await }).boxed()
self.send_packed_command(cmd).boxed()
}
fn req_packed_commands<'a>(
@@ -240,10 +253,10 @@ impl ConnectionLike for ConnectionWithCredentialsProvider {
offset: usize,
count: usize,
) -> redis::RedisFuture<'a, Vec<redis::Value>> {
(async move { self.send_packed_commands(cmd, offset, count).await }).boxed()
self.send_packed_commands(cmd, offset, count).boxed()
}
fn get_db(&self) -> i64 {
0
self.con.as_ref().map_or(0, |c| c.get_db())
}
}

View File

@@ -40,6 +40,10 @@ impl RedisKVClient {
.inspect_err(|e| tracing::error!("failed to connect to redis: {e}"))
}
pub(crate) fn credentials_refreshed(&self) -> bool {
self.client.credentials_refreshed()
}
pub(crate) async fn query<T: FromRedisValue>(
&mut self,
q: &impl Queryable,
@@ -49,7 +53,7 @@ impl RedisKVClient {
Err(e) => e,
};
tracing::error!("failed to run query: {e}");
tracing::debug!("failed to run query: {e}");
match e.retry_method() {
redis::RetryMethod::Reconnect => {
tracing::info!("Redis client is disconnected. Reconnecting...");

View File

@@ -1,4 +1,3 @@
pub mod cancellation_publisher;
pub mod connection_with_credentials_provider;
pub mod elasticache;
pub mod keys;

View File

@@ -54,9 +54,7 @@ impl<T: std::fmt::Display> ChannelBinding<T> {
"eSws".into()
}
Self::Required(mode) => {
use std::io::Write;
let mut cbind_input = vec![];
write!(&mut cbind_input, "p={mode},,",).unwrap();
let mut cbind_input = format!("p={mode},,",).into_bytes();
cbind_input.extend_from_slice(get_cbind_data(mode)?);
BASE64_STANDARD.encode(&cbind_input).into()
}

View File

@@ -107,7 +107,7 @@ pub(crate) async fn exchange(
secret: &ServerSecret,
password: &[u8],
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
let salt = BASE64_STANDARD.decode(&secret.salt_base64)?;
let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?;
let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
if secret.is_password_invalid(&client_key).into() {

View File

@@ -87,13 +87,20 @@ impl<'a> ClientFirstMessage<'a> {
salt_base64: &str,
iterations: u32,
) -> OwnedServerFirstMessage {
use std::fmt::Write;
let mut message = String::with_capacity(128);
message.push_str("r=");
let mut message = String::new();
write!(&mut message, "r={}", self.nonce).unwrap();
// write combined nonce
let combined_nonce_start = message.len();
message.push_str(self.nonce);
BASE64_STANDARD.encode_string(nonce, &mut message);
let combined_nonce = 2..message.len();
write!(&mut message, ",s={salt_base64},i={iterations}").unwrap();
let combined_nonce = combined_nonce_start..message.len();
// write salt and iterations
message.push_str(",s=");
message.push_str(salt_base64);
message.push_str(",i=");
message.push_str(itoa::Buffer::new().format(iterations));
// This design guarantees that it's impossible to create a
// server-first-message without receiving a client-first-message

View File

@@ -14,7 +14,7 @@ pub(crate) struct ServerSecret {
/// Number of iterations for `PBKDF2` function.
pub(crate) iterations: u32,
/// Salt used to hash user's password.
pub(crate) salt_base64: String,
pub(crate) salt_base64: Box<str>,
/// Hashed `ClientKey`.
pub(crate) stored_key: ScramKey,
/// Used by client to verify server's signature.
@@ -35,7 +35,7 @@ impl ServerSecret {
let secret = ServerSecret {
iterations: iterations.parse().ok()?,
salt_base64: salt.to_owned(),
salt_base64: salt.into(),
stored_key: base64_decode_array(stored_key)?.into(),
server_key: base64_decode_array(server_key)?.into(),
doomed: false,
@@ -58,7 +58,7 @@ impl ServerSecret {
// iteration count 1 for our generated passwords going forward.
// PG16 users can set iteration count=1 already today.
iterations: 1,
salt_base64: BASE64_STANDARD.encode(nonce),
salt_base64: BASE64_STANDARD.encode(nonce).into_boxed_str(),
stored_key: ScramKey::default(),
server_key: ScramKey::default(),
doomed: true,
@@ -88,7 +88,7 @@ mod tests {
let parsed = ServerSecret::parse(&secret).unwrap();
assert_eq!(parsed.iterations, iterations);
assert_eq!(parsed.salt_base64, salt);
assert_eq!(&*parsed.salt_base64, salt);
assert_eq!(BASE64_STANDARD.encode(parsed.stored_key), stored_key);
assert_eq!(BASE64_STANDARD.encode(parsed.server_key), server_key);

View File

@@ -137,7 +137,7 @@ impl Future for JobSpec {
let state = state.as_mut().expect("should be set on thread startup");
state.tick = state.tick.wrapping_add(1);
if state.tick % SKETCH_RESET_INTERVAL == 0 {
if state.tick.is_multiple_of(SKETCH_RESET_INTERVAL) {
state.countmin.reset();
}

View File

@@ -115,7 +115,8 @@ impl PoolingBackend {
match &self.auth_backend {
crate::auth::Backend::ControlPlane(console, ()) => {
self.config
let keys = self
.config
.authentication_config
.jwks_cache
.check_jwt(
@@ -129,7 +130,7 @@ impl PoolingBackend {
Ok(ComputeCredentials {
info: user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
keys,
})
}
crate::auth::Backend::Local(_) => {
@@ -256,6 +257,7 @@ impl PoolingBackend {
&self,
ctx: &RequestContext,
conn_info: ConnInfo,
disable_pg_session_jwt: bool,
) -> Result<Client<postgres_client::Client>, HttpConnError> {
if let Some(client) = self.local_pool.get(ctx, &conn_info)? {
return Ok(client);
@@ -277,7 +279,7 @@ impl PoolingBackend {
.expect("semaphore should never be closed");
// check again for race
if !self.local_pool.initialized(&conn_info) {
if !self.local_pool.initialized(&conn_info) && !disable_pg_session_jwt {
local_backend
.compute_ctl
.install_extension(&ExtensionInstallRequest {
@@ -313,14 +315,16 @@ impl PoolingBackend {
.to_postgres_client_config();
config
.user(&conn_info.user_info.user)
.dbname(&conn_info.dbname)
.set_param(
.dbname(&conn_info.dbname);
if !disable_pg_session_jwt {
config.set_param(
"options",
&format!(
"-c pg_session_jwt.jwk={}",
serde_json::to_string(&jwk).expect("serializing jwk to json should not fail")
),
);
}
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (client, connection) = config.connect(&postgres_client::NoTls).await?;
@@ -345,7 +349,9 @@ impl PoolingBackend {
debug!("setting up backend session state");
// initiates the auth session
if let Err(e) = client.batch_execute("select auth.init();").await {
if !disable_pg_session_jwt
&& let Err(e) = client.batch_execute("select auth.init();").await
{
discard.discard();
return Err(e.into());
}

View File

@@ -148,11 +148,10 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade() {
if pool.write().remove_client(db_user.clone(), conn_id) {
if let Some(pool) = pool.clone().upgrade()
&& pool.write().remove_client(db_user.clone(), conn_id) {
info!("closed connection removed");
}
}
Poll::Ready(())
}).await;

View File

@@ -1,5 +1,93 @@
use http::StatusCode;
use http::header::HeaderName;
use crate::auth::ComputeUserInfoParseError;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::http::ReadBodyError;
pub trait HttpCodeError {
fn get_http_status_code(&self) -> StatusCode;
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum ConnInfoError {
#[error("invalid header: {0}")]
InvalidHeader(&'static HeaderName),
#[error("invalid connection string: {0}")]
UrlParseError(#[from] url::ParseError),
#[error("incorrect scheme")]
IncorrectScheme,
#[error("missing database name")]
MissingDbName,
#[error("invalid database name")]
InvalidDbName,
#[error("missing username")]
MissingUsername,
#[error("invalid username: {0}")]
InvalidUsername(#[from] std::string::FromUtf8Error),
#[error("missing authentication credentials: {0}")]
MissingCredentials(Credentials),
#[error("missing hostname")]
MissingHostname,
#[error("invalid hostname: {0}")]
InvalidEndpoint(#[from] ComputeUserInfoParseError),
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum Credentials {
#[error("required password")]
Password,
#[error("required authorization bearer token in JWT format")]
BearerJwt,
}
impl ReportableError for ConnInfoError {
fn get_error_kind(&self) -> ErrorKind {
ErrorKind::User
}
}
impl UserFacingError for ConnInfoError {
fn to_string_client(&self) -> String {
self.to_string()
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum ReadPayloadError {
#[error("could not read the HTTP request body: {0}")]
Read(#[from] hyper::Error),
#[error("request is too large (max is {limit} bytes)")]
BodyTooLarge { limit: usize },
#[error("could not parse the HTTP request body: {0}")]
Parse(#[from] serde_json::Error),
}
impl From<ReadBodyError<hyper::Error>> for ReadPayloadError {
fn from(value: ReadBodyError<hyper::Error>) -> Self {
match value {
ReadBodyError::BodyTooLarge { limit } => Self::BodyTooLarge { limit },
ReadBodyError::Read(e) => Self::Read(e),
}
}
}
impl ReportableError for ReadPayloadError {
fn get_error_kind(&self) -> ErrorKind {
match self {
ReadPayloadError::Read(_) => ErrorKind::ClientDisconnect,
ReadPayloadError::BodyTooLarge { .. } => ErrorKind::User,
ReadPayloadError::Parse(_) => ErrorKind::User,
}
}
}
impl HttpCodeError for ReadPayloadError {
fn get_http_status_code(&self) -> StatusCode {
match self {
ReadPayloadError::Read(_) => StatusCode::BAD_REQUEST,
ReadPayloadError::BodyTooLarge { .. } => StatusCode::PAYLOAD_TOO_LARGE,
ReadPayloadError::Parse(_) => StatusCode::BAD_REQUEST,
}
}
}

View File

@@ -2,6 +2,8 @@ use std::collections::VecDeque;
use std::sync::atomic::{self, AtomicUsize};
use std::sync::{Arc, Weak};
use bytes::Bytes;
use http_body_util::combinators::BoxBody;
use hyper::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
use parking_lot::RwLock;
@@ -21,8 +23,9 @@ use crate::protocol2::ConnectionInfoExtra;
use crate::types::EndpointCacheKey;
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
pub(crate) type Send = http2::SendRequest<hyper::body::Incoming>;
pub(crate) type Connect = http2::Connection<TokioIo<AsyncRW>, hyper::body::Incoming, TokioExecutor>;
pub(crate) type Send = http2::SendRequest<BoxBody<Bytes, hyper::Error>>;
pub(crate) type Connect =
http2::Connection<TokioIo<AsyncRW>, BoxBody<Bytes, hyper::Error>, TokioExecutor>;
#[derive(Clone)]
pub(crate) struct ClientDataHttp();
@@ -237,10 +240,10 @@ pub(crate) fn poll_http2_client(
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade() {
if pool.write().remove_conn(conn_id) {
info!("closed connection removed");
}
if let Some(pool) = pool.clone().upgrade()
&& pool.write().remove_conn(conn_id)
{
info!("closed connection removed");
}
}
.instrument(span),

View File

@@ -3,11 +3,42 @@
use anyhow::Context;
use bytes::Bytes;
use http::{Response, StatusCode};
use http::header::AUTHORIZATION;
use http::{HeaderMap, HeaderName, HeaderValue, Response, StatusCode};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full};
use http_utils::error::ApiError;
use serde::Serialize;
use url::Url;
use uuid::Uuid;
use super::conn_pool::{AuthData, ConnInfoWithAuth};
use super::conn_pool_lib::ConnInfo;
use super::error::{ConnInfoError, Credentials};
use crate::auth::backend::ComputeUserInfo;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::metrics::{Metrics, SniGroup, SniKind};
use crate::pqproto::StartupMessageParams;
use crate::proxy::NeonOptions;
use crate::types::{DbName, EndpointId, RoleName};
// Common header names used across serverless modules
pub(super) static NEON_REQUEST_ID: HeaderName = HeaderName::from_static("neon-request-id");
pub(super) static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
pub(super) static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
pub(super) static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
pub(super) static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
pub(super) static TXN_ISOLATION_LEVEL: HeaderName =
HeaderName::from_static("neon-batch-isolation-level");
pub(super) static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
pub(super) static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
pub(crate) fn uuid_to_header_value(id: Uuid) -> HeaderValue {
let mut uuid = [0; uuid::fmt::Hyphenated::LENGTH];
HeaderValue::from_str(id.as_hyphenated().encode_lower(&mut uuid[..]))
.expect("uuid hyphenated format should be all valid header characters")
}
/// Like [`ApiError::into_response`]
pub(crate) fn api_error_into_response(this: ApiError) -> Response<BoxBody<Bytes, hyper::Error>> {
@@ -107,3 +138,136 @@ pub(crate) fn json_response<T: Serialize>(
.map_err(|e| ApiError::InternalServerError(e.into()))?;
Ok(response)
}
pub(crate) fn get_conn_info(
config: &'static AuthenticationConfig,
ctx: &RequestContext,
connection_string: Option<&str>,
headers: &HeaderMap,
) -> Result<ConnInfoWithAuth, ConnInfoError> {
let connection_url = match connection_string {
Some(connection_string) => Url::parse(connection_string)?,
None => {
let connection_string = headers
.get(&CONN_STRING)
.ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))?
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?;
Url::parse(connection_string)?
}
};
let protocol = connection_url.scheme();
if protocol != "postgres" && protocol != "postgresql" {
return Err(ConnInfoError::IncorrectScheme);
}
let mut url_path = connection_url
.path_segments()
.ok_or(ConnInfoError::MissingDbName)?;
let dbname: DbName =
urlencoding::decode(url_path.next().ok_or(ConnInfoError::InvalidDbName)?)?.into();
ctx.set_dbname(dbname.clone());
let username = RoleName::from(urlencoding::decode(connection_url.username())?);
if username.is_empty() {
return Err(ConnInfoError::MissingUsername);
}
ctx.set_user(username.clone());
// TODO: make sure this is right in the context of rest broker
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
if !config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
}
let auth = auth
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
AuthData::Jwt(
auth.strip_prefix("Bearer ")
.ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
.into(),
)
} else if let Some(pass) = connection_url.password() {
// wrong credentials provided
if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
}
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
std::borrow::Cow::Borrowed(b) => b.into(),
std::borrow::Cow::Owned(b) => b.into(),
})
} else if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
} else {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
};
let endpoint: EndpointId = match connection_url.host() {
Some(url::Host::Domain(hostname)) => hostname
.split_once('.')
.map_or(hostname, |(prefix, _)| prefix)
.into(),
Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => {
return Err(ConnInfoError::MissingHostname);
}
};
ctx.set_endpoint_id(endpoint.clone());
let pairs = connection_url.query_pairs();
let mut options = Option::None;
let mut params = StartupMessageParams::default();
params.insert("user", &username);
params.insert("database", &dbname);
for (key, value) in pairs {
params.insert(&key, &value);
if key == "options" {
options = Some(NeonOptions::parse_options_raw(&value));
}
}
// check the URL that was used, for metrics
{
let host_endpoint = headers
// get the host header
.get("host")
// extract the domain
.and_then(|h| {
let (host, _port) = h.to_str().ok()?.split_once(':')?;
Some(host)
})
// get the endpoint prefix
.map(|h| h.split_once('.').map_or(h, |(prefix, _)| prefix));
let kind = if host_endpoint == Some(&*endpoint) {
SniKind::Sni
} else {
SniKind::NoSni
};
let protocol = ctx.protocol();
Metrics::get()
.proxy
.accepted_connections_by_sni
.inc(SniGroup { protocol, kind });
}
ctx.set_user_agent(
headers
.get(hyper::header::USER_AGENT)
.and_then(|h| h.to_str().ok())
.map(Into::into),
);
let user_info = ComputeUserInfo {
endpoint,
user: username,
options: options.unwrap_or_default(),
};
let conn_info = ConnInfo { user_info, dbname };
Ok(ConnInfoWithAuth { conn_info, auth })
}

View File

@@ -70,6 +70,34 @@ pub(crate) enum JsonConversionError {
ParseJsonError(#[from] serde_json::Error),
#[error("unbalanced array")]
UnbalancedArray,
#[error("unbalanced quoted string")]
UnbalancedString,
}
enum OutputMode {
Array(Vec<Value>),
Object(Map<String, Value>),
}
impl OutputMode {
fn key(&mut self, key: &str) -> &mut Value {
match self {
OutputMode::Array(values) => push_entry(values, Value::Null),
OutputMode::Object(map) => map.entry(key.to_string()).or_insert(Value::Null),
}
}
fn finish(self) -> Value {
match self {
OutputMode::Array(values) => Value::Array(values),
OutputMode::Object(map) => Value::Object(map),
}
}
}
fn push_entry<T>(arr: &mut Vec<T>, t: T) -> &mut T {
arr.push(t);
arr.last_mut().expect("a value was just inserted")
}
//
@@ -77,182 +105,277 @@ pub(crate) enum JsonConversionError {
//
pub(crate) fn pg_text_row_to_json(
row: &Row,
columns: &[Type],
raw_output: bool,
array_mode: bool,
) -> Result<Value, JsonConversionError> {
let iter = row
.columns()
.iter()
.zip(columns)
.enumerate()
.map(|(i, (column, typ))| {
let name = column.name();
let pg_value = row.as_text(i).map_err(JsonConversionError::AsTextError)?;
let json_value = if raw_output {
match pg_value {
Some(v) => Value::String(v.to_string()),
None => Value::Null,
}
} else {
pg_text_to_json(pg_value, typ)?
};
Ok((name.to_string(), json_value))
});
if array_mode {
// drop keys and aggregate into array
let arr = iter
.map(|r| r.map(|(_key, val)| val))
.collect::<Result<Vec<Value>, JsonConversionError>>()?;
Ok(Value::Array(arr))
let mut entries = if array_mode {
OutputMode::Array(Vec::with_capacity(row.columns().len()))
} else {
let obj = iter.collect::<Result<Map<String, Value>, JsonConversionError>>()?;
Ok(Value::Object(obj))
OutputMode::Object(Map::with_capacity(row.columns().len()))
};
for (i, column) in row.columns().iter().enumerate() {
let pg_value = row.as_text(i).map_err(JsonConversionError::AsTextError)?;
let value = entries.key(column.name());
match pg_value {
Some(v) if raw_output => *value = Value::String(v.to_string()),
Some(v) => pg_text_to_json(value, v, column.type_())?,
None => *value = Value::Null,
}
}
Ok(entries.finish())
}
//
// Convert postgres text-encoded value to JSON value
//
fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value, JsonConversionError> {
if let Some(val) = pg_value {
if let Kind::Array(elem_type) = pg_type.kind() {
return pg_array_parse(val, elem_type);
}
fn pg_text_to_json(
output: &mut Value,
val: &str,
pg_type: &Type,
) -> Result<(), JsonConversionError> {
if let Kind::Array(elem_type) = pg_type.kind() {
// todo: we should fetch this from postgres.
let delimiter = ',';
match *pg_type {
Type::BOOL => Ok(Value::Bool(val == "t")),
Type::INT2 | Type::INT4 => {
let val = val.parse::<i32>()?;
Ok(Value::Number(serde_json::Number::from(val)))
}
Type::FLOAT4 | Type::FLOAT8 => {
let fval = val.parse::<f64>()?;
let num = serde_json::Number::from_f64(fval);
if let Some(num) = num {
Ok(Value::Number(num))
} else {
// Pass Nan, Inf, -Inf as strings
// JS JSON.stringify() does converts them to null, but we
// want to preserve them, so we pass them as strings
Ok(Value::String(val.to_string()))
}
}
Type::JSON | Type::JSONB => Ok(serde_json::from_str(val)?),
_ => Ok(Value::String(val.to_string())),
}
} else {
Ok(Value::Null)
}
}
//
// Parse postgres array into JSON array.
//
// This is a bit involved because we need to handle nested arrays and quoted
// values. Unlike postgres we don't check that all nested arrays have the same
// dimensions, we just return them as is.
//
fn pg_array_parse(pg_array: &str, elem_type: &Type) -> Result<Value, JsonConversionError> {
pg_array_parse_inner(pg_array, elem_type, false).map(|(v, _)| v)
}
fn pg_array_parse_inner(
pg_array: &str,
elem_type: &Type,
nested: bool,
) -> Result<(Value, usize), JsonConversionError> {
let mut pg_array_chr = pg_array.char_indices();
let mut level = 0;
let mut quote = false;
let mut entries: Vec<Value> = Vec::new();
let mut entry = String::new();
// skip bounds decoration
if let Some('[') = pg_array.chars().next() {
for (_, c) in pg_array_chr.by_ref() {
if c == '=' {
break;
}
}
let mut array = vec![];
pg_array_parse(&mut array, val, elem_type, delimiter)?;
*output = Value::Array(array);
return Ok(());
}
fn push_checked(
entry: &mut String,
entries: &mut Vec<Value>,
elem_type: &Type,
) -> Result<(), JsonConversionError> {
if !entry.is_empty() {
// While in usual postgres response we get nulls as None and everything else
// as Some(&str), in arrays we get NULL as unquoted 'NULL' string (while
// string with value 'NULL' will be represented by '"NULL"'). So catch NULLs
// here while we have quotation info and convert them to None.
if entry == "NULL" {
entries.push(pg_text_to_json(None, elem_type)?);
match *pg_type {
Type::BOOL => *output = Value::Bool(val == "t"),
Type::INT2 | Type::INT4 => {
let val = val.parse::<i32>()?;
*output = Value::Number(serde_json::Number::from(val));
}
Type::FLOAT4 | Type::FLOAT8 => {
let fval = val.parse::<f64>()?;
let num = serde_json::Number::from_f64(fval);
if let Some(num) = num {
*output = Value::Number(num);
} else {
entries.push(pg_text_to_json(Some(entry), elem_type)?);
// Pass Nan, Inf, -Inf as strings
// JS JSON.stringify() does converts them to null, but we
// want to preserve them, so we pass them as strings
*output = Value::String(val.to_string());
}
entry.clear();
}
Ok(())
Type::JSON | Type::JSONB => *output = serde_json::from_str(val)?,
_ => *output = Value::String(val.to_string()),
}
while let Some((mut i, mut c)) = pg_array_chr.next() {
let mut escaped = false;
Ok(())
}
if c == '\\' {
escaped = true;
let Some(x) = pg_array_chr.next() else {
return Err(JsonConversionError::UnbalancedArray);
};
(i, c) = x;
}
match c {
'{' if !quote => {
level += 1;
if level > 1 {
let (res, off) = pg_array_parse_inner(&pg_array[i..], elem_type, true)?;
entries.push(res);
for _ in 0..off - 1 {
pg_array_chr.next();
}
}
}
'}' if !quote => {
level -= 1;
if level == 0 {
push_checked(&mut entry, &mut entries, elem_type)?;
if nested {
return Ok((Value::Array(entries), i));
}
}
}
'"' if !escaped => {
if quote {
// end of quoted string, so push it manually without any checks
// for emptiness or nulls
entries.push(pg_text_to_json(Some(&entry), elem_type)?);
entry.clear();
}
quote = !quote;
}
',' if !quote => {
push_checked(&mut entry, &mut entries, elem_type)?;
}
_ => {
entry.push(c);
}
}
/// Parse postgres array into JSON array.
///
/// This is a bit involved because we need to handle nested arrays and quoted
/// values. Unlike postgres we don't check that all nested arrays have the same
/// dimensions, we just return them as is.
///
/// <https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO>
///
/// The external text representation of an array value consists of items that are interpreted
/// according to the I/O conversion rules for the array's element type, plus decoration that
/// indicates the array structure. The decoration consists of curly braces (`{` and `}`) around
/// the array value plus delimiter characters between adjacent items. The delimiter character
/// is usually a comma (,) but can be something else: it is determined by the typdelim setting
/// for the array's element type. Among the standard data types provided in the PostgreSQL
/// distribution, all use a comma, except for type box, which uses a semicolon (;).
///
/// In a multidimensional array, each dimension (row, plane, cube, etc.)
/// gets its own level of curly braces, and delimiters must be written between adjacent
/// curly-braced entities of the same level.
fn pg_array_parse(
elements: &mut Vec<Value>,
mut pg_array: &str,
elem: &Type,
delim: char,
) -> Result<(), JsonConversionError> {
// skip bounds decoration, eg:
// `[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}`
// technically these are significant, but we have no way to represent them in json.
if let Some('[') = pg_array.chars().next() {
let Some((_bounds, array)) = pg_array.split_once('=') else {
return Err(JsonConversionError::UnbalancedArray);
};
pg_array = array;
}
if level != 0 {
// whitespace might preceed a `{`.
let pg_array = pg_array.trim_start();
let rest = pg_array_parse_inner(elements, pg_array, elem, delim)?;
if !rest.is_empty() {
return Err(JsonConversionError::UnbalancedArray);
}
Ok((Value::Array(entries), 0))
Ok(())
}
/// reads a single array from the `pg_array` string and pushes each values to `elements`.
/// returns the rest of the `pg_array` string that was not read.
fn pg_array_parse_inner<'a>(
elements: &mut Vec<Value>,
mut pg_array: &'a str,
elem: &Type,
delim: char,
) -> Result<&'a str, JsonConversionError> {
// array should have a `{` prefix.
pg_array = pg_array
.strip_prefix('{')
.ok_or(JsonConversionError::UnbalancedArray)?;
let mut q = String::new();
loop {
let value = push_entry(elements, Value::Null);
pg_array = pg_array_parse_item(value, &mut q, pg_array, elem, delim)?;
// check for separator.
if let Some(next) = pg_array.strip_prefix(delim) {
// next item.
pg_array = next;
} else {
break;
}
}
let Some(next) = pg_array.strip_prefix('}') else {
// missing `}` terminator.
return Err(JsonConversionError::UnbalancedArray);
};
// whitespace might follow a `}`.
Ok(next.trim_start())
}
/// reads a single item from the `pg_array` string.
/// returns the rest of the `pg_array` string that was not read.
///
/// `quoted` is a scratch allocation that has no defined output.
fn pg_array_parse_item<'a>(
output: &mut Value,
quoted: &mut String,
mut pg_array: &'a str,
elem: &Type,
delim: char,
) -> Result<&'a str, JsonConversionError> {
// We are trying to parse an array item.
// This could be a new array, if this is a multi-dimentional array.
// This could be a quoted string representing `elem`.
// This could be an unquoted string representing `elem`.
// whitespace might preceed an item.
pg_array = pg_array.trim_start();
if pg_array.starts_with('{') {
// nested array.
let mut nested = vec![];
pg_array = pg_array_parse_inner(&mut nested, pg_array, elem, delim)?;
*output = Value::Array(nested);
return Ok(pg_array);
}
if let Some(mut pg_array) = pg_array.strip_prefix('"') {
// the parsed string is un-escaped and written into quoted.
pg_array = pg_array_parse_quoted(quoted, pg_array)?;
// we have un-escaped the string, parse it as pgtext.
pg_text_to_json(output, quoted, elem)?;
return Ok(pg_array);
}
// we need to parse an item. read until we find a delimiter or `}`.
let index = pg_array
.find([delim, '}'])
.ok_or(JsonConversionError::UnbalancedArray)?;
let item;
(item, pg_array) = pg_array.split_at(index);
// item might have trailing whitespace that we need to ignore.
let item = item.trim_end();
// we might have an item string:
// check for null
if item == "NULL" {
*output = Value::Null;
} else {
pg_text_to_json(output, item, elem)?;
}
Ok(pg_array)
}
/// reads a single quoted item from the `pg_array` string.
///
/// Returns the rest of the `pg_array` string that was not read.
/// The output is written into `quoted`.
///
/// The pg_array string must have a `"` terminator, but the `"` initial value
/// must have already been removed from the input. The terminator is removed.
fn pg_array_parse_quoted<'a>(
quoted: &mut String,
mut pg_array: &'a str,
) -> Result<&'a str, JsonConversionError> {
// The array output routine will put double quotes around element values if they are empty strings,
// contain curly braces, delimiter characters, double quotes, backslashes, or white space,
// or match the word `NULL`. Double quotes and backslashes embedded in element values will be backslash-escaped.
// For numeric data types it is safe to assume that double quotes will never appear,
// but for textual data types one should be prepared to cope with either the presence or absence of quotes.
quoted.clear();
// We write to quoted in chunks terminated by an escape character.
// Eg if we have the input `foo\"bar"`, then we write `foo`, then `"`, then finally `bar`.
loop {
// we need to parse an chunk. read until we find a '\\' or `"`.
let i = pg_array
.find(['\\', '"'])
.ok_or(JsonConversionError::UnbalancedString)?;
let chunk: &str;
(chunk, pg_array) = pg_array
.split_at_checked(i)
.expect("i is guaranteed to be in-bounds of pg_array");
// push the chunk.
quoted.push_str(chunk);
// consume the chunk_end character.
let chunk_end: char;
(chunk_end, pg_array) =
split_first_char(pg_array).expect("pg_array should start with either '\\\\' or '\"'");
// finished.
if chunk_end == '"' {
// whitespace might follow the '"'.
pg_array = pg_array.trim_start();
break Ok(pg_array);
}
// consume the escaped character.
let escaped: char;
(escaped, pg_array) =
split_first_char(pg_array).ok_or(JsonConversionError::UnbalancedString)?;
quoted.push(escaped);
}
}
fn split_first_char(s: &str) -> Option<(char, &str)> {
let mut chars = s.chars();
let c = chars.next()?;
Some((c, chars.as_str()))
}
#[cfg(test)]
@@ -316,37 +439,33 @@ mod tests {
);
}
fn pg_text_to_json(val: &str, pg_type: &Type) -> Value {
let mut v = Value::Null;
super::pg_text_to_json(&mut v, val, pg_type).unwrap();
v
}
fn pg_array_parse(pg_array: &str, pg_type: &Type) -> Value {
let mut array = vec![];
super::pg_array_parse(&mut array, pg_array, pg_type, ',').unwrap();
Value::Array(array)
}
#[test]
fn test_atomic_types_parse() {
assert_eq!(pg_text_to_json("foo", &Type::TEXT), json!("foo"));
assert_eq!(pg_text_to_json("42", &Type::INT4), json!(42));
assert_eq!(pg_text_to_json("42", &Type::INT2), json!(42));
assert_eq!(pg_text_to_json("42", &Type::INT8), json!("42"));
assert_eq!(pg_text_to_json("42.42", &Type::FLOAT8), json!(42.42));
assert_eq!(pg_text_to_json("42.42", &Type::FLOAT4), json!(42.42));
assert_eq!(pg_text_to_json("NaN", &Type::FLOAT4), json!("NaN"));
assert_eq!(
pg_text_to_json(Some("foo"), &Type::TEXT).unwrap(),
json!("foo")
);
assert_eq!(pg_text_to_json(None, &Type::TEXT).unwrap(), json!(null));
assert_eq!(pg_text_to_json(Some("42"), &Type::INT4).unwrap(), json!(42));
assert_eq!(pg_text_to_json(Some("42"), &Type::INT2).unwrap(), json!(42));
assert_eq!(
pg_text_to_json(Some("42"), &Type::INT8).unwrap(),
json!("42")
);
assert_eq!(
pg_text_to_json(Some("42.42"), &Type::FLOAT8).unwrap(),
json!(42.42)
);
assert_eq!(
pg_text_to_json(Some("42.42"), &Type::FLOAT4).unwrap(),
json!(42.42)
);
assert_eq!(
pg_text_to_json(Some("NaN"), &Type::FLOAT4).unwrap(),
json!("NaN")
);
assert_eq!(
pg_text_to_json(Some("Infinity"), &Type::FLOAT4).unwrap(),
pg_text_to_json("Infinity", &Type::FLOAT4),
json!("Infinity")
);
assert_eq!(
pg_text_to_json(Some("-Infinity"), &Type::FLOAT4).unwrap(),
pg_text_to_json("-Infinity", &Type::FLOAT4),
json!("-Infinity")
);
@@ -355,10 +474,9 @@ mod tests {
.unwrap();
assert_eq!(
pg_text_to_json(
Some(r#"{"s":"str","n":42,"f":4.2,"a":[null,3,"a"]}"#),
r#"{"s":"str","n":42,"f":4.2,"a":[null,3,"a"]}"#,
&Type::JSONB
)
.unwrap(),
),
json
);
}
@@ -366,7 +484,7 @@ mod tests {
#[test]
fn test_pg_array_parse_text() {
fn pt(pg_arr: &str) -> Value {
pg_array_parse(pg_arr, &Type::TEXT).unwrap()
pg_array_parse(pg_arr, &Type::TEXT)
}
assert_eq!(
pt(r#"{"aa\"\\\,a",cha,"bbbb"}"#),
@@ -389,7 +507,7 @@ mod tests {
#[test]
fn test_pg_array_parse_bool() {
fn pb(pg_arr: &str) -> Value {
pg_array_parse(pg_arr, &Type::BOOL).unwrap()
pg_array_parse(pg_arr, &Type::BOOL)
}
assert_eq!(pb(r#"{t,f,t}"#), json!([true, false, true]));
assert_eq!(pb(r#"{{t,f,t}}"#), json!([[true, false, true]]));
@@ -406,7 +524,7 @@ mod tests {
#[test]
fn test_pg_array_parse_numbers() {
fn pn(pg_arr: &str, ty: &Type) -> Value {
pg_array_parse(pg_arr, ty).unwrap()
pg_array_parse(pg_arr, ty)
}
assert_eq!(pn(r#"{1,2,3}"#, &Type::INT4), json!([1, 2, 3]));
assert_eq!(pn(r#"{1,2,3}"#, &Type::INT2), json!([1, 2, 3]));
@@ -434,7 +552,7 @@ mod tests {
#[test]
fn test_pg_array_with_decoration() {
fn p(pg_arr: &str) -> Value {
pg_array_parse(pg_arr, &Type::INT2).unwrap()
pg_array_parse(pg_arr, &Type::INT2)
}
assert_eq!(
p(r#"[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}"#),
@@ -445,7 +563,7 @@ mod tests {
#[test]
fn test_pg_array_parse_json() {
fn pt(pg_arr: &str) -> Value {
pg_array_parse(pg_arr, &Type::JSONB).unwrap()
pg_array_parse(pg_arr, &Type::JSONB)
}
assert_eq!(pt(r#"{"{}"}"#), json!([{}]));
assert_eq!(

View File

@@ -249,11 +249,10 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade() {
if pool.global_pool.write().remove_client(db_user.clone(), conn_id) {
if let Some(pool) = pool.clone().upgrade()
&& pool.global_pool.write().remove_client(db_user.clone(), conn_id) {
info!("closed connection removed");
}
}
Poll::Ready(())
}).await;

View File

@@ -11,6 +11,8 @@ mod http_conn_pool;
mod http_util;
mod json;
mod local_conn_pool;
#[cfg(feature = "rest_broker")]
pub mod rest;
mod sql_over_http;
mod websocket;
@@ -29,13 +31,13 @@ use futures::future::{Either, select};
use http::{Method, Response, StatusCode};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Empty};
use http_util::{NEON_REQUEST_ID, uuid_to_header_value};
use http_utils::error::ApiError;
use hyper::body::Incoming;
use hyper_util::rt::TokioExecutor;
use hyper_util::server::conn::auto::Builder;
use rand::SeedableRng;
use rand::rngs::StdRng;
use sql_over_http::{NEON_REQUEST_ID, uuid_to_header_value};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
use tokio::time::timeout;
@@ -487,6 +489,37 @@ async fn request_handler(
.body(Empty::new().map_err(|x| match x {}).boxed())
.map_err(|e| ApiError::InternalServerError(e.into()))
} else {
json_response(StatusCode::BAD_REQUEST, "query is not supported")
#[cfg(feature = "rest_broker")]
{
if config.rest_config.is_rest_broker && {
let path_parts: Vec<&str> = request.uri().path().split('/').collect();
path_parts.len() >= 3 && path_parts[2].starts_with("rest")
} {
let ctx =
RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Http);
let span = ctx.span();
let testodrome_id = request
.headers()
.get("X-Neon-Query-ID")
.and_then(|value| value.to_str().ok())
.map(|s| s.to_string());
if let Some(query_id) = testodrome_id {
info!(parent: &ctx.span(), "testodrome query ID: {query_id}");
ctx.set_testodrome_id(query_id.into());
}
rest::handle(config, ctx, request, backend, http_cancellation_token)
.instrument(span)
.await
} else {
json_response(StatusCode::BAD_REQUEST, "query is not supported")
}
}
#[cfg(not(feature = "rest_broker"))]
{
json_response(StatusCode::BAD_REQUEST, "query is not supported")
}
}
}

1192
proxy/src/serverless/rest.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -11,7 +11,7 @@ use http_body_util::{BodyExt, Full};
use http_utils::error::ApiError;
use hyper::body::Incoming;
use hyper::http::{HeaderName, HeaderValue};
use hyper::{HeaderMap, Request, Response, StatusCode, header};
use hyper::{Request, Response, StatusCode, header};
use indexmap::IndexMap;
use postgres_client::error::{DbError, ErrorPosition, SqlState};
use postgres_client::{
@@ -24,26 +24,23 @@ use tokio::time::{self, Instant};
use tokio_util::sync::CancellationToken;
use tracing::{Level, debug, error, info};
use typed_json::json;
use url::Url;
use uuid::Uuid;
use super::backend::{LocalProxyConnError, PoolingBackend};
use super::conn_pool::{AuthData, ConnInfoWithAuth};
use super::conn_pool::AuthData;
use super::conn_pool_lib::{self, ConnInfo};
use super::error::HttpCodeError;
use super::http_util::json_response;
use super::error::{ConnInfoError, HttpCodeError, ReadPayloadError};
use super::http_util::{
ALLOW_POOL, ARRAY_MODE, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_DEFERRABLE,
TXN_ISOLATION_LEVEL, TXN_READ_ONLY, get_conn_info, json_response, uuid_to_header_value,
};
use super::json::{JsonConversionError, json_to_pg_text, pg_text_row_to_json};
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::auth::{ComputeUserInfoParseError, endpoint_sni};
use crate::config::{AuthenticationConfig, HttpConfig, ProxyConfig, TlsConfig};
use crate::auth::backend::ComputeCredentialKeys;
use crate::config::{HttpConfig, ProxyConfig};
use crate::context::RequestContext;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::http::{ReadBodyError, read_body_with_limit};
use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind};
use crate::pqproto::StartupMessageParams;
use crate::proxy::NeonOptions;
use crate::http::read_body_with_limit;
use crate::metrics::{HttpDirection, Metrics};
use crate::serverless::backend::HttpConnError;
use crate::types::{DbName, RoleName};
use crate::usage_metrics::{MetricCounter, MetricCounterRecorder};
use crate::util::run_until_cancelled;
@@ -70,16 +67,6 @@ enum Payload {
Batch(BatchQueryData),
}
pub(super) static NEON_REQUEST_ID: HeaderName = HeaderName::from_static("neon-request-id");
static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
@@ -91,188 +78,6 @@ where
Ok(json_to_pg_text(json))
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum ConnInfoError {
#[error("invalid header: {0}")]
InvalidHeader(&'static HeaderName),
#[error("invalid connection string: {0}")]
UrlParseError(#[from] url::ParseError),
#[error("incorrect scheme")]
IncorrectScheme,
#[error("missing database name")]
MissingDbName,
#[error("invalid database name")]
InvalidDbName,
#[error("missing username")]
MissingUsername,
#[error("invalid username: {0}")]
InvalidUsername(#[from] std::string::FromUtf8Error),
#[error("missing authentication credentials: {0}")]
MissingCredentials(Credentials),
#[error("missing hostname")]
MissingHostname,
#[error("invalid hostname: {0}")]
InvalidEndpoint(#[from] ComputeUserInfoParseError),
#[error("malformed endpoint")]
MalformedEndpoint,
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum Credentials {
#[error("required password")]
Password,
#[error("required authorization bearer token in JWT format")]
BearerJwt,
}
impl ReportableError for ConnInfoError {
fn get_error_kind(&self) -> ErrorKind {
ErrorKind::User
}
}
impl UserFacingError for ConnInfoError {
fn to_string_client(&self) -> String {
self.to_string()
}
}
fn get_conn_info(
config: &'static AuthenticationConfig,
ctx: &RequestContext,
headers: &HeaderMap,
tls: Option<&TlsConfig>,
) -> Result<ConnInfoWithAuth, ConnInfoError> {
let connection_string = headers
.get(&CONN_STRING)
.ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))?
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?;
let connection_url = Url::parse(connection_string)?;
let protocol = connection_url.scheme();
if protocol != "postgres" && protocol != "postgresql" {
return Err(ConnInfoError::IncorrectScheme);
}
let mut url_path = connection_url
.path_segments()
.ok_or(ConnInfoError::MissingDbName)?;
let dbname: DbName =
urlencoding::decode(url_path.next().ok_or(ConnInfoError::InvalidDbName)?)?.into();
ctx.set_dbname(dbname.clone());
let username = RoleName::from(urlencoding::decode(connection_url.username())?);
if username.is_empty() {
return Err(ConnInfoError::MissingUsername);
}
ctx.set_user(username.clone());
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
if !config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
}
let auth = auth
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
AuthData::Jwt(
auth.strip_prefix("Bearer ")
.ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
.into(),
)
} else if let Some(pass) = connection_url.password() {
// wrong credentials provided
if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
}
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
std::borrow::Cow::Borrowed(b) => b.into(),
std::borrow::Cow::Owned(b) => b.into(),
})
} else if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
} else {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
};
let endpoint = match connection_url.host() {
Some(url::Host::Domain(hostname)) => {
if let Some(tls) = tls {
endpoint_sni(hostname, &tls.common_names).ok_or(ConnInfoError::MalformedEndpoint)?
} else {
hostname
.split_once('.')
.map_or(hostname, |(prefix, _)| prefix)
.into()
}
}
Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => {
return Err(ConnInfoError::MissingHostname);
}
};
ctx.set_endpoint_id(endpoint.clone());
let pairs = connection_url.query_pairs();
let mut options = Option::None;
let mut params = StartupMessageParams::default();
params.insert("user", &username);
params.insert("database", &dbname);
for (key, value) in pairs {
params.insert(&key, &value);
if key == "options" {
options = Some(NeonOptions::parse_options_raw(&value));
}
}
// check the URL that was used, for metrics
{
let host_endpoint = headers
// get the host header
.get("host")
// extract the domain
.and_then(|h| {
let (host, _port) = h.to_str().ok()?.split_once(':')?;
Some(host)
})
// get the endpoint prefix
.map(|h| h.split_once('.').map_or(h, |(prefix, _)| prefix));
let kind = if host_endpoint == Some(&*endpoint) {
SniKind::Sni
} else {
SniKind::NoSni
};
let protocol = ctx.protocol();
Metrics::get()
.proxy
.accepted_connections_by_sni
.inc(SniGroup { protocol, kind });
}
ctx.set_user_agent(
headers
.get(hyper::header::USER_AGENT)
.and_then(|h| h.to_str().ok())
.map(Into::into),
);
let user_info = ComputeUserInfo {
endpoint,
user: username,
options: options.unwrap_or_default(),
};
let conn_info = ConnInfo { user_info, dbname };
Ok(ConnInfoWithAuth { conn_info, auth })
}
pub(crate) async fn handle(
config: &'static ProxyConfig,
ctx: RequestContext,
@@ -541,45 +346,6 @@ impl HttpCodeError for SqlOverHttpError {
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum ReadPayloadError {
#[error("could not read the HTTP request body: {0}")]
Read(#[from] hyper::Error),
#[error("request is too large (max is {limit} bytes)")]
BodyTooLarge { limit: usize },
#[error("could not parse the HTTP request body: {0}")]
Parse(#[from] serde_json::Error),
}
impl From<ReadBodyError<hyper::Error>> for ReadPayloadError {
fn from(value: ReadBodyError<hyper::Error>) -> Self {
match value {
ReadBodyError::BodyTooLarge { limit } => Self::BodyTooLarge { limit },
ReadBodyError::Read(e) => Self::Read(e),
}
}
}
impl ReportableError for ReadPayloadError {
fn get_error_kind(&self) -> ErrorKind {
match self {
ReadPayloadError::Read(_) => ErrorKind::ClientDisconnect,
ReadPayloadError::BodyTooLarge { .. } => ErrorKind::User,
ReadPayloadError::Parse(_) => ErrorKind::User,
}
}
}
impl HttpCodeError for ReadPayloadError {
fn get_http_status_code(&self) -> StatusCode {
match self {
ReadPayloadError::Read(_) => StatusCode::BAD_REQUEST,
ReadPayloadError::BodyTooLarge { .. } => StatusCode::PAYLOAD_TOO_LARGE,
ReadPayloadError::Parse(_) => StatusCode::BAD_REQUEST,
}
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum SqlOverHttpCancel {
#[error("query was cancelled")]
@@ -670,14 +436,7 @@ async fn handle_inner(
"handling interactive connection from client"
);
let conn_info = get_conn_info(
&config.authentication_config,
ctx,
request.headers(),
// todo: race condition?
// we're unlikely to change the common names.
config.tls_config.load().as_deref(),
)?;
let conn_info = get_conn_info(&config.authentication_config, ctx, None, request.headers())?;
info!(
user = conn_info.conn_info.user_info.user.as_str(),
"credentials"
@@ -763,9 +522,17 @@ async fn handle_db_inner(
ComputeCredentialKeys::JwtPayload(payload)
if backend.auth_backend.is_local_proxy() =>
{
let mut client = backend.connect_to_local_postgres(ctx, conn_info).await?;
let (cli_inner, _dsc) = client.client_inner();
cli_inner.set_jwt_session(&payload).await?;
#[cfg(feature = "testing")]
let disable_pg_session_jwt = config.disable_pg_session_jwt;
#[cfg(not(feature = "testing"))]
let disable_pg_session_jwt = false;
let mut client = backend
.connect_to_local_postgres(ctx, conn_info, disable_pg_session_jwt)
.await?;
if !disable_pg_session_jwt {
let (cli_inner, _dsc) = client.client_inner();
cli_inner.set_jwt_session(&payload).await?;
}
Client::Local(client)
}
_ => {
@@ -864,12 +631,6 @@ static HEADERS_TO_FORWARD: &[&HeaderName] = &[
&TXN_DEFERRABLE,
];
pub(crate) fn uuid_to_header_value(id: Uuid) -> HeaderValue {
let mut uuid = [0; uuid::fmt::Hyphenated::LENGTH];
HeaderValue::from_str(id.as_hyphenated().encode_lower(&mut uuid[..]))
.expect("uuid hyphenated format should be all valid header characters")
}
async fn handle_auth_broker_inner(
ctx: &RequestContext,
request: Request<Incoming>,
@@ -899,7 +660,7 @@ async fn handle_auth_broker_inner(
req = req.header(&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id()));
let req = req
.body(body)
.body(body.map_err(|e| e).boxed()) //TODO: is there a potential for a regression here?
.expect("all headers and params received via hyper should be valid for request");
// todo: map body to count egress
@@ -1135,7 +896,6 @@ async fn query_to_json<T: GenericClient>(
let columns_len = row_stream.statement.columns().len();
let mut fields = Vec::with_capacity(columns_len);
let mut types = Vec::with_capacity(columns_len);
for c in row_stream.statement.columns() {
fields.push(json!({
@@ -1147,8 +907,6 @@ async fn query_to_json<T: GenericClient>(
"dataTypeModifier": c.type_modifier(),
"format": "text",
}));
types.push(c.type_().clone());
}
let raw_output = parsed_headers.raw_output;
@@ -1170,7 +928,7 @@ async fn query_to_json<T: GenericClient>(
));
}
let row = pg_text_row_to_json(&row, &types, raw_output, array_mode)?;
let row = pg_text_row_to_json(&row, raw_output, array_mode)?;
rows.push(row);
// assumption: parsing pg text and converting to json takes CPU time.

View File

@@ -199,27 +199,27 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
let probe_msg;
let mut msg = &*msg;
if let Some(ctx) = ctx {
if ctx.get_testodrome_id().is_some() {
let tag = match error_kind {
ErrorKind::User => "client",
ErrorKind::ClientDisconnect => "client",
ErrorKind::RateLimit => "proxy",
ErrorKind::ServiceRateLimit => "proxy",
ErrorKind::Quota => "proxy",
ErrorKind::Service => "proxy",
ErrorKind::ControlPlane => "controlplane",
ErrorKind::Postgres => "other",
ErrorKind::Compute => "compute",
};
probe_msg = typed_json::json!({
"tag": tag,
"msg": msg,
"cold_start_info": ctx.cold_start_info(),
})
.to_string();
msg = &probe_msg;
}
if let Some(ctx) = ctx
&& ctx.get_testodrome_id().is_some()
{
let tag = match error_kind {
ErrorKind::User => "client",
ErrorKind::ClientDisconnect => "client",
ErrorKind::RateLimit => "proxy",
ErrorKind::ServiceRateLimit => "proxy",
ErrorKind::Quota => "proxy",
ErrorKind::Service => "proxy",
ErrorKind::ControlPlane => "controlplane",
ErrorKind::Postgres => "other",
ErrorKind::Compute => "compute",
};
probe_msg = typed_json::json!({
"tag": tag,
"msg": msg,
"cold_start_info": ctx.cold_start_info(),
})
.to_string();
msg = &probe_msg;
}
// TODO: either preserve the error code from postgres, or assign error codes to proxy errors.

View File

@@ -18,9 +18,10 @@ use metrics::set_build_info_metric;
use remote_storage::RemoteStorageConfig;
use safekeeper::defaults::{
DEFAULT_CONTROL_FILE_SAVE_INTERVAL, DEFAULT_EVICTION_MIN_RESIDENT, DEFAULT_HEARTBEAT_TIMEOUT,
DEFAULT_HTTP_LISTEN_ADDR, DEFAULT_MAX_OFFLOADER_LAG_BYTES, DEFAULT_PARTIAL_BACKUP_CONCURRENCY,
DEFAULT_PARTIAL_BACKUP_TIMEOUT, DEFAULT_PG_LISTEN_ADDR, DEFAULT_SSL_CERT_FILE,
DEFAULT_SSL_CERT_RELOAD_PERIOD, DEFAULT_SSL_KEY_FILE,
DEFAULT_HTTP_LISTEN_ADDR, DEFAULT_MAX_OFFLOADER_LAG_BYTES,
DEFAULT_MAX_REELECT_OFFLOADER_LAG_BYTES, DEFAULT_MAX_TIMELINE_DISK_USAGE_BYTES,
DEFAULT_PARTIAL_BACKUP_CONCURRENCY, DEFAULT_PARTIAL_BACKUP_TIMEOUT, DEFAULT_PG_LISTEN_ADDR,
DEFAULT_SSL_CERT_FILE, DEFAULT_SSL_CERT_RELOAD_PERIOD, DEFAULT_SSL_KEY_FILE,
};
use safekeeper::wal_backup::WalBackup;
use safekeeper::{
@@ -138,6 +139,15 @@ struct Args {
/// Safekeeper won't be elected for WAL offloading if it is lagging for more than this value in bytes
#[arg(long, default_value_t = DEFAULT_MAX_OFFLOADER_LAG_BYTES)]
max_offloader_lag: u64,
/* BEGIN_HADRON */
/// Safekeeper will re-elect a new offloader if the current backup lagging for more than this value in bytes
#[arg(long, default_value_t = DEFAULT_MAX_REELECT_OFFLOADER_LAG_BYTES)]
max_reelect_offloader_lag_bytes: u64,
/// Safekeeper will stop accepting new WALs if the timeline disk usage exceeds this value in bytes.
/// Setting this value to 0 disables the limit.
#[arg(long, default_value_t = DEFAULT_MAX_TIMELINE_DISK_USAGE_BYTES)]
max_timeline_disk_usage_bytes: u64,
/* END_HADRON */
/// Number of max parallel WAL segments to be offloaded to remote storage.
#[arg(long, default_value = "5")]
wal_backup_parallel_jobs: usize,
@@ -391,6 +401,10 @@ async fn main() -> anyhow::Result<()> {
peer_recovery_enabled: args.peer_recovery,
remote_storage: args.remote_storage,
max_offloader_lag_bytes: args.max_offloader_lag,
/* BEGIN_HADRON */
max_reelect_offloader_lag_bytes: args.max_reelect_offloader_lag_bytes,
max_timeline_disk_usage_bytes: args.max_timeline_disk_usage_bytes,
/* END_HADRON */
wal_backup_enabled: !args.disable_wal_backup,
backup_parallel_jobs: args.wal_backup_parallel_jobs,
pg_auth,

View File

@@ -17,6 +17,7 @@ use utils::crashsafe::durable_rename;
use crate::control_file_upgrade::{downgrade_v10_to_v9, upgrade_control_file};
use crate::metrics::PERSIST_CONTROL_FILE_SECONDS;
use crate::metrics::WAL_DISK_IO_ERRORS;
use crate::state::{EvictionState, TimelinePersistentState};
pub const SK_MAGIC: u32 = 0xcafeceefu32;
@@ -192,11 +193,14 @@ impl TimelinePersistentState {
impl Storage for FileStorage {
/// Persists state durably to the underlying storage.
async fn persist(&mut self, s: &TimelinePersistentState) -> Result<()> {
// start timer for metrics
let _timer = PERSIST_CONTROL_FILE_SECONDS.start_timer();
// write data to safekeeper.control.partial
let control_partial_path = self.timeline_dir.join(CONTROL_FILE_NAME_PARTIAL);
let mut control_partial = File::create(&control_partial_path).await.with_context(|| {
/* BEGIN_HADRON */
WAL_DISK_IO_ERRORS.inc();
/*END_HADRON */
format!(
"failed to create partial control file at: {}",
&control_partial_path
@@ -206,14 +210,24 @@ impl Storage for FileStorage {
let buf: Vec<u8> = s.write_to_buf()?;
control_partial.write_all(&buf).await.with_context(|| {
/* BEGIN_HADRON */
WAL_DISK_IO_ERRORS.inc();
/*END_HADRON */
format!("failed to write safekeeper state into control file at: {control_partial_path}")
})?;
control_partial.flush().await.with_context(|| {
/* BEGIN_HADRON */
WAL_DISK_IO_ERRORS.inc();
/*END_HADRON */
format!("failed to flush safekeeper state into control file at: {control_partial_path}")
})?;
let control_path = self.timeline_dir.join(CONTROL_FILE_NAME);
durable_rename(&control_partial_path, &control_path, !self.no_sync).await?;
durable_rename(&control_partial_path, &control_path, !self.no_sync)
.await
/* BEGIN_HADRON */
.inspect_err(|_| WAL_DISK_IO_ERRORS.inc())?;
/* END_HADRON */
// update internal state
self.state = s.clone();

View File

@@ -61,6 +61,13 @@ pub mod defaults {
pub const DEFAULT_HEARTBEAT_TIMEOUT: &str = "5000ms";
pub const DEFAULT_MAX_OFFLOADER_LAG_BYTES: u64 = 128 * (1 << 20);
/* BEGIN_HADRON */
// Default leader re-elect is 0(disabled). SK will re-elect leader if the current leader is lagging this many bytes.
pub const DEFAULT_MAX_REELECT_OFFLOADER_LAG_BYTES: u64 = 0;
// Default disk usage limit is 0 (disabled). It means each timeline by default can use up to this many WAL
// disk space on this SK until SK begins to reject WALs.
pub const DEFAULT_MAX_TIMELINE_DISK_USAGE_BYTES: u64 = 0;
/* END_HADRON */
pub const DEFAULT_PARTIAL_BACKUP_TIMEOUT: &str = "15m";
pub const DEFAULT_CONTROL_FILE_SAVE_INTERVAL: &str = "300s";
pub const DEFAULT_PARTIAL_BACKUP_CONCURRENCY: &str = "5";
@@ -99,6 +106,10 @@ pub struct SafeKeeperConf {
pub peer_recovery_enabled: bool,
pub remote_storage: Option<RemoteStorageConfig>,
pub max_offloader_lag_bytes: u64,
/* BEGIN_HADRON */
pub max_reelect_offloader_lag_bytes: u64,
pub max_timeline_disk_usage_bytes: u64,
/* END_HADRON */
pub backup_parallel_jobs: usize,
pub wal_backup_enabled: bool,
pub pg_auth: Option<Arc<JwtAuth>>,
@@ -151,6 +162,10 @@ impl SafeKeeperConf {
sk_auth_token: None,
heartbeat_timeout: Duration::new(5, 0),
max_offloader_lag_bytes: defaults::DEFAULT_MAX_OFFLOADER_LAG_BYTES,
/* BEGIN_HADRON */
max_reelect_offloader_lag_bytes: defaults::DEFAULT_MAX_REELECT_OFFLOADER_LAG_BYTES,
max_timeline_disk_usage_bytes: defaults::DEFAULT_MAX_TIMELINE_DISK_USAGE_BYTES,
/* END_HADRON */
current_thread_runtime: false,
walsenders_keep_horizon: false,
partial_backup_timeout: Duration::from_secs(0),

View File

@@ -58,6 +58,25 @@ pub static FLUSH_WAL_SECONDS: Lazy<Histogram> = Lazy::new(|| {
)
.expect("Failed to register safekeeper_flush_wal_seconds histogram")
});
/* BEGIN_HADRON */
pub static WAL_DISK_IO_ERRORS: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"safekeeper_wal_disk_io_errors",
"Number of disk I/O errors when creating and flushing WALs and control files"
)
.expect("Failed to register safekeeper_wal_disk_io_errors counter")
});
pub static WAL_STORAGE_LIMIT_ERRORS: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"safekeeper_wal_storage_limit_errors",
concat!(
"Number of errors due to timeline WAL storage utilization exceeding configured limit. ",
"An increase in this metric indicates issues backing up or removing WALs."
)
)
.expect("Failed to register safekeeper_wal_storage_limit_errors counter")
});
/* END_HADRON */
pub static PERSIST_CONTROL_FILE_SECONDS: Lazy<Histogram> = Lazy::new(|| {
register_histogram!(
"safekeeper_persist_control_file_seconds",
@@ -138,6 +157,15 @@ pub static BACKUP_ERRORS: Lazy<IntCounter> = Lazy::new(|| {
)
.expect("Failed to register safekeeper_backup_errors_total counter")
});
/* BEGIN_HADRON */
pub static BACKUP_REELECT_LEADER_COUNT: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"safekeeper_backup_reelect_leader_total",
"Number of times the backup leader was reelected"
)
.expect("Failed to register safekeeper_backup_reelect_leader_total counter")
});
/* END_HADRON */
pub static BROKER_PUSH_ALL_UPDATES_SECONDS: Lazy<Histogram> = Lazy::new(|| {
register_histogram!(
"safekeeper_broker_push_update_seconds",

View File

@@ -16,7 +16,7 @@ use tokio::sync::mpsc::error::SendError;
use tokio::task::JoinHandle;
use tokio::time::MissedTickBehavior;
use tracing::{Instrument, error, info, info_span};
use utils::critical;
use utils::critical_timeline;
use utils::lsn::Lsn;
use utils::postgres_client::{Compression, InterpretedFormat};
use wal_decoder::models::{InterpretedWalRecord, InterpretedWalRecords};
@@ -268,6 +268,8 @@ impl InterpretedWalReader {
let (shard_notification_tx, shard_notification_rx) = tokio::sync::mpsc::unbounded_channel();
let ttid = wal_stream.ttid;
let reader = InterpretedWalReader {
wal_stream,
shard_senders: HashMap::from([(
@@ -300,7 +302,11 @@ impl InterpretedWalReader {
.inspect_err(|err| match err {
// TODO: we may want to differentiate these errors further.
InterpretedWalReaderError::Decode(_) => {
critical!("failed to decode WAL record: {err:?}");
critical_timeline!(
ttid.tenant_id,
ttid.timeline_id,
"failed to read WAL record: {err:?}"
);
}
err => error!("failed to read WAL record: {err}"),
})
@@ -363,9 +369,14 @@ impl InterpretedWalReader {
metric.dec();
}
let ttid = self.wal_stream.ttid;
match self.run_impl(start_pos).await {
Err(err @ InterpretedWalReaderError::Decode(_)) => {
critical!("failed to decode WAL record: {err:?}");
critical_timeline!(
ttid.tenant_id,
ttid.timeline_id,
"failed to decode WAL record: {err:?}"
);
}
Err(err) => error!("failed to read WAL record: {err}"),
Ok(()) => info!("interpreted wal reader exiting"),

View File

@@ -26,7 +26,9 @@ use utils::id::{NodeId, TenantId, TenantTimelineId};
use utils::lsn::Lsn;
use utils::sync::gate::Gate;
use crate::metrics::{FullTimelineInfo, MISC_OPERATION_SECONDS, WalStorageMetrics};
use crate::metrics::{
FullTimelineInfo, MISC_OPERATION_SECONDS, WAL_STORAGE_LIMIT_ERRORS, WalStorageMetrics,
};
use crate::rate_limit::RateLimiter;
use crate::receive_wal::WalReceivers;
use crate::safekeeper::{AcceptorProposerMessage, ProposerAcceptorMessage, SafeKeeper, TermLsn};
@@ -1050,6 +1052,39 @@ impl WalResidentTimeline {
Ok(ss)
}
// BEGIN HADRON
// Check if disk usage by WAL segment files for this timeline exceeds the configured limit.
fn hadron_check_disk_usage(
&self,
shared_state_locked: &mut WriteGuardSharedState<'_>,
) -> Result<()> {
// The disk usage is calculated based on the number of segments between `last_removed_segno`
// and the current flush LSN segment number. `last_removed_segno` is advanced after
// unneeded WAL files are physically removed from disk (see `update_wal_removal_end()`
// in `timeline_manager.rs`).
let max_timeline_disk_usage_bytes = self.conf.max_timeline_disk_usage_bytes;
if max_timeline_disk_usage_bytes > 0 {
let last_removed_segno = self.last_removed_segno.load(Ordering::Relaxed);
let flush_lsn = shared_state_locked.sk.flush_lsn();
let wal_seg_size = shared_state_locked.sk.state().server.wal_seg_size as u64;
let current_segno = flush_lsn.segment_number(wal_seg_size as usize);
let segno_count = current_segno - last_removed_segno;
let disk_usage_bytes = segno_count * wal_seg_size;
if disk_usage_bytes > max_timeline_disk_usage_bytes {
WAL_STORAGE_LIMIT_ERRORS.inc();
bail!(
"WAL storage utilization exceeds configured limit of {} bytes: current disk usage: {} bytes",
max_timeline_disk_usage_bytes,
disk_usage_bytes
);
}
}
Ok(())
}
// END HADRON
/// Pass arrived message to the safekeeper.
pub async fn process_msg(
&self,
@@ -1062,6 +1097,13 @@ impl WalResidentTimeline {
let mut rmsg: Option<AcceptorProposerMessage>;
{
let mut shared_state = self.write_shared_state().await;
// BEGIN HADRON
// Errors from the `hadron_check_disk_usage()` function fail the process_msg() function, which
// gets propagated upward and terminates the entire WalAcceptor. This will cause postgres to
// disconnect from the safekeeper and reestablish another connection. Postgres will keep retrying
// safekeeper connections every second until it can successfully propose WAL to the SK again.
self.hadron_check_disk_usage(&mut shared_state)?;
// END HADRON
rmsg = shared_state.sk.safekeeper().process_msg(msg).await?;
// if this is AppendResponse, fill in proper hot standby feedback.

View File

@@ -26,7 +26,9 @@ use utils::id::{NodeId, TenantTimelineId};
use utils::lsn::Lsn;
use utils::{backoff, pausable_failpoint};
use crate::metrics::{BACKED_UP_SEGMENTS, BACKUP_ERRORS, WAL_BACKUP_TASKS};
use crate::metrics::{
BACKED_UP_SEGMENTS, BACKUP_ERRORS, BACKUP_REELECT_LEADER_COUNT, WAL_BACKUP_TASKS,
};
use crate::timeline::WalResidentTimeline;
use crate::timeline_manager::{Manager, StateSnapshot};
use crate::{SafeKeeperConf, WAL_BACKUP_RUNTIME};
@@ -70,8 +72,9 @@ pub(crate) async fn update_task(
need_backup: bool,
state: &StateSnapshot,
) {
let (offloader, election_dbg_str) =
determine_offloader(&state.peers, state.backup_lsn, mgr.tli.ttid, &mgr.conf);
/* BEGIN_HADRON */
let (offloader, election_dbg_str) = hadron_determine_offloader(mgr, state);
/* END_HADRON */
let elected_me = Some(mgr.conf.my_id) == offloader;
let should_task_run = need_backup && elected_me;
@@ -127,6 +130,70 @@ async fn shut_down_task(entry: &mut Option<WalBackupTaskHandle>) {
}
}
/* BEGIN_HADRON */
// On top of the neon determine_offloader, we also check if the current offloader is lagging behind too much.
// If it is, we re-elect a new offloader. This mitigates the below issue. It also helps distribute the load across SKs.
//
// We observe that the offloader fails to upload a segment due to race conditions on XLOG SWITCH and PG start streaming WALs.
// wal_backup task continously failing to upload a full segment while the segment remains partial on the disk.
// The consequence is that commit_lsn for all SKs move forward but backup_lsn stays the same. Then, all SKs run out of disk space.
// See go/sk-ood-xlog-switch for more details.
//
// To mitigate this issue, we will re-elect a new offloader if the current offloader is lagging behind too much.
// Each SK makes the decision locally but they are aware of each other's commit and backup lsns.
//
// determine_offloader will pick a SK. say SK-1.
// Each SK checks
// -- if commit_lsn - back_lsn > threshold,
// -- -- remove SK-1 from the candidate and call determine_offloader again.
// SK-1 will step down and all SKs will elect the same leader again.
// After the backup is caught up, the leader will become SK-1 again.
fn hadron_determine_offloader(mgr: &Manager, state: &StateSnapshot) -> (Option<NodeId>, String) {
let mut offloader: Option<NodeId>;
let mut election_dbg_str: String;
let caughtup_peers_count: usize;
(offloader, election_dbg_str, caughtup_peers_count) =
determine_offloader(&state.peers, state.backup_lsn, mgr.tli.ttid, &mgr.conf);
if offloader.is_none()
|| caughtup_peers_count <= 1
|| mgr.conf.max_reelect_offloader_lag_bytes == 0
{
return (offloader, election_dbg_str);
}
let offloader_sk_id = offloader.unwrap();
let backup_lag = state.commit_lsn.checked_sub(state.backup_lsn);
if backup_lag.is_none() {
info!("Backup lag is None. Skipping re-election.");
return (offloader, election_dbg_str);
}
let backup_lag = backup_lag.unwrap().0;
if backup_lag < mgr.conf.max_reelect_offloader_lag_bytes {
return (offloader, election_dbg_str);
}
info!(
"Electing a new leader: Backup lag is too high backup lsn lag {} threshold {}: {}",
backup_lag, mgr.conf.max_reelect_offloader_lag_bytes, election_dbg_str
);
BACKUP_REELECT_LEADER_COUNT.inc();
// Remove the current offloader if lag is too high.
let new_peers: Vec<_> = state
.peers
.iter()
.filter(|p| p.sk_id != offloader_sk_id)
.cloned()
.collect();
(offloader, election_dbg_str, _) =
determine_offloader(&new_peers, state.backup_lsn, mgr.tli.ttid, &mgr.conf);
(offloader, election_dbg_str)
}
/* END_HADRON */
/// The goal is to ensure that normally only one safekeepers offloads. However,
/// it is fine (and inevitable, as s3 doesn't provide CAS) that for some short
/// time we have several ones as they PUT the same files. Also,
@@ -141,13 +208,13 @@ fn determine_offloader(
wal_backup_lsn: Lsn,
ttid: TenantTimelineId,
conf: &SafeKeeperConf,
) -> (Option<NodeId>, String) {
) -> (Option<NodeId>, String, usize) {
// TODO: remove this once we fill newly joined safekeepers since backup_lsn.
let capable_peers = alive_peers
.iter()
.filter(|p| p.local_start_lsn <= wal_backup_lsn);
match capable_peers.clone().map(|p| p.commit_lsn).max() {
None => (None, "no connected peers to elect from".to_string()),
None => (None, "no connected peers to elect from".to_string(), 0),
Some(max_commit_lsn) => {
let threshold = max_commit_lsn
.checked_sub(conf.max_offloader_lag_bytes)
@@ -175,6 +242,7 @@ fn determine_offloader(
capable_peers_dbg,
caughtup_peers.len()
),
caughtup_peers.len(),
)
}
}
@@ -346,6 +414,8 @@ async fn backup_lsn_range(
anyhow::bail!("parallel_jobs must be >= 1");
}
pausable_failpoint!("backup-lsn-range-pausable");
let remote_timeline_path = &timeline.remote_path;
let start_lsn = *backup_lsn;
let segments = get_segments(start_lsn, end_lsn, wal_seg_size);

View File

@@ -1,15 +1,15 @@
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::Bytes;
use futures::stream::BoxStream;
use futures::{Stream, StreamExt};
use safekeeper_api::Term;
use utils::lsn::Lsn;
use crate::send_wal::EndWatch;
use crate::timeline::WalResidentTimeline;
use crate::wal_storage::WalReader;
use bytes::Bytes;
use futures::stream::BoxStream;
use futures::{Stream, StreamExt};
use safekeeper_api::Term;
use utils::id::TenantTimelineId;
use utils::lsn::Lsn;
#[derive(PartialEq, Eq, Debug)]
pub(crate) struct WalBytes {
@@ -37,6 +37,8 @@ struct PositionedWalReader {
pub(crate) struct StreamingWalReader {
stream: BoxStream<'static, WalOrReset>,
start_changed_tx: tokio::sync::watch::Sender<Lsn>,
// HADRON: Added TenantTimelineId for instrumentation purposes.
pub(crate) ttid: TenantTimelineId,
}
pub(crate) enum WalOrReset {
@@ -63,6 +65,7 @@ impl StreamingWalReader {
buffer_size: usize,
) -> Self {
let (start_changed_tx, start_changed_rx) = tokio::sync::watch::channel(start);
let ttid = tli.ttid;
let state = WalReaderStreamState {
tli,
@@ -107,6 +110,7 @@ impl StreamingWalReader {
Self {
stream,
start_changed_tx,
ttid,
}
}

View File

@@ -31,7 +31,8 @@ use utils::id::TenantTimelineId;
use utils::lsn::Lsn;
use crate::metrics::{
REMOVED_WAL_SEGMENTS, WAL_STORAGE_OPERATION_SECONDS, WalStorageMetrics, time_io_closure,
REMOVED_WAL_SEGMENTS, WAL_DISK_IO_ERRORS, WAL_STORAGE_OPERATION_SECONDS, WalStorageMetrics,
time_io_closure,
};
use crate::state::TimelinePersistentState;
use crate::wal_backup::{WalBackup, read_object, remote_timeline_path};
@@ -293,9 +294,12 @@ impl PhysicalStorage {
// half initialized segment, first bake it under tmp filename and
// then rename.
let tmp_path = self.timeline_dir.join("waltmp");
let file = File::create(&tmp_path)
.await
.with_context(|| format!("Failed to open tmp wal file {:?}", &tmp_path))?;
let file: File = File::create(&tmp_path).await.with_context(|| {
/* BEGIN_HADRON */
WAL_DISK_IO_ERRORS.inc();
/* END_HADRON */
format!("Failed to open tmp wal file {:?}", &tmp_path)
})?;
fail::fail_point!("sk-zero-segment", |_| {
info!("sk-zero-segment failpoint hit");
@@ -382,7 +386,11 @@ impl PhysicalStorage {
let flushed = self
.write_in_segment(segno, xlogoff, &buf[..bytes_write])
.await?;
.await
/* BEGIN_HADRON */
.inspect_err(|_| WAL_DISK_IO_ERRORS.inc())?;
/* END_HADRON */
self.write_lsn += bytes_write as u64;
if flushed {
self.flush_lsn = self.write_lsn;
@@ -491,7 +499,11 @@ impl Storage for PhysicalStorage {
}
if let Some(unflushed_file) = self.file.take() {
self.fdatasync_file(&unflushed_file).await?;
self.fdatasync_file(&unflushed_file)
.await
/* BEGIN_HADRON */
.inspect_err(|_| WAL_DISK_IO_ERRORS.inc())?;
/* END_HADRON */
self.file = Some(unflushed_file);
} else {
// We have unflushed data (write_lsn != flush_lsn), but no file. This

View File

@@ -159,6 +159,10 @@ pub fn run_server(os: NodeOs, disk: Arc<SafekeeperDisk>) -> Result<()> {
heartbeat_timeout: Duration::from_secs(0),
remote_storage: None,
max_offloader_lag_bytes: 0,
/* BEGIN_HADRON */
max_reelect_offloader_lag_bytes: 0,
max_timeline_disk_usage_bytes: 0,
/* END_HADRON */
wal_backup_enabled: false,
listen_pg_addr_tenant_only: None,
advertise_pg_addr: None,

View File

@@ -0,0 +1 @@
ALTER TABLE safekeepers ALTER COLUMN scheduling_policy SET DEFAULT 'pause';

View File

@@ -0,0 +1 @@
ALTER TABLE safekeepers ALTER COLUMN scheduling_policy SET DEFAULT 'activating';

View File

@@ -76,6 +76,9 @@ pub(crate) struct StorageControllerMetricGroup {
/// How many shards would like to reconcile but were blocked by concurrency limits
pub(crate) storage_controller_pending_reconciles: measured::Gauge,
/// How many shards are keep-failing and will be ignored when considering to run optimizations
pub(crate) storage_controller_keep_failing_reconciles: measured::Gauge,
/// HTTP request status counters for handled requests
pub(crate) storage_controller_http_request_status:
measured::CounterVec<HttpRequestStatusLabelGroupSet>,

View File

@@ -1388,6 +1388,48 @@ impl Persistence {
.await
}
/// Activate the given safekeeper, ensuring that there is no TOCTOU.
/// Returns `Some` if the safekeeper has indeed been activating (or already active). Other states return `None`.
pub(crate) async fn activate_safekeeper(&self, id_: i64) -> Result<Option<()>, DatabaseError> {
use crate::schema::safekeepers::dsl::*;
self.with_conn(move |conn| {
Box::pin(async move {
#[derive(Insertable, AsChangeset)]
#[diesel(table_name = crate::schema::safekeepers)]
struct UpdateSkSchedulingPolicy<'a> {
id: i64,
scheduling_policy: &'a str,
}
let scheduling_policy_active = String::from(SkSchedulingPolicy::Active);
let scheduling_policy_activating = String::from(SkSchedulingPolicy::Activating);
let rows_affected = diesel::update(
safekeepers.filter(id.eq(id_)).filter(
scheduling_policy
.eq(scheduling_policy_activating)
.or(scheduling_policy.eq(&scheduling_policy_active)),
),
)
.set(scheduling_policy.eq(&scheduling_policy_active))
.execute(conn)
.await?;
if rows_affected == 0 {
return Ok(Some(()));
}
if rows_affected != 1 {
return Err(DatabaseError::Logical(format!(
"unexpected number of rows ({rows_affected})",
)));
}
Ok(Some(()))
})
})
.await
}
/// Persist timeline. Returns if the timeline was newly inserted. If it wasn't, we haven't done any writes.
pub(crate) async fn insert_timeline(&self, entry: TimelinePersistence) -> DatabaseResult<bool> {
use crate::schema::timelines;

View File

@@ -31,8 +31,8 @@ use pageserver_api::controller_api::{
AvailabilityZone, MetadataHealthRecord, MetadataHealthUpdateRequest, NodeAvailability,
NodeRegisterRequest, NodeSchedulingPolicy, NodeShard, NodeShardResponse, PlacementPolicy,
ShardSchedulingPolicy, ShardsPreferredAzsRequest, ShardsPreferredAzsResponse,
TenantCreateRequest, TenantCreateResponse, TenantCreateResponseShard, TenantDescribeResponse,
TenantDescribeResponseShard, TenantLocateResponse, TenantPolicyRequest,
SkSchedulingPolicy, TenantCreateRequest, TenantCreateResponse, TenantCreateResponseShard,
TenantDescribeResponse, TenantDescribeResponseShard, TenantLocateResponse, TenantPolicyRequest,
TenantShardMigrateRequest, TenantShardMigrateResponse,
};
use pageserver_api::models::{
@@ -210,6 +210,10 @@ pub const RECONCILER_CONCURRENCY_DEFAULT: usize = 128;
pub const PRIORITY_RECONCILER_CONCURRENCY_DEFAULT: usize = 256;
pub const SAFEKEEPER_RECONCILER_CONCURRENCY_DEFAULT: usize = 32;
// Number of consecutive reconciliation errors, occured for one shard,
// after which the shard is ignored when considering to run optimizations.
const MAX_CONSECUTIVE_RECONCILIATION_ERRORS: usize = 5;
// Depth of the channel used to enqueue shards for reconciliation when they can't do it immediately.
// This channel is finite-size to avoid using excessive memory if we get into a state where reconciles are finishing more slowly
// than they're being pushed onto the queue.
@@ -702,6 +706,36 @@ struct ShardMutationLocations {
#[derive(Default, Clone)]
struct TenantMutationLocations(BTreeMap<TenantShardId, ShardMutationLocations>);
struct ReconcileAllResult {
spawned_reconciles: usize,
keep_failing_reconciles: usize,
has_delayed_reconciles: bool,
}
impl ReconcileAllResult {
fn new(
spawned_reconciles: usize,
keep_failing_reconciles: usize,
has_delayed_reconciles: bool,
) -> Self {
assert!(
spawned_reconciles >= keep_failing_reconciles,
"It is impossible to have more keep-failing reconciles than spawned reconciles"
);
Self {
spawned_reconciles,
keep_failing_reconciles,
has_delayed_reconciles,
}
}
/// We can run optimizations only if we don't have any delayed reconciles and
/// all spawned reconciles are also keep-failing reconciles.
fn can_run_optimizations(&self) -> bool {
!self.has_delayed_reconciles && self.spawned_reconciles == self.keep_failing_reconciles
}
}
impl Service {
pub fn get_config(&self) -> &Config {
&self.config
@@ -899,7 +933,7 @@ impl Service {
// which require it: under normal circumstances this should only include tenants that were in some
// transient state before we restarted, or any tenants whose compute hooks failed above.
tracing::info!("Checking for shards in need of reconciliation...");
let reconcile_tasks = self.reconcile_all();
let reconcile_all_result = self.reconcile_all();
// We will not wait for these reconciliation tasks to run here: we're now done with startup and
// normal operations may proceed.
@@ -947,8 +981,9 @@ impl Service {
}
}
let spawned_reconciles = reconcile_all_result.spawned_reconciles;
tracing::info!(
"Startup complete, spawned {reconcile_tasks} reconciliation tasks ({shard_count} shards total)"
"Startup complete, spawned {spawned_reconciles} reconciliation tasks ({shard_count} shards total)"
);
}
@@ -1199,8 +1234,8 @@ impl Service {
while !self.reconcilers_cancel.is_cancelled() {
tokio::select! {
_ = interval.tick() => {
let reconciles_spawned = self.reconcile_all();
if reconciles_spawned == 0 {
let reconcile_all_result = self.reconcile_all();
if reconcile_all_result.can_run_optimizations() {
// Run optimizer only when we didn't find any other work to do
self.optimize_all().await;
}
@@ -1214,7 +1249,7 @@ impl Service {
}
/// Heartbeat all storage nodes once in a while.
#[instrument(skip_all)]
async fn spawn_heartbeat_driver(&self) {
async fn spawn_heartbeat_driver(self: &Arc<Self>) {
self.startup_complete.clone().wait().await;
let mut interval = tokio::time::interval(self.config.heartbeat_interval);
@@ -1341,18 +1376,51 @@ impl Service {
}
}
if let Ok(deltas) = res_sk {
let mut locked = self.inner.write().unwrap();
let mut safekeepers = (*locked.safekeepers).clone();
for (id, state) in deltas.0 {
let Some(sk) = safekeepers.get_mut(&id) else {
tracing::info!(
"Couldn't update safekeeper safekeeper state for id {id} from heartbeat={state:?}"
);
continue;
};
sk.set_availability(state);
let mut to_activate = Vec::new();
{
let mut locked = self.inner.write().unwrap();
let mut safekeepers = (*locked.safekeepers).clone();
for (id, state) in deltas.0 {
let Some(sk) = safekeepers.get_mut(&id) else {
tracing::info!(
"Couldn't update safekeeper safekeeper state for id {id} from heartbeat={state:?}"
);
continue;
};
if sk.scheduling_policy() == SkSchedulingPolicy::Activating
&& let SafekeeperState::Available { .. } = state
{
to_activate.push(id);
}
sk.set_availability(state);
}
locked.safekeepers = Arc::new(safekeepers);
}
for sk_id in to_activate {
// TODO this can race with set_scheduling_policy (can create disjoint DB <-> in-memory state)
tracing::info!("Activating safekeeper {sk_id}");
match self.persistence.activate_safekeeper(sk_id.0 as i64).await {
Ok(Some(())) => {}
Ok(None) => {
tracing::info!(
"safekeeper {sk_id} has been removed from db or has different scheduling policy than active or activating"
);
}
Err(e) => {
tracing::warn!("couldn't apply activation of {sk_id} to db: {e}");
continue;
}
}
if let Err(e) = self
.set_safekeeper_scheduling_policy_in_mem(sk_id, SkSchedulingPolicy::Active)
.await
{
tracing::info!("couldn't activate safekeeper {sk_id} in memory: {e}");
continue;
}
tracing::info!("Activation of safekeeper {sk_id} done");
}
locked.safekeepers = Arc::new(safekeepers);
}
}
}
@@ -1408,6 +1476,7 @@ impl Service {
match result.result {
Ok(()) => {
tenant.consecutive_errors_count = 0;
tenant.apply_observed_deltas(deltas);
tenant.waiter.advance(result.sequence);
}
@@ -1426,6 +1495,8 @@ impl Service {
}
}
tenant.consecutive_errors_count = tenant.consecutive_errors_count.saturating_add(1);
// Ordering: populate last_error before advancing error_seq,
// so that waiters will see the correct error after waiting.
tenant.set_last_error(result.sequence, e);
@@ -8026,7 +8097,7 @@ impl Service {
/// Returns how many reconciliation tasks were started, or `1` if no reconciles were
/// spawned but some _would_ have been spawned if `reconciler_concurrency` units where
/// available. A return value of 0 indicates that everything is fully reconciled already.
fn reconcile_all(&self) -> usize {
fn reconcile_all(&self) -> ReconcileAllResult {
let mut locked = self.inner.write().unwrap();
let (nodes, tenants, scheduler) = locked.parts_mut();
let pageservers = nodes.clone();
@@ -8034,13 +8105,16 @@ impl Service {
// This function is an efficient place to update lazy statistics, since we are walking
// all tenants.
let mut pending_reconciles = 0;
let mut keep_failing_reconciles = 0;
let mut az_violations = 0;
// If we find any tenants to drop from memory, stash them to offload after
// we're done traversing the map of tenants.
let mut drop_detached_tenants = Vec::new();
let mut reconciles_spawned = 0;
let mut spawned_reconciles = 0;
let mut has_delayed_reconciles = false;
for shard in tenants.values_mut() {
// Accumulate scheduling statistics
if let (Some(attached), Some(preferred)) =
@@ -8060,18 +8134,32 @@ impl Service {
// If there is something delayed, then return a nonzero count so that
// callers like reconcile_all_now do not incorrectly get the impression
// that the system is in a quiescent state.
reconciles_spawned = std::cmp::max(1, reconciles_spawned);
has_delayed_reconciles = true;
pending_reconciles += 1;
continue;
}
// Eventual consistency: if an earlier reconcile job failed, and the shard is still
// dirty, spawn another one
let consecutive_errors_count = shard.consecutive_errors_count;
if self
.maybe_reconcile_shard(shard, &pageservers, ReconcilerPriority::Normal)
.is_some()
{
reconciles_spawned += 1;
spawned_reconciles += 1;
// Count shards that are keep-failing. We still want to reconcile them
// to avoid a situation where a shard is stuck.
// But we don't want to consider them when deciding to run optimizations.
if consecutive_errors_count >= MAX_CONSECUTIVE_RECONCILIATION_ERRORS {
tracing::warn!(
tenant_id=%shard.tenant_shard_id.tenant_id,
shard_id=%shard.tenant_shard_id.shard_slug(),
"Shard reconciliation is keep-failing: {} errors",
consecutive_errors_count
);
keep_failing_reconciles += 1;
}
} else if shard.delayed_reconcile {
// Shard wanted to reconcile but for some reason couldn't.
pending_reconciles += 1;
@@ -8110,7 +8198,16 @@ impl Service {
.storage_controller_pending_reconciles
.set(pending_reconciles as i64);
reconciles_spawned
metrics::METRICS_REGISTRY
.metrics_group
.storage_controller_keep_failing_reconciles
.set(keep_failing_reconciles as i64);
ReconcileAllResult::new(
spawned_reconciles,
keep_failing_reconciles,
has_delayed_reconciles,
)
}
/// `optimize` in this context means identifying shards which have valid scheduled locations, but
@@ -8783,13 +8880,13 @@ impl Service {
/// also wait for any generated Reconcilers to complete. Calling this until it returns zero should
/// put the system into a quiescent state where future background reconciliations won't do anything.
pub(crate) async fn reconcile_all_now(&self) -> Result<usize, ReconcileWaitError> {
let reconciles_spawned = self.reconcile_all();
let reconciles_spawned = if reconciles_spawned == 0 {
let reconcile_all_result = self.reconcile_all();
let mut spawned_reconciles = reconcile_all_result.spawned_reconciles;
if reconcile_all_result.can_run_optimizations() {
// Only optimize when we are otherwise idle
self.optimize_all().await
} else {
reconciles_spawned
};
let optimization_reconciles = self.optimize_all().await;
spawned_reconciles += optimization_reconciles;
}
let waiters = {
let mut waiters = Vec::new();
@@ -8826,11 +8923,11 @@ impl Service {
tracing::info!(
"{} reconciles in reconcile_all, {} waiters",
reconciles_spawned,
spawned_reconciles,
waiter_count
);
Ok(std::cmp::max(waiter_count, reconciles_spawned))
Ok(std::cmp::max(waiter_count, spawned_reconciles))
}
async fn stop_reconciliations(&self, reason: StopReconciliationsReason) {

View File

@@ -236,40 +236,30 @@ impl Service {
F: std::future::Future<Output = mgmt_api::Result<T>> + Send + 'static,
T: Sync + Send + 'static,
{
let target_sk_count = safekeepers.len();
if target_sk_count == 0 {
return Err(ApiError::InternalServerError(anyhow::anyhow!(
"timeline configured without any safekeepers"
)));
}
if target_sk_count < self.config.timeline_safekeeper_count {
tracing::warn!(
"running a quorum operation with {} safekeepers, which is less than configured {} safekeepers per timeline",
target_sk_count,
self.config.timeline_safekeeper_count
);
}
let results = self
.tenant_timeline_safekeeper_op(safekeepers, op, timeout)
.await?;
// Now check if quorum was reached in results.
let target_sk_count = safekeepers.len();
let quorum_size = match target_sk_count {
0 => {
return Err(ApiError::InternalServerError(anyhow::anyhow!(
"timeline configured without any safekeepers",
)));
}
1 | 2 => {
#[cfg(feature = "testing")]
{
// In test settings, it is allowed to have one or two safekeepers
target_sk_count
}
#[cfg(not(feature = "testing"))]
{
// The region is misconfigured: we need at least three safekeepers to be configured
// in order to schedule work to them
tracing::warn!(
"couldn't find at least 3 safekeepers for timeline, found: {:?}",
target_sk_count
);
return Err(ApiError::InternalServerError(anyhow::anyhow!(
"couldn't find at least 3 safekeepers to put timeline to"
)));
}
}
_ => target_sk_count / 2 + 1,
};
let quorum_size = target_sk_count / 2 + 1;
let success_count = results.iter().filter(|res| res.is_ok()).count();
if success_count < quorum_size {
// Failure
@@ -815,7 +805,7 @@ impl Service {
Safekeeper::from_persistence(
crate::persistence::SafekeeperPersistence::from_upsert(
record,
SkSchedulingPolicy::Pause,
SkSchedulingPolicy::Activating,
),
CancellationToken::new(),
use_https,
@@ -856,27 +846,36 @@ impl Service {
.await?;
let node_id = NodeId(id as u64);
// After the change has been persisted successfully, update the in-memory state
{
let mut locked = self.inner.write().unwrap();
let mut safekeepers = (*locked.safekeepers).clone();
let sk = safekeepers
.get_mut(&node_id)
.ok_or(DatabaseError::Logical("Not found".to_string()))?;
sk.set_scheduling_policy(scheduling_policy);
self.set_safekeeper_scheduling_policy_in_mem(node_id, scheduling_policy)
.await
}
match scheduling_policy {
SkSchedulingPolicy::Active => {
locked
.safekeeper_reconcilers
.start_reconciler(node_id, self);
}
SkSchedulingPolicy::Decomissioned | SkSchedulingPolicy::Pause => {
locked.safekeeper_reconcilers.stop_reconciler(node_id);
}
pub(crate) async fn set_safekeeper_scheduling_policy_in_mem(
self: &Arc<Service>,
node_id: NodeId,
scheduling_policy: SkSchedulingPolicy,
) -> Result<(), DatabaseError> {
let mut locked = self.inner.write().unwrap();
let mut safekeepers = (*locked.safekeepers).clone();
let sk = safekeepers
.get_mut(&node_id)
.ok_or(DatabaseError::Logical("Not found".to_string()))?;
sk.set_scheduling_policy(scheduling_policy);
match scheduling_policy {
SkSchedulingPolicy::Active => {
locked
.safekeeper_reconcilers
.start_reconciler(node_id, self);
}
SkSchedulingPolicy::Decomissioned
| SkSchedulingPolicy::Pause
| SkSchedulingPolicy::Activating => {
locked.safekeeper_reconcilers.stop_reconciler(node_id);
}
locked.safekeepers = Arc::new(safekeepers);
}
locked.safekeepers = Arc::new(safekeepers);
Ok(())
}

View File

@@ -131,6 +131,15 @@ pub(crate) struct TenantShard {
#[serde(serialize_with = "read_last_error")]
pub(crate) last_error: std::sync::Arc<std::sync::Mutex<Option<Arc<ReconcileError>>>>,
/// Number of consecutive reconciliation errors that have occurred for this shard.
///
/// When this count reaches MAX_CONSECUTIVE_RECONCILIATION_ERRORS, the tenant shard
/// will be countered as keep-failing in `reconcile_all` calculations. This will lead to
/// allowing optimizations to run even with some failing shards.
///
/// The counter is reset to 0 after a successful reconciliation.
pub(crate) consecutive_errors_count: usize,
/// If we have a pending compute notification that for some reason we weren't able to send,
/// set this to true. If this is set, calls to [`Self::get_reconcile_needed`] will return Yes
/// and trigger a Reconciler run. This is the mechanism by which compute notifications are included in the scope
@@ -594,6 +603,7 @@ impl TenantShard {
waiter: Arc::new(SeqWait::new(Sequence(0))),
error_waiter: Arc::new(SeqWait::new(Sequence(0))),
last_error: Arc::default(),
consecutive_errors_count: 0,
pending_compute_notification: false,
scheduling_policy: ShardSchedulingPolicy::default(),
preferred_node: None,
@@ -1859,6 +1869,7 @@ impl TenantShard {
waiter: Arc::new(SeqWait::new(Sequence::initial())),
error_waiter: Arc::new(SeqWait::new(Sequence::initial())),
last_error: Arc::default(),
consecutive_errors_count: 0,
pending_compute_notification: false,
delayed_reconcile: false,
scheduling_policy: serde_json::from_str(&tsp.scheduling_policy).unwrap(),

View File

@@ -180,7 +180,7 @@ def test_metric_collection(
httpserver.check()
# Check that at least one bucket output object is present, and that all
# can be decompressed and decoded.
# can be decompressed and decoded as NDJSON.
bucket_dumps = {}
assert isinstance(env.pageserver_remote_storage, LocalFsStorage)
for dirpath, _dirs, files in os.walk(env.pageserver_remote_storage.root):
@@ -188,7 +188,13 @@ def test_metric_collection(
file_path = os.path.join(dirpath, file)
log.info(file_path)
if file.endswith(".gz"):
bucket_dumps[file_path] = json.load(gzip.open(file_path))
events = []
with gzip.open(file_path, "rt") as f:
for line in f:
line = line.strip()
if line:
events.append(json.loads(line))
bucket_dumps[file_path] = {"events": events}
assert len(bucket_dumps) >= 1
assert all("events" in data for data in bucket_dumps.values())

View File

@@ -399,7 +399,7 @@ def test_tx_abort_with_many_relations(
# How many relations: this number is tuned to be long enough to take tens of seconds
# if the rollback code path is buggy, tripping the test's timeout.
n = 5000
step = 500
step = 2500
def create():
# Create many relations

View File

@@ -989,6 +989,102 @@ def test_storage_controller_compute_hook_retry(
)
@run_only_on_default_postgres("postgres behavior is not relevant")
def test_storage_controller_compute_hook_keep_failing(
httpserver: HTTPServer,
neon_env_builder: NeonEnvBuilder,
httpserver_listen_address: ListenAddress,
):
neon_env_builder.num_pageservers = 4
neon_env_builder.storage_controller_config = {"use_local_compute_notifications": False}
(host, port) = httpserver_listen_address
neon_env_builder.control_plane_hooks_api = f"http://{host}:{port}"
# Set up CP handler for compute notifications
status_by_tenant: dict[TenantId, int] = {}
def handler(request: Request):
notify_request = request.json
assert notify_request is not None
status = status_by_tenant[TenantId(notify_request["tenant_id"])]
log.info(f"Notify request[{status}]: {notify_request}")
return Response(status=status)
httpserver.expect_request("/notify-attach", method="PUT").respond_with_handler(handler)
# Run neon environment
env = neon_env_builder.init_configs()
env.start()
# Create two tenants:
# - The first tenant is banned by CP and contains only one shard
# - The second tenant is allowed by CP and contains four shards
banned_tenant = TenantId.generate()
status_by_tenant[banned_tenant] = 200 # we will ban this tenant later
env.create_tenant(banned_tenant, placement_policy='{"Attached": 1}')
shard_count = 4
allowed_tenant = TenantId.generate()
status_by_tenant[allowed_tenant] = 200
env.create_tenant(allowed_tenant, shard_count=shard_count, placement_policy='{"Attached": 1}')
# Find the pageserver of the banned tenant
banned_tenant_ps = env.get_tenant_pageserver(banned_tenant)
assert banned_tenant_ps is not None
alive_pageservers = [p for p in env.pageservers if p.id != banned_tenant_ps.id]
# Stop pageserver and ban tenant to trigger failed reconciliation
status_by_tenant[banned_tenant] = 423
banned_tenant_ps.stop()
env.storage_controller.allowed_errors.append(NOTIFY_BLOCKED_LOG)
env.storage_controller.allowed_errors.extend(NOTIFY_FAILURE_LOGS)
env.storage_controller.allowed_errors.append(".*Shard reconciliation is keep-failing.*")
env.storage_controller.node_configure(banned_tenant_ps.id, {"availability": "Offline"})
# Migrate all allowed tenant shards to the first alive pageserver
# to trigger storage controller optimizations due to affinity rules
for shard_number in range(shard_count):
env.storage_controller.tenant_shard_migrate(
TenantShardId(allowed_tenant, shard_number, shard_count),
alive_pageservers[0].id,
config=StorageControllerMigrationConfig(prewarm=False, override_scheduler=True),
)
# Make some reconcile_all calls to trigger optimizations
# RECONCILE_COUNT must be greater than storcon's MAX_CONSECUTIVE_RECONCILIATION_ERRORS
RECONCILE_COUNT = 12
for i in range(RECONCILE_COUNT):
try:
n = env.storage_controller.reconcile_all()
log.info(f"Reconciliation attempt {i} finished with success: {n}")
except StorageControllerApiException as e:
assert "Control plane tenant busy" in str(e)
log.info(f"Reconciliation attempt {i} finished with failure")
banned_descr = env.storage_controller.tenant_describe(banned_tenant)
assert banned_descr["shards"][0]["is_pending_compute_notification"] is True
time.sleep(2)
# Check that the allowed tenant shards are optimized due to affinity rules
locations = alive_pageservers[0].http_client().tenant_list_locations()["tenant_shards"]
not_optimized_shard_count = 0
for loc in locations:
tsi = TenantShardId.parse(loc[0])
if tsi.tenant_id != allowed_tenant:
continue
if loc[1]["mode"] == "AttachedSingle":
not_optimized_shard_count += 1
log.info(f"Shard {tsi} seen in mode {loc[1]['mode']}")
assert not_optimized_shard_count < shard_count, "At least one shard should be optimized"
# Unban the tenant and run reconciliations
status_by_tenant[banned_tenant] = 200
env.storage_controller.reconcile_all()
banned_descr = env.storage_controller.tenant_describe(banned_tenant)
assert banned_descr["shards"][0]["is_pending_compute_notification"] is False
@run_only_on_default_postgres("this test doesn't start an endpoint")
def test_storage_controller_compute_hook_revert(
httpserver: HTTPServer,
@@ -3530,18 +3626,21 @@ def test_safekeeper_deployment_time_update(neon_env_builder: NeonEnvBuilder):
# some small tests for the scheduling policy querying and returning APIs
newest_info = target.get_safekeeper(inserted["id"])
assert newest_info
assert newest_info["scheduling_policy"] == "Pause"
target.safekeeper_scheduling_policy(inserted["id"], "Active")
newest_info = target.get_safekeeper(inserted["id"])
assert newest_info
assert newest_info["scheduling_policy"] == "Active"
# Ensure idempotency
target.safekeeper_scheduling_policy(inserted["id"], "Active")
newest_info = target.get_safekeeper(inserted["id"])
assert newest_info
assert newest_info["scheduling_policy"] == "Active"
# change back to paused again
assert (
newest_info["scheduling_policy"] == "Activating"
or newest_info["scheduling_policy"] == "Active"
)
target.safekeeper_scheduling_policy(inserted["id"], "Pause")
newest_info = target.get_safekeeper(inserted["id"])
assert newest_info
assert newest_info["scheduling_policy"] == "Pause"
# Ensure idempotency
target.safekeeper_scheduling_policy(inserted["id"], "Pause")
newest_info = target.get_safekeeper(inserted["id"])
assert newest_info
assert newest_info["scheduling_policy"] == "Pause"
# change back to active again
target.safekeeper_scheduling_policy(inserted["id"], "Active")
def storcon_heartbeat():
assert env.storage_controller.log_contains(
@@ -3554,6 +3653,57 @@ def test_safekeeper_deployment_time_update(neon_env_builder: NeonEnvBuilder):
target.safekeeper_scheduling_policy(inserted["id"], "Decomissioned")
@run_only_on_default_postgres("this is like a 'unit test' against storcon db")
def test_safekeeper_activating_to_active(neon_env_builder: NeonEnvBuilder):
env = neon_env_builder.init_configs()
env.start()
fake_id = 5
target = env.storage_controller
assert target.get_safekeeper(fake_id) is None
start_sks = target.get_safekeepers()
sk_0 = env.safekeepers[0]
body = {
"active": True,
"id": fake_id,
"created_at": "2023-10-25T09:11:25Z",
"updated_at": "2024-08-28T11:32:43Z",
"region_id": "aws-eu-central-1",
"host": "localhost",
"port": sk_0.port.pg,
"http_port": sk_0.port.http,
"https_port": None,
"version": 5957,
"availability_zone_id": "eu-central-1a",
}
target.on_safekeeper_deploy(fake_id, body)
inserted = target.get_safekeeper(fake_id)
assert inserted is not None
assert target.get_safekeepers() == start_sks + [inserted]
assert eq_safekeeper_records(body, inserted)
def safekeeper_is_active():
newest_info = target.get_safekeeper(inserted["id"])
assert newest_info
assert newest_info["scheduling_policy"] == "Active"
wait_until(safekeeper_is_active)
target.safekeeper_scheduling_policy(inserted["id"], "Activating")
wait_until(safekeeper_is_active)
# Now decomission it
target.safekeeper_scheduling_policy(inserted["id"], "Decomissioned")
def eq_safekeeper_records(a: dict[str, Any], b: dict[str, Any]) -> bool:
compared = [dict(a), dict(b)]

View File

@@ -324,7 +324,7 @@ def test_ancestor_detach_reparents_earlier(neon_env_builder: NeonEnvBuilder):
# it is to be in line with the deletion timestamp.. well, almost.
when = original_ancestor[2][:26]
when_ts = datetime.datetime.fromisoformat(when).replace(tzinfo=datetime.UTC)
now = datetime.datetime.utcnow().replace(tzinfo=datetime.UTC)
now = datetime.datetime.now(datetime.UTC)
assert when_ts < now
assert len(lineage.get("reparenting_history", [])) == 0
elif expected_ancestor == timeline_id:
@@ -458,19 +458,20 @@ def test_ancestor_detach_behavior_v2(neon_env_builder: NeonEnvBuilder, snapshots
env.pageserver.quiesce_tenants()
# checking the ancestor after is much faster than waiting for the endpoint not start
# checking the ancestor after is much faster than waiting for the endpoint to start
expected_result = [
("main", env.initial_timeline, None, 24576, 1),
("after", after, env.initial_timeline, 24576, 1),
("snapshot_branchpoint_old", snapshot_branchpoint_old, env.initial_timeline, 8192, 1),
("snapshot_branchpoint", snapshot_branchpoint, env.initial_timeline, 16384, 1),
("branch_to_detach", branch_to_detach, None, 16384, 1),
("earlier", earlier, env.initial_timeline, 0, 1),
# (branch_name, queried_timeline, expected_ancestor, rows, starts, read_only)
("main", env.initial_timeline, None, 24576, 1, False),
("after", after, env.initial_timeline, 24576, 1, False),
("snapshot_branchpoint_old", snapshot_branchpoint_old, env.initial_timeline, 8192, 1, True),
("snapshot_branchpoint", snapshot_branchpoint, env.initial_timeline, 16384, 1, False),
("branch_to_detach", branch_to_detach, None, 16384, 1, False),
("earlier", earlier, env.initial_timeline, 0, 1, False),
]
assert isinstance(env.pageserver_remote_storage, LocalFsStorage)
for branch_name, queried_timeline, expected_ancestor, _, _ in expected_result:
for branch_name, queried_timeline, expected_ancestor, _, _, _ in expected_result:
details = client.timeline_detail(env.initial_tenant, queried_timeline)
ancestor_timeline_id = details["ancestor_timeline_id"]
if expected_ancestor is None:
@@ -508,13 +509,17 @@ def test_ancestor_detach_behavior_v2(neon_env_builder: NeonEnvBuilder, snapshots
assert len(lineage.get("original_ancestor", [])) == 0
assert len(lineage.get("reparenting_history", [])) == 0
for branch_name, queried_timeline, _, rows, starts in expected_result:
details = client.timeline_detail(env.initial_tenant, queried_timeline)
log.info(f"reading data from branch {branch_name}")
# specifying the lsn makes the endpoint read-only and not connect to safekeepers
for branch_name, queried_timeline, _, rows, starts, read_only in expected_result:
last_record_lsn = None
if read_only:
# specifying the lsn makes the endpoint read-only and not connect to safekeepers
details = client.timeline_detail(env.initial_tenant, queried_timeline)
last_record_lsn = Lsn(details["last_record_lsn"])
log.info(f"reading data from branch {branch_name} at {last_record_lsn}")
with env.endpoints.create(
branch_name,
lsn=Lsn(details["last_record_lsn"]),
lsn=last_record_lsn,
) as ep:
ep.start(safekeeper_generation=1)
assert ep.safe_psql("SELECT count(*) FROM foo;")[0][0] == rows
@@ -1884,6 +1889,31 @@ def test_timeline_detach_with_aux_files_with_detach_v1(
assert set(http.list_aux_files(env.initial_tenant, branch_timeline_id, lsn1).keys()) == set([])
def test_detach_ancestors_with_no_writes(
neon_env_builder: NeonEnvBuilder,
):
env = neon_env_builder.init_start()
endpoint = env.endpoints.create_start("main", tenant_id=env.initial_tenant)
wait_for_last_flush_lsn(env, endpoint, env.initial_tenant, env.initial_timeline)
endpoint.safe_psql(
"SELECT pg_create_logical_replication_slot('test_slot_parent_1', 'pgoutput')"
)
wait_for_last_flush_lsn(env, endpoint, env.initial_tenant, env.initial_timeline)
endpoint.stop()
for i in range(0, 5):
if i == 0:
ancestor_name = "main"
else:
ancestor_name = f"b{i}"
tlid = env.create_branch(f"b{i + 1}", ancestor_branch_name=ancestor_name)
client = env.pageserver.http_client()
client.detach_ancestor(tenant_id=env.initial_tenant, timeline_id=tlid)
# TODO:
# - branch near existing L1 boundary, image layers?
# - investigate: why are layers started at uneven lsn? not just after branching, but in general.

View File

@@ -2740,3 +2740,85 @@ def test_pull_timeline_partial_segment_integrity(neon_env_builder: NeonEnvBuilde
raise Exception("Uneviction did not happen on source safekeeper yet")
wait_until(unevicted)
def test_timeline_disk_usage_limit(neon_env_builder: NeonEnvBuilder):
"""
Test that the timeline disk usage circuit breaker works as expected. We test that:
1. The circuit breaker kicks in when the timeline's disk usage exceeds the configured limit,
and it causes writes to hang.
2. The hanging writes unblock when the issue resolves (by restarting the safekeeper in the
test to simulate a more realistic production troubleshooting scenario).
3. We can continue to write as normal after the issue resolves.
4. There is no data corruption throughout the test.
"""
# Set up environment with a very small disk usage limit (1KB)
neon_env_builder.num_safekeepers = 1
remote_storage_kind = s3_storage()
neon_env_builder.enable_safekeeper_remote_storage(remote_storage_kind)
# Set a very small disk usage limit (1KB)
neon_env_builder.safekeeper_extra_opts = ["--max-timeline-disk-usage-bytes=1024"]
env = neon_env_builder.init_start()
# Create a timeline and endpoint
env.create_branch("test_timeline_disk_usage_limit")
endpoint = env.endpoints.create_start("test_timeline_disk_usage_limit")
# Get the safekeeper
sk = env.safekeepers[0]
# Inject a failpoint to stop WAL backup
with sk.http_client() as http_cli:
http_cli.configure_failpoints([("backup-lsn-range-pausable", "pause")])
# Write some data that will exceed the 1KB limit. While the failpoint is active, this operation
# will hang as Postgres encounters safekeeper-returned errors and retries.
def run_hanging_insert():
with closing(endpoint.connect()) as bg_conn:
with bg_conn.cursor() as bg_cur:
# This should generate more than 1KB of WAL
bg_cur.execute("create table t(key int, value text)")
bg_cur.execute("insert into t select generate_series(1,2000), 'payload'")
# Start the inserts in a background thread
bg_thread = threading.Thread(target=run_hanging_insert)
bg_thread.start()
# Wait for the error message to appear in the compute log
def error_logged():
return endpoint.log_contains("WAL storage utilization exceeds configured limit") is not None
wait_until(error_logged)
log.info("Found expected error message in compute log, resuming.")
# Sanity check that the hanging insert is indeed still hanging. Otherwise means the circuit breaker we
# implemented didn't work as expected.
time.sleep(2)
assert bg_thread.is_alive(), (
"The hanging insert somehow unblocked without resolving the disk usage issue!"
)
log.info("Restarting the safekeeper to resume WAL backup.")
# Restart the safekeeper with defaults to both clear the failpoint and resume the larger disk usage limit.
for sk in env.safekeepers:
sk.stop().start(extra_opts=[])
# The hanging insert will now complete. Join the background thread so that we can
# verify that the insert completed successfully.
bg_thread.join(timeout=120)
assert not bg_thread.is_alive(), "Hanging insert did not complete after safekeeper restart"
log.info("Hanging insert unblocked.")
# Verify we can continue to write as normal
with closing(endpoint.connect()) as conn:
with conn.cursor() as cur:
cur.execute("insert into t select generate_series(2001,3000), 'payload'")
# Sanity check data correctness
with closing(endpoint.connect()) as conn:
with conn.cursor() as cur:
cur.execute("select count(*) from t")
# 2000 rows from first insert + 1000 from last insert
assert cur.fetchone() == (3000,)