Compare commits

..

23 Commits

Author SHA1 Message Date
Joonas Koivunen
78ca8679e8 doc: comment why this is not so easy 2024-03-25 15:32:47 +00:00
Joonas Koivunen
fc82512dde fix: stop holding layermap read lock while downloading
this is a cleanup remnant from `struct Layer` work.
2024-03-25 15:30:00 +00:00
John Spray
2713142308 tests: stabilize compat tests (#7227)
This test had two flaky failure modes:
- pageserver log error for timeline not found: this resulted from
changes for DR when timeline destroy/create was added, but endpoint was
left running during that operation.
- storage controller log error because the test was running for long
enough that a background reconcile happened at almost the exact moment
of test teardown, and our test fixtures tear down the pageservers before
the controller.

Closes: #7224
2024-03-25 14:35:24 +00:00
Arseny Sher
a6c1fdcaf6 Try to fix test_crafted_wal_end flakiness.
Postgres can always write some more WAL, so previous checks that WAL doesn't
change after something had been crafted were wrong; remove them. Add comments
here and there.

should fix https://github.com/neondatabase/neon/issues/4691
2024-03-25 14:53:06 +03:00
John Spray
adb0526262 pageserver: track total ephemeral layer bytes (#7182)
## Problem

Large quantities of ephemeral layer data can lead to excessive memory
consumption (https://github.com/neondatabase/neon/issues/6939). We
currently don't have a way to know how much ephemeral layer data is
present on a pageserver.

Before we can add new behaviors to proactively roll layers in response
to too much ephemeral data, we must calculate that total.

Related: https://github.com/neondatabase/neon/issues/6916

## Summary of changes

- Create GlobalResources and GlobalResourceUnits types, where timelines
carry a GlobalResourceUnits in their TimelineWriterState.
- Periodically update the size in GlobalResourceUnits:
  - During tick()
  - During layer roll
- During put() if the latest value has drifted more than 10MB since our
last update
- Expose the value of the global ephemeral layer bytes counter as a
prometheus metric.
- Extend the lifetime of TimelineWriterState:
  - Instead of dropping it in TimelineWriter::drop, let it remain.
- Drop TimelineWriterState in roll_layer: this drops our guard on the
global byte count to reflect the fact that we're freezing the layer.
- Ensure the validity of the later in the writer state by clearing the
state in the same place we freeze layers, and asserting on the
write-ability of the layer in `writer()`
- Add a 'context' parameter to `get_open_layer_action` so that it can
skip the prev_lsn==lsn check when called in tick() -- this is needed
because now tick is called with a populated state, where
prev_lsn==Some(lsn) is true for an idle timeline.
- Extend layer rolling test to use this metric
2024-03-25 11:52:50 +00:00
John Spray
0099dfa56b storage controller: tighten up secrets handling (#7105)
- Remove code for using AWS secrets manager, as we're deploying with
k8s->env vars instead
- Load each secret independently, so that one can mix CLI args with
environment variables, rather than requiring that all secrets are loaded
with the same mechanism.
- Add a 'strict mode', enabled by default, which will refuse to start if
secrets are not loaded. This avoids the risk of accidentially disabling
auth by omitting the public key, for example
2024-03-25 11:52:33 +00:00
Vlad Lazar
3a4ebfb95d test: fix test_pageserver_recovery flakyness (#7207)
## Problem
We recently introduced log file validation for the storage controller.
The heartbeater will WARN when it fails
for a node, hence the test fails.

Closes https://github.com/neondatabase/neon/issues/7159

## Summary of changes
* Warn only once for each set of heartbeat retries
* Allow list heartbeat warns
2024-03-25 09:38:12 +00:00
Christian Schwarz
3220f830b7 pageserver: use a single tokio runtime (#6555)
Before this PR, each core had 3 executor threads from 3 different
runtimes. With this PR, we just have one runtime, with one thread per
core. Switching to a single tokio runtime should reduce that effective
over-commit of CPU and in theory help with tail latencies -- iff all
tokio tasks are well-behaved and yield to the runtime regularly.

Are All Tasks Well-Behaved? Are We Ready?
-----------------------------------------

Sadly there doesn't seem to be good out-of-the box tokio tooling to
answer this question.

We *believe* all tasks are well behaved in today's code base, as of the
switch to `virtual_file_io_engine = "tokio-epoll-uring"` in production
(https://github.com/neondatabase/aws/pull/1121).

The only remaining executor-thread-blocking code is walredo and some
filesystem namespace operations.

Filesystem namespace operations work is being tracked in #6663 and not
considered likely to actually block at this time.

Regarding walredo, it currently does a blocking `poll` for read/write to
the pipe file descriptors we use for IPC with the walredo process.
There is an ongoing experiment to make walredo async (#6628), but it
needs more time because there are surprisingly tricky trade-offs that
are articulated in that PR's description (which itself is still WIP).
What's relevant for *this* PR is that
1. walredo is always CPU-bound
2. production tail latencies for walredo request-response
(`pageserver_wal_redo_seconds_bucket`) are
  - p90: with few exceptions, low hundreds of micro-seconds
  - p95: except on very packed pageservers, below 1ms
  - p99: all below 50ms, vast majority below 1ms
  - p99.9: almost all around 50ms, rarely at >= 70ms
- [Dashboard
Link](https://neonprod.grafana.net/d/edgggcrmki3uof/2024-03-walredo-latency?orgId=1&var-ds=ZNX49CDVz&var-pXX_by_instance=0.9&var-pXX_by_instance=0.99&var-pXX_by_instance=0.95&var-adhoc=instance%7C%21%3D%7Cpageserver-30.us-west-2.aws.neon.tech&var-per_instance_pXX_max_seconds=0.0005&from=1711049688777&to=1711136088777)

The ones below 1ms are below our current threshold for when we start
thinking about yielding to the executor.
The tens of milliseconds stalls aren't great, but, not least because of
the implicit overcommit of CPU by the three runtimes, we can't be sure
whether these tens of milliseconds are inherently necessary to do the
walredo work or whether we could be faster if there was less contention
for CPU.

On the first item (walredo being always CPU-bound work): it means that
walredo processes will always compete with the executor threads.
We could yield, using async walredo, but then we hit the trade-offs
explained in that PR.

tl;dr: the risk of stalling executor threads through blocking walredo
seems low, and switching to one runtime cleans up one potential source
for higher-than-necessary stall times (explained in the previous
paragraphs).


Code Changes
------------

- Remove the 3 different runtime definitions.
- Add a new definition called `THE_RUNTIME`.
- Use it in all places that previously used one of the 3 removed
runtimes.
- Remove the argument from `task_mgr`.
- Fix failpoint usage where `pausable_failpoint!` should have been used.
We encountered some actual failures because of this, e.g., hung
`get_metric()` calls during test teardown that would client-timeout
after 300s.

As indicated by the comment above `THE_RUNTIME`, we could take this
clean-up further.
But before we create so much churn, let's first validate that there's no
perf regression.


Performance
-----------

We will test this in staging using the various nightly benchmark runs.

However, the worst-case impact of this change is likely compaction
(=>image layer creation) competing with compute requests.
Image layer creation work can't be easily generated & repeated quickly
by pagebench.
So, we'll simply watch getpage & basebackup tail latencies in staging.

Additionally, I have done manual benchmarking using pagebench.
Report:
https://neondatabase.notion.site/2024-03-23-oneruntime-change-benchmarking-22a399c411e24399a73311115fb703ec?pvs=4
Tail latencies and throughput are marginally better (no regression =
good).
Except in a workload with 128 clients against one tenant.
There, the p99.9 and p99.99 getpage latency is about 2x worse (at
slightly lower throughput).
A dip in throughput every 20s (compaction_period_ is clearly visible,
and probably responsible for that worse tail latency.
This has potential to improve with async walredo, and is an edge case
workload anyway.


Future Work
-----------

1. Once this change has shown satisfying results in production, change
the codebase to use the ambient runtime instead of explicitly
referencing `THE_RUNTIME`.
2. Have a mode where we run with a single-threaded runtime, so we
uncover executor stalls more quickly.
3. Switch or write our own failpoints library that is async-native:
https://github.com/neondatabase/neon/issues/7216
2024-03-23 19:25:11 +01:00
Conrad Ludgate
72103d481d proxy: fix stack overflow in cancel publisher (#7212)
## Problem

stack overflow in blanket impl for `CancellationPublisher`

## Summary of changes

Removes `async_trait` and fixes the impl order to make it non-recursive.
2024-03-23 06:36:58 +00:00
Alex Chi Z
643683f41a fixup(#7204 / postgres): revert IsPrimaryAlive checks (#7209)
Fix #7204.

https://github.com/neondatabase/postgres/pull/400
https://github.com/neondatabase/postgres/pull/401
https://github.com/neondatabase/postgres/pull/402

These commits never go into prod. Detailed investigation will be posted
in another issue. Reverting the commits so that things can keep running
in prod. This pull request adds the test to start two replicas. It fails
on the current main https://github.com/neondatabase/neon/pull/7210 but
passes in this pull request.

---------

Signed-off-by: Alex Chi Z <chi@neon.tech>
2024-03-23 01:01:51 +00:00
Konstantin Knizhnik
35f4c04c9b Remove Get/SetZenithCurrentClusterSize from Postgres core (#7196)
## Problem

See https://neondb.slack.com/archives/C04DGM6SMTM/p1711003752072899

## Summary of changes

Move keeping of cluster size to neon extension

---------

Co-authored-by: Konstantin Knizhnik <knizhnik@neon.tech>
2024-03-22 13:14:31 -04:00
John Spray
1787cf19e3 pageserver: write consumption metrics to S3 (#7200)
## Problem

The service that receives consumption metrics has lower availability
than S3. Writing metrics to S3 improves their availability.

Closes: https://github.com/neondatabase/cloud/issues/9824

## Summary of changes

- The same data as consumption metrics POST bodies is also compressed
and written to an S3 object with a timestamp-formatted path.
- Set `metric_collection_bucket` (same format as `remote_storage`
config) to configure the location to write to
2024-03-22 14:52:14 +00:00
Alexander Bayandin
2668a1dfab CI: deploy release version to a preprod region (#6811)
## Problem

We want to deploy releases to a preprod region first to perform required
checks

## Summary of changes
- Deploy `release-XXX` / `release-proxy-YYY` docker tags to a preprod region
2024-03-22 14:42:10 +00:00
Conrad Ludgate
77f3a30440 proxy: unit tests for auth_quirks (#7199)
## Problem

I noticed code coverage for auth_quirks was pretty bare

## Summary of changes

Adds 3 happy path unit tests for auth_quirks
* scram
* cleartext (websockets)
* cleartext (password hack)
2024-03-22 13:31:10 +00:00
John Spray
62b318c928 Fix ephemeral file warning on secondaries (#7201)
A test was added which exercises secondary locations more, and there was
a location in the secondary downloader that warned on ephemeral files.

This was intended to be fixed in this faulty commit:
8cea866adf
2024-03-22 10:10:28 +00:00
Anna Khanova
6770ddba2e proxy: connect redis with AWS IAM (#7189)
## Problem

Support of IAM Roles for Service Accounts for authentication.

## Summary of changes

* Obtain aws 15m-long credentials
* Retrieve redis password from credentials
* Update every 1h to keep connection for more than 12h
* For now allow to have different endpoints for pubsub/stream redis.

TODOs: 
* PubSub doesn't support credentials refresh, consider using stream
instead.
* We need an AWS role for proxy to be able to connect to both: S3 and
elasticache.

Credentials obtaining and connection refresh was tested on xenon
preview.

https://github.com/neondatabase/cloud/issues/10365
2024-03-22 09:38:04 +01:00
Arpad Müller
3ee34a3f26 Update Rust to 1.77.0 (#7198)
Release notes: https://blog.rust-lang.org/2024/03/21/Rust-1.77.0.html

Thanks to #6886 the diff is reasonable, only for one new lint
`clippy::suspicious_open_options`. I added `truncate()` calls to the
places where it is obviously the right choice to me, and added allows
everywhere else, leaving it for followups.

I had to specify cargo install --locked because the build would fail otherwise.
This was also recommended by upstream.
2024-03-22 06:52:31 +00:00
Christian Schwarz
fb60278e02 walredo benchmark: throughput-oriented rewrite (#7190)
See the updated `bench_walredo.rs` module comment.

tl;dr: we measure avg latency of single redo operations issues against a
single redo manager from N tokio tasks.

part of https://github.com/neondatabase/neon/issues/6628
2024-03-21 15:24:56 +01:00
Conrad Ludgate
d5304337cf proxy: simplify password validation (#7188)
## Problem

for HTTP/WS/password hack flows we imitate SCRAM to validate passwords.
This code was unnecessarily complicated.

## Summary of changes

Copy in the `pbkdf2` and 'derive keys' steps from the
`postgres_protocol` crate in our `rust-postgres` fork. Derive the
`client_key`, `server_key` and `stored_key` from the password directly.
Use constant time equality to compare the `stored_key` and `server_key`
with the ones we are sent from cplane.
2024-03-21 13:54:06 +00:00
John Spray
06cb582d91 pageserver: extend /re-attach response to include tenant mode (#6941)
This change improves the resilience of the system to unclean restarts.

Previously, re-attach responses only included attached tenants
- If the pageserver had local state for a secondary location, it would
remain, but with no guarantee that it was still _meant_ to be there.
After this change, the pageserver will only retain secondary locations
if the /re-attach response indicates that they should still be there.
- If the pageserver had local state for an attached location that was
omitted from a re-attach response, it would be entirely detached. This
is wasteful in a typical HA setup, where an offline node's tenants might
have been re-attached elsewhere before it restarts, but the offline
node's location should revert to a secondary location rather than being
wiped. Including secondary tenants in the re-attach response enables the
pageserver to avoid throwing away local state unnecessarily.

In this PR:
- The re-attach items are extended with a 'mode' field.
- Storage controller populates 'mode'
- Pageserver interprets it (default is attached if missing) to construct
either a SecondaryTenant or a Tenant.
- A new test exercises both cases.
2024-03-21 13:39:23 +00:00
John Spray
bb47d536fb pageserver: quieten log on shutdown-while-attaching (#7177)
## Problem

If a shutdown happens when a tenant is attaching, we were logging at
ERROR severity and with a backtrace. Yuck.

## Summary of changes

- Pass a flag into `make_broken` to enable quietening this non-scary
case.
2024-03-21 12:56:13 +00:00
John Spray
59cdee749e storage controller: fixes to secondary location handling (#7169)
Stacks on:
- https://github.com/neondatabase/neon/pull/7165

Fixes while working on background optimization of scheduling after a
split:
- When a tenant has secondary locations, we weren't detaching the parent
shards' secondary locations when doing a split
- When a reconciler detaches a location, it was feeding back a
locationconf with `Detached` mode in its `observed` object, whereas it
should omit that location. This could cause the background reconcile
task to keep kicking off no-op reconcilers forever (harmless but
annoying).
- During shard split, we were scheduling secondary locations for the
child shards, but no reconcile was run for these until the next time the
background reconcile task ran. Creating these ASAP is useful, because
they'll be used shortly after a shard split as the destination locations
for migrating the new shards to different nodes.
2024-03-21 12:06:57 +00:00
Vlad Lazar
c75b584430 storage_controller: add metrics (#7178)
## Problem
Storage controller had basically no metrics.

## Summary of changes
1. Migrate the existing metrics to use Conrad's
[`measured`](https://docs.rs/measured/0.0.14/measured/) crate.
2. Add metrics for incoming http requests
3. Add metrics for outgoing http requests to the pageserver
4. Add metrics for outgoing pass through requests to the pageserver
5. Add metrics for database queries

Note that the metrics response for the attachment service does not use
chunked encoding like the rest of the metrics endpoints. Conrad has
kindly extended the crate such that it can now be done. Let's leave it
for a follow-up since the payload shouldn't be that big at this point.

Fixes https://github.com/neondatabase/neon/issues/6875
2024-03-21 12:00:20 +00:00
111 changed files with 3529 additions and 1336 deletions

View File

@@ -1121,10 +1121,16 @@ jobs:
run: |
if [[ "$GITHUB_REF_NAME" == "main" ]]; then
gh workflow --repo neondatabase/aws run deploy-dev.yml --ref main -f branch=main -f dockerTag=${{needs.tag.outputs.build-tag}} -f deployPreprodRegion=false
# TODO: move deployPreprodRegion to release (`"$GITHUB_REF_NAME" == "release"` block), once Staging support different compute tag prefixes for different regions
gh workflow --repo neondatabase/aws run deploy-dev.yml --ref main -f branch=main -f dockerTag=${{needs.tag.outputs.build-tag}} -f deployPreprodRegion=true
elif [[ "$GITHUB_REF_NAME" == "release" ]]; then
gh workflow --repo neondatabase/aws run deploy-dev.yml --ref main \
-f deployPgSniRouter=false \
-f deployProxy=false \
-f deployStorage=true \
-f deployStorageBroker=true \
-f branch=main \
-f dockerTag=${{needs.tag.outputs.build-tag}} \
-f deployPreprodRegion=true
gh workflow --repo neondatabase/aws run deploy-prod.yml --ref main \
-f deployPgSniRouter=false \
-f deployProxy=false \
@@ -1133,6 +1139,15 @@ jobs:
-f branch=main \
-f dockerTag=${{needs.tag.outputs.build-tag}}
elif [[ "$GITHUB_REF_NAME" == "release-proxy" ]]; then
gh workflow --repo neondatabase/aws run deploy-dev.yml --ref main \
-f deployPgSniRouter=true \
-f deployProxy=true \
-f deployStorage=false \
-f deployStorageBroker=false \
-f branch=main \
-f dockerTag=${{needs.tag.outputs.build-tag}} \
-f deployPreprodRegion=true
gh workflow --repo neondatabase/aws run deploy-proxy-prod.yml --ref main \
-f deployPgSniRouter=true \
-f deployProxy=true \

176
Cargo.lock generated
View File

@@ -276,7 +276,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"aws-config",
"aws-sdk-secretsmanager",
"bytes",
"camino",
"clap",
"control_plane",
@@ -288,6 +288,8 @@ dependencies = [
"hex",
"humantime",
"hyper",
"lasso",
"measured",
"metrics",
"once_cell",
"pageserver_api",
@@ -295,6 +297,7 @@ dependencies = [
"postgres_connection",
"r2d2",
"reqwest",
"routerify",
"serde",
"serde_json",
"thiserror",
@@ -343,9 +346,9 @@ dependencies = [
[[package]]
name = "aws-credential-types"
version = "1.1.4"
version = "1.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33cc49dcdd31c8b6e79850a179af4c367669150c7ac0135f176c61bec81a70f7"
checksum = "fa8587ae17c8e967e4b05a62d495be2fb7701bec52a97f7acfe8a29f938384c8"
dependencies = [
"aws-smithy-async",
"aws-smithy-runtime-api",
@@ -355,9 +358,9 @@ dependencies = [
[[package]]
name = "aws-runtime"
version = "1.1.4"
version = "1.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb031bff99877c26c28895766f7bb8484a05e24547e370768d6cc9db514662aa"
checksum = "b13dc54b4b49f8288532334bba8f87386a40571c47c37b1304979b556dc613c8"
dependencies = [
"aws-credential-types",
"aws-sigv4",
@@ -377,6 +380,29 @@ dependencies = [
"uuid",
]
[[package]]
name = "aws-sdk-iam"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8ae76026bfb1b80a6aed0bb400c1139cd9c0563e26bce1986cd021c6a968c7b"
dependencies = [
"aws-credential-types",
"aws-runtime",
"aws-smithy-async",
"aws-smithy-http",
"aws-smithy-json",
"aws-smithy-query",
"aws-smithy-runtime",
"aws-smithy-runtime-api",
"aws-smithy-types",
"aws-smithy-xml",
"aws-types",
"http 0.2.9",
"once_cell",
"regex-lite",
"tracing",
]
[[package]]
name = "aws-sdk-s3"
version = "1.14.0"
@@ -406,29 +432,6 @@ dependencies = [
"url",
]
[[package]]
name = "aws-sdk-secretsmanager"
version = "1.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0a0b64e61e7d632d9df90a2e0f32630c68c24960cab1d27d848718180af883d3"
dependencies = [
"aws-credential-types",
"aws-runtime",
"aws-smithy-async",
"aws-smithy-http",
"aws-smithy-json",
"aws-smithy-runtime",
"aws-smithy-runtime-api",
"aws-smithy-types",
"aws-types",
"bytes",
"fastrand 2.0.0",
"http 0.2.9",
"once_cell",
"regex-lite",
"tracing",
]
[[package]]
name = "aws-sdk-sso"
version = "1.12.0"
@@ -498,9 +501,9 @@ dependencies = [
[[package]]
name = "aws-sigv4"
version = "1.1.4"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c371c6b0ac54d4605eb6f016624fb5c7c2925d315fdf600ac1bf21b19d5f1742"
checksum = "11d6f29688a4be9895c0ba8bef861ad0c0dac5c15e9618b9b7a6c233990fc263"
dependencies = [
"aws-credential-types",
"aws-smithy-eventstream",
@@ -513,7 +516,7 @@ dependencies = [
"hex",
"hmac",
"http 0.2.9",
"http 1.0.0",
"http 1.1.0",
"once_cell",
"p256",
"percent-encoding",
@@ -527,9 +530,9 @@ dependencies = [
[[package]]
name = "aws-smithy-async"
version = "1.1.4"
version = "1.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72ee2d09cce0ef3ae526679b522835d63e75fb427aca5413cd371e490d52dcc6"
checksum = "d26ea8fa03025b2face2b3038a63525a10891e3d8829901d502e5384a0d8cd46"
dependencies = [
"futures-util",
"pin-project-lite",
@@ -570,9 +573,9 @@ dependencies = [
[[package]]
name = "aws-smithy-http"
version = "0.60.4"
version = "0.60.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dab56aea3cd9e1101a0a999447fb346afb680ab1406cebc44b32346e25b4117d"
checksum = "3f10fa66956f01540051b0aa7ad54574640f748f9839e843442d99b970d3aff9"
dependencies = [
"aws-smithy-eventstream",
"aws-smithy-runtime-api",
@@ -591,18 +594,18 @@ dependencies = [
[[package]]
name = "aws-smithy-json"
version = "0.60.4"
version = "0.60.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd3898ca6518f9215f62678870064398f00031912390efd03f1f6ef56d83aa8e"
checksum = "4683df9469ef09468dad3473d129960119a0d3593617542b7d52086c8486f2d6"
dependencies = [
"aws-smithy-types",
]
[[package]]
name = "aws-smithy-query"
version = "0.60.4"
version = "0.60.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bda4b1dfc9810e35fba8a620e900522cd1bd4f9578c446e82f49d1ce41d2e9f9"
checksum = "f2fbd61ceb3fe8a1cb7352e42689cec5335833cd9f94103a61e98f9bb61c64bb"
dependencies = [
"aws-smithy-types",
"urlencoding",
@@ -610,9 +613,9 @@ dependencies = [
[[package]]
name = "aws-smithy-runtime"
version = "1.1.4"
version = "1.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fafdab38f40ad7816e7da5dec279400dd505160780083759f01441af1bbb10ea"
checksum = "ec81002d883e5a7fd2bb063d6fb51c4999eb55d404f4fff3dd878bf4733b9f01"
dependencies = [
"aws-smithy-async",
"aws-smithy-http",
@@ -635,14 +638,15 @@ dependencies = [
[[package]]
name = "aws-smithy-runtime-api"
version = "1.1.4"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c18276dd28852f34b3bf501f4f3719781f4999a51c7bff1a5c6dc8c4529adc29"
checksum = "9acb931e0adaf5132de878f1398d83f8677f90ba70f01f65ff87f6d7244be1c5"
dependencies = [
"aws-smithy-async",
"aws-smithy-types",
"bytes",
"http 0.2.9",
"http 1.1.0",
"pin-project-lite",
"tokio",
"tracing",
@@ -651,9 +655,9 @@ dependencies = [
[[package]]
name = "aws-smithy-types"
version = "1.1.4"
version = "1.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb3e134004170d3303718baa2a4eb4ca64ee0a1c0a7041dca31b38be0fb414f3"
checksum = "abe14dceea1e70101d38fbf2a99e6a34159477c0fb95e68e05c66bd7ae4c3729"
dependencies = [
"base64-simd",
"bytes",
@@ -674,18 +678,18 @@ dependencies = [
[[package]]
name = "aws-smithy-xml"
version = "0.60.4"
version = "0.60.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8604a11b25e9ecaf32f9aa56b9fe253c5e2f606a3477f0071e96d3155a5ed218"
checksum = "872c68cf019c0e4afc5de7753c4f7288ce4b71663212771bf5e4542eb9346ca9"
dependencies = [
"xmlparser",
]
[[package]]
name = "aws-types"
version = "1.1.4"
version = "1.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "789bbe008e65636fe1b6dbbb374c40c8960d1232b96af5ff4aec349f9c4accf4"
checksum = "0dbf2f3da841a8930f159163175cf6a3d16ddde517c1b0fba7aa776822800f40"
dependencies = [
"aws-credential-types",
"aws-smithy-async",
@@ -2392,9 +2396,9 @@ dependencies = [
[[package]]
name = "http"
version = "1.0.0"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b32afd38673a8016f7c9ae69e5af41a58f81b1d31689040f2f1959594ce194ea"
checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258"
dependencies = [
"bytes",
"fnv",
@@ -2494,7 +2498,7 @@ dependencies = [
"hyper",
"log",
"rustls 0.21.9",
"rustls-native-certs",
"rustls-native-certs 0.6.2",
"tokio",
"tokio-rustls 0.24.0",
]
@@ -2880,6 +2884,35 @@ version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771"
[[package]]
name = "measured"
version = "0.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f246648d027839a34b420e27c7de1165ace96e19ef894985d0a6ff89a7840a9f"
dependencies = [
"bytes",
"hashbrown 0.14.0",
"itoa",
"lasso",
"measured-derive",
"memchr",
"parking_lot 0.12.1",
"rustc-hash",
"ryu",
]
[[package]]
name = "measured-derive"
version = "0.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edaa5cc22d99d5d6d7d99c3b5b5f7e7f8034c22f1b5d62a1adecd2ed005d9b80"
dependencies = [
"heck",
"proc-macro2",
"quote",
"syn 2.0.52",
]
[[package]]
name = "memchr"
version = "2.6.4"
@@ -4166,6 +4199,10 @@ version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"aws-config",
"aws-sdk-iam",
"aws-sigv4",
"aws-types",
"base64 0.13.1",
"bstr",
"bytes",
@@ -4176,6 +4213,7 @@ dependencies = [
"consumption_metrics",
"dashmap",
"env_logger",
"fallible-iterator",
"futures",
"git-version",
"hashbrown 0.13.2",
@@ -4183,6 +4221,7 @@ dependencies = [
"hex",
"hmac",
"hostname",
"http 1.1.0",
"humantime",
"hyper",
"hyper-tungstenite",
@@ -4226,6 +4265,7 @@ dependencies = [
"smallvec",
"smol_str",
"socket2 0.5.5",
"subtle",
"sync_wrapper",
"task-local-extensions",
"thiserror",
@@ -4397,9 +4437,9 @@ dependencies = [
[[package]]
name = "redis"
version = "0.24.0"
version = "0.25.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c580d9cbbe1d1b479e8d67cf9daf6a62c957e6846048408b80b43ac3f6af84cd"
checksum = "71d64e978fd98a0e6b105d066ba4889a7301fca65aeac850a877d8797343feeb"
dependencies = [
"async-trait",
"bytes",
@@ -4408,15 +4448,15 @@ dependencies = [
"itoa",
"percent-encoding",
"pin-project-lite",
"rustls 0.21.9",
"rustls-native-certs",
"rustls-pemfile 1.0.2",
"rustls-webpki 0.101.7",
"rustls 0.22.2",
"rustls-native-certs 0.7.0",
"rustls-pemfile 2.1.1",
"rustls-pki-types",
"ryu",
"sha1_smol",
"socket2 0.4.9",
"socket2 0.5.5",
"tokio",
"tokio-rustls 0.24.0",
"tokio-rustls 0.25.0",
"tokio-util",
"url",
]
@@ -4845,6 +4885,19 @@ dependencies = [
"security-framework",
]
[[package]]
name = "rustls-native-certs"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f1fb85efa936c42c6d5fc28d2629bb51e4b2f4b8a5211e297d599cc5a093792"
dependencies = [
"openssl-probe",
"rustls-pemfile 2.1.1",
"rustls-pki-types",
"schannel",
"security-framework",
]
[[package]]
name = "rustls-pemfile"
version = "1.0.2"
@@ -6112,7 +6165,7 @@ dependencies = [
"percent-encoding",
"pin-project",
"prost",
"rustls-native-certs",
"rustls-native-certs 0.6.2",
"rustls-pemfile 1.0.2",
"tokio",
"tokio-rustls 0.24.0",
@@ -6997,7 +7050,6 @@ dependencies = [
"aws-sigv4",
"aws-smithy-async",
"aws-smithy-http",
"aws-smithy-runtime-api",
"aws-smithy-types",
"axum",
"base64 0.21.1",

View File

@@ -52,10 +52,12 @@ async-stream = "0.3"
async-trait = "0.1"
aws-config = { version = "1.1.4", default-features = false, features=["rustls"] }
aws-sdk-s3 = "1.14"
aws-sdk-secretsmanager = { version = "1.14.0" }
aws-sdk-iam = "1.15.0"
aws-smithy-async = { version = "1.1.4", default-features = false, features=["rt-tokio"] }
aws-smithy-types = "1.1.4"
aws-credential-types = "1.1.4"
aws-sigv4 = { version = "1.2.0", features = ["sign-http"] }
aws-types = "1.1.7"
axum = { version = "0.6.20", features = ["ws"] }
base64 = "0.13.0"
bincode = "1.3"
@@ -76,6 +78,7 @@ either = "1.8"
enum-map = "2.4.2"
enumset = "1.0.12"
fail = "0.5.0"
fallible-iterator = "0.2"
fs2 = "0.4.3"
futures = "0.3"
futures-core = "0.3"
@@ -88,6 +91,7 @@ hex = "0.4"
hex-literal = "0.4"
hmac = "0.12.1"
hostname = "0.3.1"
http = {version = "1.1.0", features = ["std"]}
http-types = { version = "2", default-features = false }
humantime = "2.1"
humantime-serde = "1.1.1"
@@ -101,6 +105,7 @@ lasso = "0.7"
leaky-bucket = "1.0.1"
libc = "0.2"
md5 = "0.7.0"
measured = { version = "0.0.13", features=["default", "lasso"] }
memoffset = "0.8"
native-tls = "0.2"
nix = { version = "0.27", features = ["fs", "process", "socket", "signal", "poll"] }
@@ -120,7 +125,7 @@ procfs = "0.14"
prometheus = {version = "0.13", default_features=false, features = ["process"]} # removes protobuf dependency
prost = "0.11"
rand = "0.8"
redis = { version = "0.24.0", features = ["tokio-rustls-comp", "keep-alive"] }
redis = { version = "0.25.2", features = ["tokio-rustls-comp", "keep-alive"] }
regex = "1.10.2"
reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"] }
reqwest-tracing = { version = "0.4.7", features = ["opentelemetry_0_20"] }
@@ -148,6 +153,7 @@ smol_str = { version = "0.2.0", features = ["serde"] }
socket2 = "0.5"
strum = "0.24"
strum_macros = "0.24"
"subtle" = "2.5.0"
svg_fmt = "0.4.1"
sync_wrapper = "0.1.2"
tar = "0.4"

View File

@@ -135,7 +135,7 @@ WORKDIR /home/nonroot
# Rust
# Please keep the version of llvm (installed above) in sync with rust llvm (`rustc --version --verbose | grep LLVM`)
ENV RUSTC_VERSION=1.76.0
ENV RUSTC_VERSION=1.77.0
ENV RUSTUP_HOME="/home/nonroot/.rustup"
ENV PATH="/home/nonroot/.cargo/bin:${PATH}"
RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux-gnu/rustup-init && whoami && \
@@ -149,7 +149,7 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux
cargo install --git https://github.com/paritytech/cachepot && \
cargo install rustfilt && \
cargo install cargo-hakari && \
cargo install cargo-deny && \
cargo install cargo-deny --locked && \
cargo install cargo-hack && \
cargo install cargo-nextest && \
rm -rf /home/nonroot/.cargo/registry && \

View File

@@ -17,6 +17,7 @@ pub fn line_in_file(path: &Path, line: &str) -> Result<bool> {
.write(true)
.create(true)
.append(false)
.truncate(false)
.open(path)?;
let buf = io::BufReader::new(&file);
let mut count: usize = 0;

View File

@@ -16,7 +16,7 @@ testing = []
[dependencies]
anyhow.workspace = true
aws-config.workspace = true
aws-sdk-secretsmanager.workspace = true
bytes.workspace = true
camino.workspace = true
clap.workspace = true
fail.workspace = true
@@ -25,17 +25,20 @@ git-version.workspace = true
hex.workspace = true
hyper.workspace = true
humantime.workspace = true
lasso.workspace = true
once_cell.workspace = true
pageserver_api.workspace = true
pageserver_client.workspace = true
postgres_connection.workspace = true
reqwest.workspace = true
routerify.workspace = true
serde.workspace = true
serde_json.workspace = true
thiserror.workspace = true
tokio.workspace = true
tokio-util.workspace = true
tracing.workspace = true
measured.workspace = true
diesel = { version = "2.1.4", features = ["serde_json", "postgres", "r2d2"] }
diesel_migrations = { version = "2.1.0" }

View File

@@ -139,7 +139,7 @@ impl HeartbeaterTask {
.with_client_retries(
|client| async move { client.get_utilization().await },
&jwt_token,
2,
3,
3,
Duration::from_secs(1),
&cancel,

View File

@@ -1,5 +1,11 @@
use crate::metrics::{
HttpRequestLatencyLabelGroup, HttpRequestStatusLabelGroup, PageserverRequestLabelGroup,
METRICS_REGISTRY,
};
use crate::reconciler::ReconcileError;
use crate::service::{Service, STARTUP_RECONCILE_TIMEOUT};
use futures::Future;
use hyper::header::CONTENT_TYPE;
use hyper::{Body, Request, Response};
use hyper::{StatusCode, Uri};
use pageserver_api::models::{
@@ -34,6 +40,8 @@ use pageserver_api::upcall_api::{ReAttachRequest, ValidateRequest};
use control_plane::storage_controller::{AttachHookRequest, InspectRequest};
use routerify::Middleware;
/// State available to HTTP request handlers
#[derive(Clone)]
pub struct HttpState {
@@ -313,7 +321,7 @@ async fn handle_tenant_timeline_passthrough(
tracing::info!("Proxying request for tenant {} ({})", tenant_id, path);
// Find the node that holds shard zero
let (base_url, tenant_shard_id) = service.tenant_shard0_baseurl(tenant_id)?;
let (node, tenant_shard_id) = service.tenant_shard0_node(tenant_id)?;
// Callers will always pass an unsharded tenant ID. Before proxying, we must
// rewrite this to a shard-aware shard zero ID.
@@ -322,12 +330,39 @@ async fn handle_tenant_timeline_passthrough(
let tenant_shard_str = format!("{}", tenant_shard_id);
let path = path.replace(&tenant_str, &tenant_shard_str);
let client = mgmt_api::Client::new(base_url, service.get_config().jwt_token.as_deref());
let latency = &METRICS_REGISTRY
.metrics_group
.storage_controller_passthrough_request_latency;
// This is a bit awkward. We remove the param from the request
// and join the words by '_' to get a label for the request.
let just_path = path.replace(&tenant_shard_str, "");
let path_label = just_path
.split('/')
.filter(|token| !token.is_empty())
.collect::<Vec<_>>()
.join("_");
let labels = PageserverRequestLabelGroup {
pageserver_id: &node.get_id().to_string(),
path: &path_label,
method: crate::metrics::Method::Get,
};
let _timer = latency.start_timer(labels.clone());
let client = mgmt_api::Client::new(node.base_url(), service.get_config().jwt_token.as_deref());
let resp = client.get_raw(path).await.map_err(|_e|
// FIXME: give APiError a proper Unavailable variant. We return 503 here because
// if we can't successfully send a request to the pageserver, we aren't available.
ApiError::ShuttingDown)?;
if !resp.status().is_success() {
let error_counter = &METRICS_REGISTRY
.metrics_group
.storage_controller_passthrough_request_error;
error_counter.inc(labels);
}
// We have a reqest::Response, would like a http::Response
let mut builder = hyper::Response::builder()
.status(resp.status())
@@ -498,7 +533,11 @@ impl From<ReconcileError> for ApiError {
/// Common wrapper for request handlers that call into Service and will operate on tenants: they must only
/// be allowed to run if Service has finished its initial reconciliation.
async fn tenant_service_handler<R, H>(request: Request<Body>, handler: H) -> R::Output
async fn tenant_service_handler<R, H>(
request: Request<Body>,
handler: H,
request_name: RequestName,
) -> R::Output
where
R: std::future::Future<Output = Result<Response<Body>, ApiError>> + Send + 'static,
H: FnOnce(Arc<Service>, Request<Body>) -> R + Send + Sync + 'static,
@@ -518,9 +557,10 @@ where
));
}
request_span(
named_request_span(
request,
|request| async move { handler(service, request).await },
request_name,
)
.await
}
@@ -531,11 +571,98 @@ fn check_permissions(request: &Request<Body>, required_scope: Scope) -> Result<(
})
}
#[derive(Clone, Debug)]
struct RequestMeta {
method: hyper::http::Method,
at: Instant,
}
fn prologue_metrics_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
) -> Middleware<B, ApiError> {
Middleware::pre(move |req| async move {
let meta = RequestMeta {
method: req.method().clone(),
at: Instant::now(),
};
req.set_context(meta);
Ok(req)
})
}
fn epilogue_metrics_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
) -> Middleware<B, ApiError> {
Middleware::post_with_info(move |resp, req_info| async move {
let request_name = match req_info.context::<RequestName>() {
Some(name) => name,
None => {
return Ok(resp);
}
};
if let Some(meta) = req_info.context::<RequestMeta>() {
let status = &crate::metrics::METRICS_REGISTRY
.metrics_group
.storage_controller_http_request_status;
let latency = &crate::metrics::METRICS_REGISTRY
.metrics_group
.storage_controller_http_request_latency;
status.inc(HttpRequestStatusLabelGroup {
path: request_name.0,
method: meta.method.clone().into(),
status: crate::metrics::StatusCode(resp.status()),
});
latency.observe(
HttpRequestLatencyLabelGroup {
path: request_name.0,
method: meta.method.into(),
},
meta.at.elapsed().as_secs_f64(),
);
}
Ok(resp)
})
}
pub async fn measured_metrics_handler(_req: Request<Body>) -> Result<Response<Body>, ApiError> {
pub const TEXT_FORMAT: &str = "text/plain; version=0.0.4";
let payload = crate::metrics::METRICS_REGISTRY.encode();
let response = Response::builder()
.status(200)
.header(CONTENT_TYPE, TEXT_FORMAT)
.body(payload.into())
.unwrap();
Ok(response)
}
#[derive(Clone)]
struct RequestName(&'static str);
async fn named_request_span<R, H>(
request: Request<Body>,
handler: H,
name: RequestName,
) -> R::Output
where
R: Future<Output = Result<Response<Body>, ApiError>> + Send + 'static,
H: FnOnce(Request<Body>) -> R + Send + Sync + 'static,
{
request.set_context(name);
request_span(request, handler).await
}
pub fn make_router(
service: Arc<Service>,
auth: Option<Arc<SwappableJwtAuth>>,
) -> RouterBuilder<hyper::Body, ApiError> {
let mut router = endpoint::make_router();
let mut router = endpoint::make_router()
.middleware(prologue_metrics_middleware())
.middleware(epilogue_metrics_middleware());
if auth.is_some() {
router = router.middleware(auth_middleware(|request| {
let state = get_state(request);
@@ -544,99 +671,166 @@ pub fn make_router(
} else {
state.auth.as_deref()
}
}))
}));
}
router
.data(Arc::new(HttpState::new(service, auth)))
.get("/metrics", |r| {
named_request_span(r, measured_metrics_handler, RequestName("metrics"))
})
// Non-prefixed generic endpoints (status, metrics)
.get("/status", |r| request_span(r, handle_status))
.get("/ready", |r| request_span(r, handle_ready))
.get("/status", |r| {
named_request_span(r, handle_status, RequestName("status"))
})
.get("/ready", |r| {
named_request_span(r, handle_ready, RequestName("ready"))
})
// Upcalls for the pageserver: point the pageserver's `control_plane_api` config to this prefix
.post("/upcall/v1/re-attach", |r| {
request_span(r, handle_re_attach)
named_request_span(r, handle_re_attach, RequestName("upcall_v1_reattach"))
})
.post("/upcall/v1/validate", |r| {
named_request_span(r, handle_validate, RequestName("upcall_v1_validate"))
})
.post("/upcall/v1/validate", |r| request_span(r, handle_validate))
// Test/dev/debug endpoints
.post("/debug/v1/attach-hook", |r| {
request_span(r, handle_attach_hook)
named_request_span(r, handle_attach_hook, RequestName("debug_v1_attach_hook"))
})
.post("/debug/v1/inspect", |r| {
named_request_span(r, handle_inspect, RequestName("debug_v1_inspect"))
})
.post("/debug/v1/inspect", |r| request_span(r, handle_inspect))
.post("/debug/v1/tenant/:tenant_id/drop", |r| {
request_span(r, handle_tenant_drop)
named_request_span(r, handle_tenant_drop, RequestName("debug_v1_tenant_drop"))
})
.post("/debug/v1/node/:node_id/drop", |r| {
request_span(r, handle_node_drop)
named_request_span(r, handle_node_drop, RequestName("debug_v1_node_drop"))
})
.get("/debug/v1/tenant", |r| {
named_request_span(r, handle_tenants_dump, RequestName("debug_v1_tenant"))
})
.get("/debug/v1/tenant", |r| request_span(r, handle_tenants_dump))
.get("/debug/v1/tenant/:tenant_id/locate", |r| {
tenant_service_handler(r, handle_tenant_locate)
tenant_service_handler(
r,
handle_tenant_locate,
RequestName("debug_v1_tenant_locate"),
)
})
.get("/debug/v1/scheduler", |r| {
request_span(r, handle_scheduler_dump)
named_request_span(r, handle_scheduler_dump, RequestName("debug_v1_scheduler"))
})
.post("/debug/v1/consistency_check", |r| {
request_span(r, handle_consistency_check)
named_request_span(
r,
handle_consistency_check,
RequestName("debug_v1_consistency_check"),
)
})
.put("/debug/v1/failpoints", |r| {
request_span(r, |r| failpoints_handler(r, CancellationToken::new()))
})
// Node operations
.post("/control/v1/node", |r| {
request_span(r, handle_node_register)
named_request_span(r, handle_node_register, RequestName("control_v1_node"))
})
.get("/control/v1/node", |r| {
named_request_span(r, handle_node_list, RequestName("control_v1_node"))
})
.get("/control/v1/node", |r| request_span(r, handle_node_list))
.put("/control/v1/node/:node_id/config", |r| {
request_span(r, handle_node_configure)
named_request_span(
r,
handle_node_configure,
RequestName("control_v1_node_config"),
)
})
// Tenant Shard operations
.put("/control/v1/tenant/:tenant_shard_id/migrate", |r| {
tenant_service_handler(r, handle_tenant_shard_migrate)
tenant_service_handler(
r,
handle_tenant_shard_migrate,
RequestName("control_v1_tenant_migrate"),
)
})
.put("/control/v1/tenant/:tenant_id/shard_split", |r| {
tenant_service_handler(r, handle_tenant_shard_split)
tenant_service_handler(
r,
handle_tenant_shard_split,
RequestName("control_v1_tenant_shard_split"),
)
})
.get("/control/v1/tenant/:tenant_id", |r| {
tenant_service_handler(r, handle_tenant_describe)
tenant_service_handler(
r,
handle_tenant_describe,
RequestName("control_v1_tenant_describe"),
)
})
// Tenant operations
// The ^/v1/ endpoints act as a "Virtual Pageserver", enabling shard-naive clients to call into
// this service to manage tenants that actually consist of many tenant shards, as if they are a single entity.
.post("/v1/tenant", |r| {
tenant_service_handler(r, handle_tenant_create)
tenant_service_handler(r, handle_tenant_create, RequestName("v1_tenant"))
})
.delete("/v1/tenant/:tenant_id", |r| {
tenant_service_handler(r, handle_tenant_delete)
tenant_service_handler(r, handle_tenant_delete, RequestName("v1_tenant"))
})
.put("/v1/tenant/config", |r| {
tenant_service_handler(r, handle_tenant_config_set)
tenant_service_handler(r, handle_tenant_config_set, RequestName("v1_tenant_config"))
})
.get("/v1/tenant/:tenant_id/config", |r| {
tenant_service_handler(r, handle_tenant_config_get)
tenant_service_handler(r, handle_tenant_config_get, RequestName("v1_tenant_config"))
})
.put("/v1/tenant/:tenant_shard_id/location_config", |r| {
tenant_service_handler(r, handle_tenant_location_config)
tenant_service_handler(
r,
handle_tenant_location_config,
RequestName("v1_tenant_location_config"),
)
})
.put("/v1/tenant/:tenant_id/time_travel_remote_storage", |r| {
tenant_service_handler(r, handle_tenant_time_travel_remote_storage)
tenant_service_handler(
r,
handle_tenant_time_travel_remote_storage,
RequestName("v1_tenant_time_travel_remote_storage"),
)
})
.post("/v1/tenant/:tenant_id/secondary/download", |r| {
tenant_service_handler(r, handle_tenant_secondary_download)
tenant_service_handler(
r,
handle_tenant_secondary_download,
RequestName("v1_tenant_secondary_download"),
)
})
// Timeline operations
.delete("/v1/tenant/:tenant_id/timeline/:timeline_id", |r| {
tenant_service_handler(r, handle_tenant_timeline_delete)
tenant_service_handler(
r,
handle_tenant_timeline_delete,
RequestName("v1_tenant_timeline"),
)
})
.post("/v1/tenant/:tenant_id/timeline", |r| {
tenant_service_handler(r, handle_tenant_timeline_create)
tenant_service_handler(
r,
handle_tenant_timeline_create,
RequestName("v1_tenant_timeline"),
)
})
// Tenant detail GET passthrough to shard zero
.get("/v1/tenant/:tenant_id", |r| {
tenant_service_handler(r, handle_tenant_timeline_passthrough)
tenant_service_handler(
r,
handle_tenant_timeline_passthrough,
RequestName("v1_tenant_passthrough"),
)
})
// Timeline GET passthrough to shard zero. Note that the `*` in the URL is a wildcard: any future
// timeline GET APIs will be implicitly included.
.get("/v1/tenant/:tenant_id/timeline*", |r| {
tenant_service_handler(r, handle_tenant_timeline_passthrough)
tenant_service_handler(
r,
handle_tenant_timeline_passthrough,
RequestName("v1_tenant_timeline_passthrough"),
)
})
}

View File

@@ -8,6 +8,7 @@ pub mod http;
mod id_lock_map;
pub mod metrics;
mod node;
mod pageserver_client;
pub mod persistence;
mod reconciler;
mod scheduler;

View File

@@ -3,7 +3,6 @@ use attachment_service::http::make_router;
use attachment_service::metrics::preinitialize_metrics;
use attachment_service::persistence::Persistence;
use attachment_service::service::{Config, Service, MAX_UNAVAILABLE_INTERVAL_DEFAULT};
use aws_config::{BehaviorVersion, Region};
use camino::Utf8PathBuf;
use clap::Parser;
use diesel::Connection;
@@ -55,11 +54,31 @@ struct Cli {
#[arg(long)]
database_url: Option<String>,
/// Flag to enable dev mode, which permits running without auth
#[arg(long, default_value = "false")]
dev: bool,
/// Grace period before marking unresponsive pageserver offline
#[arg(long)]
max_unavailable_interval: Option<humantime::Duration>,
}
enum StrictMode {
/// In strict mode, we will require that all secrets are loaded, i.e. security features
/// may not be implicitly turned off by omitting secrets in the environment.
Strict,
/// In dev mode, secrets are optional, and omitting a particular secret will implicitly
/// disable the auth related to it (e.g. no pageserver jwt key -> send unauthenticated
/// requests, no public key -> don't authenticate incoming requests).
Dev,
}
impl Default for StrictMode {
fn default() -> Self {
Self::Strict
}
}
/// Secrets may either be provided on the command line (for testing), or loaded from AWS SecretManager: this
/// type encapsulates the logic to decide which and do the loading.
struct Secrets {
@@ -70,13 +89,6 @@ struct Secrets {
}
impl Secrets {
const DATABASE_URL_SECRET: &'static str = "rds-neon-storage-controller-url";
const PAGESERVER_JWT_TOKEN_SECRET: &'static str =
"neon-storage-controller-pageserver-jwt-token";
const CONTROL_PLANE_JWT_TOKEN_SECRET: &'static str =
"neon-storage-controller-control-plane-jwt-token";
const PUBLIC_KEY_SECRET: &'static str = "neon-storage-controller-public-key";
const DATABASE_URL_ENV: &'static str = "DATABASE_URL";
const PAGESERVER_JWT_TOKEN_ENV: &'static str = "PAGESERVER_JWT_TOKEN";
const CONTROL_PLANE_JWT_TOKEN_ENV: &'static str = "CONTROL_PLANE_JWT_TOKEN";
@@ -87,111 +99,41 @@ impl Secrets {
/// - Environment variables if DATABASE_URL is set.
/// - AWS Secrets Manager secrets
async fn load(args: &Cli) -> anyhow::Result<Self> {
match &args.database_url {
Some(url) => Self::load_cli(url, args),
None => match std::env::var(Self::DATABASE_URL_ENV) {
Ok(database_url) => Self::load_env(database_url),
Err(_) => Self::load_aws_sm().await,
},
}
}
fn load_env(database_url: String) -> anyhow::Result<Self> {
let public_key = match std::env::var(Self::PUBLIC_KEY_ENV) {
Ok(public_key) => Some(JwtAuth::from_key(public_key).context("Loading public key")?),
Err(_) => None,
};
Ok(Self {
database_url,
public_key,
jwt_token: std::env::var(Self::PAGESERVER_JWT_TOKEN_ENV).ok(),
control_plane_jwt_token: std::env::var(Self::CONTROL_PLANE_JWT_TOKEN_ENV).ok(),
})
}
async fn load_aws_sm() -> anyhow::Result<Self> {
let Ok(region) = std::env::var("AWS_REGION") else {
anyhow::bail!("AWS_REGION is not set, cannot load secrets automatically: either set this, or use CLI args to supply secrets");
};
let config = aws_config::defaults(BehaviorVersion::v2023_11_09())
.region(Region::new(region.clone()))
.load()
.await;
let asm = aws_sdk_secretsmanager::Client::new(&config);
let Some(database_url) = asm
.get_secret_value()
.secret_id(Self::DATABASE_URL_SECRET)
.send()
.await?
.secret_string()
.map(str::to_string)
let Some(database_url) =
Self::load_secret(&args.database_url, Self::DATABASE_URL_ENV).await
else {
anyhow::bail!(
"Database URL secret not found at {region}/{}",
Self::DATABASE_URL_SECRET
"Database URL is not set (set `--database-url`, or `DATABASE_URL` environment)"
)
};
let jwt_token = asm
.get_secret_value()
.secret_id(Self::PAGESERVER_JWT_TOKEN_SECRET)
.send()
.await?
.secret_string()
.map(str::to_string);
if jwt_token.is_none() {
tracing::warn!("No pageserver JWT token set: this will only work if authentication is disabled on the pageserver");
}
let control_plane_jwt_token = asm
.get_secret_value()
.secret_id(Self::CONTROL_PLANE_JWT_TOKEN_SECRET)
.send()
.await?
.secret_string()
.map(str::to_string);
if jwt_token.is_none() {
tracing::warn!("No control plane JWT token set: this will only work if authentication is disabled on the pageserver");
}
let public_key = asm
.get_secret_value()
.secret_id(Self::PUBLIC_KEY_SECRET)
.send()
.await?
.secret_string()
.map(str::to_string);
let public_key = match public_key {
Some(key) => Some(JwtAuth::from_key(key)?),
None => {
tracing::warn!(
"No public key set: inccoming HTTP requests will not be authenticated"
);
None
}
let public_key = match Self::load_secret(&args.public_key, Self::PUBLIC_KEY_ENV).await {
Some(v) => Some(JwtAuth::from_key(v).context("Loading public key")?),
None => None,
};
Ok(Self {
let this = Self {
database_url,
public_key,
jwt_token,
control_plane_jwt_token,
})
jwt_token: Self::load_secret(&args.jwt_token, Self::PAGESERVER_JWT_TOKEN_ENV).await,
control_plane_jwt_token: Self::load_secret(
&args.control_plane_jwt_token,
Self::CONTROL_PLANE_JWT_TOKEN_ENV,
)
.await,
};
Ok(this)
}
fn load_cli(database_url: &str, args: &Cli) -> anyhow::Result<Self> {
let public_key = match &args.public_key {
None => None,
Some(key) => Some(JwtAuth::from_key(key.clone()).context("Loading public key")?),
};
Ok(Self {
database_url: database_url.to_owned(),
public_key,
jwt_token: args.jwt_token.clone(),
control_plane_jwt_token: args.control_plane_jwt_token.clone(),
})
async fn load_secret(cli: &Option<String>, env_name: &str) -> Option<String> {
if let Some(v) = cli {
Some(v.clone())
} else if let Ok(v) = std::env::var(env_name) {
Some(v)
} else {
None
}
}
}
@@ -247,8 +189,42 @@ async fn async_main() -> anyhow::Result<()> {
args.listen
);
let strict_mode = if args.dev {
StrictMode::Dev
} else {
StrictMode::Strict
};
let secrets = Secrets::load(&args).await?;
// Validate required secrets and arguments are provided in strict mode
match strict_mode {
StrictMode::Strict
if (secrets.public_key.is_none()
|| secrets.jwt_token.is_none()
|| secrets.control_plane_jwt_token.is_none()) =>
{
// Production systems should always have secrets configured: if public_key was not set
// then we would implicitly disable auth.
anyhow::bail!(
"Insecure config! One or more secrets is not set. This is only permitted in `--dev` mode"
);
}
StrictMode::Strict if args.compute_hook_url.is_none() => {
// Production systems should always have a compute hook set, to prevent falling
// back to trying to use neon_local.
anyhow::bail!(
"`--compute-hook-url` is not set: this is only permitted in `--dev` mode"
);
}
StrictMode::Strict => {
tracing::info!("Starting in strict mode: configuration is OK.")
}
StrictMode::Dev => {
tracing::warn!("Starting in dev mode: this may be an insecure configuration.")
}
}
let config = Config {
jwt_token: secrets.jwt_token,
control_plane_jwt_token: secrets.control_plane_jwt_token,

View File

@@ -1,32 +1,284 @@
use metrics::{register_int_counter, register_int_counter_vec, IntCounter, IntCounterVec};
//!
//! This module provides metric definitions for the storage controller.
//!
//! All metrics are grouped in [`StorageControllerMetricGroup`]. [`StorageControllerMetrics`] holds
//! the mentioned metrics and their encoder. It's globally available via the [`METRICS_REGISTRY`]
//! constant.
//!
//! The rest of the code defines label group types and deals with converting outer types to labels.
//!
use bytes::Bytes;
use measured::{
label::{LabelValue, StaticLabelSet},
FixedCardinalityLabel, MetricGroup,
};
use once_cell::sync::Lazy;
use std::sync::Mutex;
pub(crate) struct ReconcilerMetrics {
pub(crate) spawned: IntCounter,
pub(crate) complete: IntCounterVec,
}
use crate::persistence::{DatabaseError, DatabaseOperation};
impl ReconcilerMetrics {
// Labels used on [`Self::complete`]
pub(crate) const SUCCESS: &'static str = "ok";
pub(crate) const ERROR: &'static str = "success";
pub(crate) const CANCEL: &'static str = "cancel";
}
pub(crate) static RECONCILER: Lazy<ReconcilerMetrics> = Lazy::new(|| ReconcilerMetrics {
spawned: register_int_counter!(
"storage_controller_reconcile_spawn",
"Count of how many times we spawn a reconcile task",
)
.expect("failed to define a metric"),
complete: register_int_counter_vec!(
"storage_controller_reconcile_complete",
"Reconciler tasks completed, broken down by success/failure/cancelled",
&["status"],
)
.expect("failed to define a metric"),
});
pub(crate) static METRICS_REGISTRY: Lazy<StorageControllerMetrics> =
Lazy::new(StorageControllerMetrics::default);
pub fn preinitialize_metrics() {
Lazy::force(&RECONCILER);
Lazy::force(&METRICS_REGISTRY);
}
pub(crate) struct StorageControllerMetrics {
pub(crate) metrics_group: StorageControllerMetricGroup,
encoder: Mutex<measured::text::TextEncoder>,
}
#[derive(measured::MetricGroup)]
pub(crate) struct StorageControllerMetricGroup {
/// Count of how many times we spawn a reconcile task
pub(crate) storage_controller_reconcile_spawn: measured::Counter,
/// Reconciler tasks completed, broken down by success/failure/cancelled
pub(crate) storage_controller_reconcile_complete:
measured::CounterVec<ReconcileCompleteLabelGroupSet>,
/// HTTP request status counters for handled requests
pub(crate) storage_controller_http_request_status:
measured::CounterVec<HttpRequestStatusLabelGroupSet>,
/// HTTP request handler latency across all status codes
pub(crate) storage_controller_http_request_latency:
measured::HistogramVec<HttpRequestLatencyLabelGroupSet, 5>,
/// Count of HTTP requests to the pageserver that resulted in an error,
/// broken down by the pageserver node id, request name and method
pub(crate) storage_controller_pageserver_request_error:
measured::CounterVec<PageserverRequestLabelGroupSet>,
/// Latency of HTTP requests to the pageserver, broken down by pageserver
/// node id, request name and method. This include both successful and unsuccessful
/// requests.
pub(crate) storage_controller_pageserver_request_latency:
measured::HistogramVec<PageserverRequestLabelGroupSet, 5>,
/// Count of pass-through HTTP requests to the pageserver that resulted in an error,
/// broken down by the pageserver node id, request name and method
pub(crate) storage_controller_passthrough_request_error:
measured::CounterVec<PageserverRequestLabelGroupSet>,
/// Latency of pass-through HTTP requests to the pageserver, broken down by pageserver
/// node id, request name and method. This include both successful and unsuccessful
/// requests.
pub(crate) storage_controller_passthrough_request_latency:
measured::HistogramVec<PageserverRequestLabelGroupSet, 5>,
/// Count of errors in database queries, broken down by error type and operation.
pub(crate) storage_controller_database_query_error:
measured::CounterVec<DatabaseQueryErrorLabelGroupSet>,
/// Latency of database queries, broken down by operation.
pub(crate) storage_controller_database_query_latency:
measured::HistogramVec<DatabaseQueryLatencyLabelGroupSet, 5>,
}
impl StorageControllerMetrics {
pub(crate) fn encode(&self) -> Bytes {
let mut encoder = self.encoder.lock().unwrap();
self.metrics_group.collect_into(&mut *encoder);
encoder.finish()
}
}
impl Default for StorageControllerMetrics {
fn default() -> Self {
Self {
metrics_group: StorageControllerMetricGroup::new(),
encoder: Mutex::new(measured::text::TextEncoder::new()),
}
}
}
impl StorageControllerMetricGroup {
pub(crate) fn new() -> Self {
Self {
storage_controller_reconcile_spawn: measured::Counter::new(),
storage_controller_reconcile_complete: measured::CounterVec::new(
ReconcileCompleteLabelGroupSet {
status: StaticLabelSet::new(),
},
),
storage_controller_http_request_status: measured::CounterVec::new(
HttpRequestStatusLabelGroupSet {
path: lasso::ThreadedRodeo::new(),
method: StaticLabelSet::new(),
status: StaticLabelSet::new(),
},
),
storage_controller_http_request_latency: measured::HistogramVec::new(
measured::metric::histogram::Thresholds::exponential_buckets(0.1, 2.0),
),
storage_controller_pageserver_request_error: measured::CounterVec::new(
PageserverRequestLabelGroupSet {
pageserver_id: lasso::ThreadedRodeo::new(),
path: lasso::ThreadedRodeo::new(),
method: StaticLabelSet::new(),
},
),
storage_controller_pageserver_request_latency: measured::HistogramVec::new(
measured::metric::histogram::Thresholds::exponential_buckets(0.1, 2.0),
),
storage_controller_passthrough_request_error: measured::CounterVec::new(
PageserverRequestLabelGroupSet {
pageserver_id: lasso::ThreadedRodeo::new(),
path: lasso::ThreadedRodeo::new(),
method: StaticLabelSet::new(),
},
),
storage_controller_passthrough_request_latency: measured::HistogramVec::new(
measured::metric::histogram::Thresholds::exponential_buckets(0.1, 2.0),
),
storage_controller_database_query_error: measured::CounterVec::new(
DatabaseQueryErrorLabelGroupSet {
operation: StaticLabelSet::new(),
error_type: StaticLabelSet::new(),
},
),
storage_controller_database_query_latency: measured::HistogramVec::new(
measured::metric::histogram::Thresholds::exponential_buckets(0.1, 2.0),
),
}
}
}
#[derive(measured::LabelGroup)]
#[label(set = ReconcileCompleteLabelGroupSet)]
pub(crate) struct ReconcileCompleteLabelGroup {
pub(crate) status: ReconcileOutcome,
}
#[derive(measured::LabelGroup)]
#[label(set = HttpRequestStatusLabelGroupSet)]
pub(crate) struct HttpRequestStatusLabelGroup<'a> {
#[label(dynamic_with = lasso::ThreadedRodeo)]
pub(crate) path: &'a str,
pub(crate) method: Method,
pub(crate) status: StatusCode,
}
#[derive(measured::LabelGroup)]
#[label(set = HttpRequestLatencyLabelGroupSet)]
pub(crate) struct HttpRequestLatencyLabelGroup<'a> {
#[label(dynamic_with = lasso::ThreadedRodeo)]
pub(crate) path: &'a str,
pub(crate) method: Method,
}
impl Default for HttpRequestLatencyLabelGroupSet {
fn default() -> Self {
Self {
path: lasso::ThreadedRodeo::new(),
method: StaticLabelSet::new(),
}
}
}
#[derive(measured::LabelGroup, Clone)]
#[label(set = PageserverRequestLabelGroupSet)]
pub(crate) struct PageserverRequestLabelGroup<'a> {
#[label(dynamic_with = lasso::ThreadedRodeo)]
pub(crate) pageserver_id: &'a str,
#[label(dynamic_with = lasso::ThreadedRodeo)]
pub(crate) path: &'a str,
pub(crate) method: Method,
}
impl Default for PageserverRequestLabelGroupSet {
fn default() -> Self {
Self {
pageserver_id: lasso::ThreadedRodeo::new(),
path: lasso::ThreadedRodeo::new(),
method: StaticLabelSet::new(),
}
}
}
#[derive(measured::LabelGroup)]
#[label(set = DatabaseQueryErrorLabelGroupSet)]
pub(crate) struct DatabaseQueryErrorLabelGroup {
pub(crate) error_type: DatabaseErrorLabel,
pub(crate) operation: DatabaseOperation,
}
#[derive(measured::LabelGroup)]
#[label(set = DatabaseQueryLatencyLabelGroupSet)]
pub(crate) struct DatabaseQueryLatencyLabelGroup {
pub(crate) operation: DatabaseOperation,
}
#[derive(FixedCardinalityLabel)]
pub(crate) enum ReconcileOutcome {
#[label(rename = "ok")]
Success,
Error,
Cancel,
}
#[derive(FixedCardinalityLabel, Clone)]
pub(crate) enum Method {
Get,
Put,
Post,
Delete,
Other,
}
impl From<hyper::Method> for Method {
fn from(value: hyper::Method) -> Self {
if value == hyper::Method::GET {
Method::Get
} else if value == hyper::Method::PUT {
Method::Put
} else if value == hyper::Method::POST {
Method::Post
} else if value == hyper::Method::DELETE {
Method::Delete
} else {
Method::Other
}
}
}
pub(crate) struct StatusCode(pub(crate) hyper::http::StatusCode);
impl LabelValue for StatusCode {
fn visit<V: measured::label::LabelVisitor>(&self, v: V) -> V::Output {
v.write_int(self.0.as_u16() as u64)
}
}
impl FixedCardinalityLabel for StatusCode {
fn cardinality() -> usize {
(100..1000).len()
}
fn encode(&self) -> usize {
self.0.as_u16() as usize
}
fn decode(value: usize) -> Self {
Self(hyper::http::StatusCode::from_u16(u16::try_from(value).unwrap()).unwrap())
}
}
#[derive(FixedCardinalityLabel)]
pub(crate) enum DatabaseErrorLabel {
Query,
Connection,
ConnectionPool,
Logical,
}
impl DatabaseError {
pub(crate) fn error_label(&self) -> DatabaseErrorLabel {
match self {
Self::Query(_) => DatabaseErrorLabel::Query,
Self::Connection(_) => DatabaseErrorLabel::Connection,
Self::ConnectionPool(_) => DatabaseErrorLabel::ConnectionPool,
Self::Logical(_) => DatabaseErrorLabel::Logical,
}
}
}

View File

@@ -12,7 +12,9 @@ use serde::Serialize;
use tokio_util::sync::CancellationToken;
use utils::{backoff, id::NodeId};
use crate::{persistence::NodePersistence, scheduler::MaySchedule};
use crate::{
pageserver_client::PageserverClient, persistence::NodePersistence, scheduler::MaySchedule,
};
/// Represents the in-memory description of a Node.
///
@@ -202,7 +204,7 @@ impl Node {
cancel: &CancellationToken,
) -> Option<mgmt_api::Result<T>>
where
O: FnMut(mgmt_api::Client) -> F,
O: FnMut(PageserverClient) -> F,
F: std::future::Future<Output = mgmt_api::Result<T>>,
{
fn is_fatal(e: &mgmt_api::Error) -> bool {
@@ -224,8 +226,12 @@ impl Node {
.build()
.expect("Failed to construct HTTP client");
let client =
mgmt_api::Client::from_client(http_client, self.base_url(), jwt.as_deref());
let client = PageserverClient::from_client(
self.get_id(),
http_client,
self.base_url(),
jwt.as_deref(),
);
let node_cancel_fut = self.cancel.cancelled();

View File

@@ -0,0 +1,203 @@
use pageserver_api::{
models::{
LocationConfig, LocationConfigListResponse, PageserverUtilization, SecondaryProgress,
TenantShardSplitRequest, TenantShardSplitResponse, TimelineCreateRequest, TimelineInfo,
},
shard::TenantShardId,
};
use pageserver_client::mgmt_api::{Client, Result};
use reqwest::StatusCode;
use utils::id::{NodeId, TimelineId};
/// Thin wrapper around [`pageserver_client::mgmt_api::Client`]. It allows the storage
/// controller to collect metrics in a non-intrusive manner.
#[derive(Debug, Clone)]
pub(crate) struct PageserverClient {
inner: Client,
node_id_label: String,
}
macro_rules! measured_request {
($name:literal, $method:expr, $node_id: expr, $invoke:expr) => {{
let labels = crate::metrics::PageserverRequestLabelGroup {
pageserver_id: $node_id,
path: $name,
method: $method,
};
let latency = &crate::metrics::METRICS_REGISTRY
.metrics_group
.storage_controller_pageserver_request_latency;
let _timer_guard = latency.start_timer(labels.clone());
let res = $invoke;
if res.is_err() {
let error_counters = &crate::metrics::METRICS_REGISTRY
.metrics_group
.storage_controller_pageserver_request_error;
error_counters.inc(labels)
}
res
}};
}
impl PageserverClient {
pub(crate) fn new(node_id: NodeId, mgmt_api_endpoint: String, jwt: Option<&str>) -> Self {
Self {
inner: Client::from_client(reqwest::Client::new(), mgmt_api_endpoint, jwt),
node_id_label: node_id.0.to_string(),
}
}
pub(crate) fn from_client(
node_id: NodeId,
raw_client: reqwest::Client,
mgmt_api_endpoint: String,
jwt: Option<&str>,
) -> Self {
Self {
inner: Client::from_client(raw_client, mgmt_api_endpoint, jwt),
node_id_label: node_id.0.to_string(),
}
}
pub(crate) async fn tenant_delete(&self, tenant_shard_id: TenantShardId) -> Result<StatusCode> {
measured_request!(
"tenant",
crate::metrics::Method::Delete,
&self.node_id_label,
self.inner.tenant_delete(tenant_shard_id).await
)
}
pub(crate) async fn tenant_time_travel_remote_storage(
&self,
tenant_shard_id: TenantShardId,
timestamp: &str,
done_if_after: &str,
) -> Result<()> {
measured_request!(
"tenant_time_travel_remote_storage",
crate::metrics::Method::Put,
&self.node_id_label,
self.inner
.tenant_time_travel_remote_storage(tenant_shard_id, timestamp, done_if_after)
.await
)
}
pub(crate) async fn tenant_secondary_download(
&self,
tenant_id: TenantShardId,
wait: Option<std::time::Duration>,
) -> Result<(StatusCode, SecondaryProgress)> {
measured_request!(
"tenant_secondary_download",
crate::metrics::Method::Post,
&self.node_id_label,
self.inner.tenant_secondary_download(tenant_id, wait).await
)
}
pub(crate) async fn location_config(
&self,
tenant_shard_id: TenantShardId,
config: LocationConfig,
flush_ms: Option<std::time::Duration>,
lazy: bool,
) -> Result<()> {
measured_request!(
"location_config",
crate::metrics::Method::Put,
&self.node_id_label,
self.inner
.location_config(tenant_shard_id, config, flush_ms, lazy)
.await
)
}
pub(crate) async fn list_location_config(&self) -> Result<LocationConfigListResponse> {
measured_request!(
"location_configs",
crate::metrics::Method::Get,
&self.node_id_label,
self.inner.list_location_config().await
)
}
pub(crate) async fn get_location_config(
&self,
tenant_shard_id: TenantShardId,
) -> Result<Option<LocationConfig>> {
measured_request!(
"location_config",
crate::metrics::Method::Get,
&self.node_id_label,
self.inner.get_location_config(tenant_shard_id).await
)
}
pub(crate) async fn timeline_create(
&self,
tenant_shard_id: TenantShardId,
req: &TimelineCreateRequest,
) -> Result<TimelineInfo> {
measured_request!(
"timeline",
crate::metrics::Method::Post,
&self.node_id_label,
self.inner.timeline_create(tenant_shard_id, req).await
)
}
pub(crate) async fn timeline_delete(
&self,
tenant_shard_id: TenantShardId,
timeline_id: TimelineId,
) -> Result<StatusCode> {
measured_request!(
"timeline",
crate::metrics::Method::Delete,
&self.node_id_label,
self.inner
.timeline_delete(tenant_shard_id, timeline_id)
.await
)
}
pub(crate) async fn tenant_shard_split(
&self,
tenant_shard_id: TenantShardId,
req: TenantShardSplitRequest,
) -> Result<TenantShardSplitResponse> {
measured_request!(
"tenant_shard_split",
crate::metrics::Method::Put,
&self.node_id_label,
self.inner.tenant_shard_split(tenant_shard_id, req).await
)
}
pub(crate) async fn timeline_list(
&self,
tenant_shard_id: &TenantShardId,
) -> Result<Vec<TimelineInfo>> {
measured_request!(
"timelines",
crate::metrics::Method::Get,
&self.node_id_label,
self.inner.timeline_list(tenant_shard_id).await
)
}
pub(crate) async fn get_utilization(&self) -> Result<PageserverUtilization> {
measured_request!(
"utilization",
crate::metrics::Method::Get,
&self.node_id_label,
self.inner.get_utilization().await
)
}
}

View File

@@ -19,6 +19,9 @@ use serde::{Deserialize, Serialize};
use utils::generation::Generation;
use utils::id::{NodeId, TenantId};
use crate::metrics::{
DatabaseQueryErrorLabelGroup, DatabaseQueryLatencyLabelGroup, METRICS_REGISTRY,
};
use crate::node::Node;
/// ## What do we store?
@@ -75,6 +78,25 @@ pub(crate) enum DatabaseError {
Logical(String),
}
#[derive(measured::FixedCardinalityLabel, Clone)]
pub(crate) enum DatabaseOperation {
InsertNode,
UpdateNode,
DeleteNode,
ListNodes,
BeginShardSplit,
CompleteShardSplit,
AbortShardSplit,
Detach,
ReAttach,
IncrementGeneration,
ListTenantShards,
InsertTenantShards,
UpdateTenantShard,
DeleteTenant,
UpdateTenantConfig,
}
#[must_use]
pub(crate) enum AbortShardSplitStatus {
/// We aborted the split in the database by reverting to the parent shards
@@ -115,6 +137,34 @@ impl Persistence {
}
}
/// Wraps `with_conn` in order to collect latency and error metrics
async fn with_measured_conn<F, R>(&self, op: DatabaseOperation, func: F) -> DatabaseResult<R>
where
F: Fn(&mut PgConnection) -> DatabaseResult<R> + Send + 'static,
R: Send + 'static,
{
let latency = &METRICS_REGISTRY
.metrics_group
.storage_controller_database_query_latency;
let _timer = latency.start_timer(DatabaseQueryLatencyLabelGroup {
operation: op.clone(),
});
let res = self.with_conn(func).await;
if let Err(err) = &res {
let error_counter = &METRICS_REGISTRY
.metrics_group
.storage_controller_database_query_error;
error_counter.inc(DatabaseQueryErrorLabelGroup {
error_type: err.error_label(),
operation: op,
})
}
res
}
/// Call the provided function in a tokio blocking thread, with a Diesel database connection.
async fn with_conn<F, R>(&self, func: F) -> DatabaseResult<R>
where
@@ -130,21 +180,27 @@ impl Persistence {
/// When a node is first registered, persist it before using it for anything
pub(crate) async fn insert_node(&self, node: &Node) -> DatabaseResult<()> {
let np = node.to_persistent();
self.with_conn(move |conn| -> DatabaseResult<()> {
diesel::insert_into(crate::schema::nodes::table)
.values(&np)
.execute(conn)?;
Ok(())
})
self.with_measured_conn(
DatabaseOperation::InsertNode,
move |conn| -> DatabaseResult<()> {
diesel::insert_into(crate::schema::nodes::table)
.values(&np)
.execute(conn)?;
Ok(())
},
)
.await
}
/// At startup, populate the list of nodes which our shards may be placed on
pub(crate) async fn list_nodes(&self) -> DatabaseResult<Vec<NodePersistence>> {
let nodes: Vec<NodePersistence> = self
.with_conn(move |conn| -> DatabaseResult<_> {
Ok(crate::schema::nodes::table.load::<NodePersistence>(conn)?)
})
.with_measured_conn(
DatabaseOperation::ListNodes,
move |conn| -> DatabaseResult<_> {
Ok(crate::schema::nodes::table.load::<NodePersistence>(conn)?)
},
)
.await?;
tracing::info!("list_nodes: loaded {} nodes", nodes.len());
@@ -159,7 +215,7 @@ impl Persistence {
) -> DatabaseResult<()> {
use crate::schema::nodes::dsl::*;
let updated = self
.with_conn(move |conn| {
.with_measured_conn(DatabaseOperation::UpdateNode, move |conn| {
let updated = diesel::update(nodes)
.filter(node_id.eq(input_node_id.0 as i64))
.set((scheduling_policy.eq(String::from(input_scheduling)),))
@@ -181,9 +237,12 @@ impl Persistence {
/// be enriched at runtime with state discovered on pageservers.
pub(crate) async fn list_tenant_shards(&self) -> DatabaseResult<Vec<TenantShardPersistence>> {
let loaded = self
.with_conn(move |conn| -> DatabaseResult<_> {
Ok(crate::schema::tenant_shards::table.load::<TenantShardPersistence>(conn)?)
})
.with_measured_conn(
DatabaseOperation::ListTenantShards,
move |conn| -> DatabaseResult<_> {
Ok(crate::schema::tenant_shards::table.load::<TenantShardPersistence>(conn)?)
},
)
.await?;
if loaded.is_empty() {
@@ -260,17 +319,20 @@ impl Persistence {
shards: Vec<TenantShardPersistence>,
) -> DatabaseResult<()> {
use crate::schema::tenant_shards::dsl::*;
self.with_conn(move |conn| -> DatabaseResult<()> {
conn.transaction(|conn| -> QueryResult<()> {
for tenant in &shards {
diesel::insert_into(tenant_shards)
.values(tenant)
.execute(conn)?;
}
self.with_measured_conn(
DatabaseOperation::InsertTenantShards,
move |conn| -> DatabaseResult<()> {
conn.transaction(|conn| -> QueryResult<()> {
for tenant in &shards {
diesel::insert_into(tenant_shards)
.values(tenant)
.execute(conn)?;
}
Ok(())
})?;
Ok(())
})?;
Ok(())
})
},
)
.await
}
@@ -278,25 +340,31 @@ impl Persistence {
/// the tenant from memory on this server.
pub(crate) async fn delete_tenant(&self, del_tenant_id: TenantId) -> DatabaseResult<()> {
use crate::schema::tenant_shards::dsl::*;
self.with_conn(move |conn| -> DatabaseResult<()> {
diesel::delete(tenant_shards)
.filter(tenant_id.eq(del_tenant_id.to_string()))
.execute(conn)?;
self.with_measured_conn(
DatabaseOperation::DeleteTenant,
move |conn| -> DatabaseResult<()> {
diesel::delete(tenant_shards)
.filter(tenant_id.eq(del_tenant_id.to_string()))
.execute(conn)?;
Ok(())
})
Ok(())
},
)
.await
}
pub(crate) async fn delete_node(&self, del_node_id: NodeId) -> DatabaseResult<()> {
use crate::schema::nodes::dsl::*;
self.with_conn(move |conn| -> DatabaseResult<()> {
diesel::delete(nodes)
.filter(node_id.eq(del_node_id.0 as i64))
.execute(conn)?;
self.with_measured_conn(
DatabaseOperation::DeleteNode,
move |conn| -> DatabaseResult<()> {
diesel::delete(nodes)
.filter(node_id.eq(del_node_id.0 as i64))
.execute(conn)?;
Ok(())
})
Ok(())
},
)
.await
}
@@ -310,7 +378,7 @@ impl Persistence {
) -> DatabaseResult<HashMap<TenantShardId, Generation>> {
use crate::schema::tenant_shards::dsl::*;
let updated = self
.with_conn(move |conn| {
.with_measured_conn(DatabaseOperation::ReAttach, move |conn| {
let rows_updated = diesel::update(tenant_shards)
.filter(generation_pageserver.eq(node_id.0 as i64))
.set(generation.eq(generation + 1))
@@ -360,7 +428,7 @@ impl Persistence {
) -> anyhow::Result<Generation> {
use crate::schema::tenant_shards::dsl::*;
let updated = self
.with_conn(move |conn| {
.with_measured_conn(DatabaseOperation::IncrementGeneration, move |conn| {
let updated = diesel::update(tenant_shards)
.filter(tenant_id.eq(tenant_shard_id.tenant_id.to_string()))
.filter(shard_number.eq(tenant_shard_id.shard_number.0 as i32))
@@ -404,7 +472,7 @@ impl Persistence {
) -> DatabaseResult<()> {
use crate::schema::tenant_shards::dsl::*;
self.with_conn(move |conn| {
self.with_measured_conn(DatabaseOperation::UpdateTenantShard, move |conn| {
let query = diesel::update(tenant_shards)
.filter(tenant_id.eq(tenant_shard_id.tenant_id.to_string()))
.filter(shard_number.eq(tenant_shard_id.shard_number.0 as i32))
@@ -445,7 +513,7 @@ impl Persistence {
) -> DatabaseResult<()> {
use crate::schema::tenant_shards::dsl::*;
self.with_conn(move |conn| {
self.with_measured_conn(DatabaseOperation::UpdateTenantConfig, move |conn| {
diesel::update(tenant_shards)
.filter(tenant_id.eq(input_tenant_id.to_string()))
.set((config.eq(serde_json::to_string(&input_config).unwrap()),))
@@ -460,7 +528,7 @@ impl Persistence {
pub(crate) async fn detach(&self, tenant_shard_id: TenantShardId) -> anyhow::Result<()> {
use crate::schema::tenant_shards::dsl::*;
self.with_conn(move |conn| {
self.with_measured_conn(DatabaseOperation::Detach, move |conn| {
let updated = diesel::update(tenant_shards)
.filter(tenant_id.eq(tenant_shard_id.tenant_id.to_string()))
.filter(shard_number.eq(tenant_shard_id.shard_number.0 as i32))
@@ -490,7 +558,7 @@ impl Persistence {
parent_to_children: Vec<(TenantShardId, Vec<TenantShardPersistence>)>,
) -> DatabaseResult<()> {
use crate::schema::tenant_shards::dsl::*;
self.with_conn(move |conn| -> DatabaseResult<()> {
self.with_measured_conn(DatabaseOperation::BeginShardSplit, move |conn| -> DatabaseResult<()> {
conn.transaction(|conn| -> DatabaseResult<()> {
// Mark parent shards as splitting
@@ -554,26 +622,29 @@ impl Persistence {
old_shard_count: ShardCount,
) -> DatabaseResult<()> {
use crate::schema::tenant_shards::dsl::*;
self.with_conn(move |conn| -> DatabaseResult<()> {
conn.transaction(|conn| -> QueryResult<()> {
// Drop parent shards
diesel::delete(tenant_shards)
.filter(tenant_id.eq(split_tenant_id.to_string()))
.filter(shard_count.eq(old_shard_count.literal() as i32))
.execute(conn)?;
self.with_measured_conn(
DatabaseOperation::CompleteShardSplit,
move |conn| -> DatabaseResult<()> {
conn.transaction(|conn| -> QueryResult<()> {
// Drop parent shards
diesel::delete(tenant_shards)
.filter(tenant_id.eq(split_tenant_id.to_string()))
.filter(shard_count.eq(old_shard_count.literal() as i32))
.execute(conn)?;
// Clear sharding flag
let updated = diesel::update(tenant_shards)
.filter(tenant_id.eq(split_tenant_id.to_string()))
.set((splitting.eq(0),))
.execute(conn)?;
debug_assert!(updated > 0);
// Clear sharding flag
let updated = diesel::update(tenant_shards)
.filter(tenant_id.eq(split_tenant_id.to_string()))
.set((splitting.eq(0),))
.execute(conn)?;
debug_assert!(updated > 0);
Ok(())
})?;
Ok(())
})?;
Ok(())
})
},
)
.await
}
@@ -585,40 +656,44 @@ impl Persistence {
new_shard_count: ShardCount,
) -> DatabaseResult<AbortShardSplitStatus> {
use crate::schema::tenant_shards::dsl::*;
self.with_conn(move |conn| -> DatabaseResult<AbortShardSplitStatus> {
let aborted = conn.transaction(|conn| -> DatabaseResult<AbortShardSplitStatus> {
// Clear the splitting state on parent shards
let updated = diesel::update(tenant_shards)
.filter(tenant_id.eq(split_tenant_id.to_string()))
.filter(shard_count.ne(new_shard_count.literal() as i32))
.set((splitting.eq(0),))
.execute(conn)?;
self.with_measured_conn(
DatabaseOperation::AbortShardSplit,
move |conn| -> DatabaseResult<AbortShardSplitStatus> {
let aborted =
conn.transaction(|conn| -> DatabaseResult<AbortShardSplitStatus> {
// Clear the splitting state on parent shards
let updated = diesel::update(tenant_shards)
.filter(tenant_id.eq(split_tenant_id.to_string()))
.filter(shard_count.ne(new_shard_count.literal() as i32))
.set((splitting.eq(0),))
.execute(conn)?;
// Parent shards are already gone: we cannot abort.
if updated == 0 {
return Ok(AbortShardSplitStatus::Complete);
}
// Parent shards are already gone: we cannot abort.
if updated == 0 {
return Ok(AbortShardSplitStatus::Complete);
}
// Sanity check: if parent shards were present, their cardinality should
// be less than the number of child shards.
if updated >= new_shard_count.count() as usize {
return Err(DatabaseError::Logical(format!(
"Unexpected parent shard count {updated} while aborting split to \
// Sanity check: if parent shards were present, their cardinality should
// be less than the number of child shards.
if updated >= new_shard_count.count() as usize {
return Err(DatabaseError::Logical(format!(
"Unexpected parent shard count {updated} while aborting split to \
count {new_shard_count:?} on tenant {split_tenant_id}"
)));
}
)));
}
// Erase child shards
diesel::delete(tenant_shards)
.filter(tenant_id.eq(split_tenant_id.to_string()))
.filter(shard_count.eq(new_shard_count.literal() as i32))
.execute(conn)?;
// Erase child shards
diesel::delete(tenant_shards)
.filter(tenant_id.eq(split_tenant_id.to_string()))
.filter(shard_count.eq(new_shard_count.literal() as i32))
.execute(conn)?;
Ok(AbortShardSplitStatus::Aborted)
})?;
Ok(AbortShardSplitStatus::Aborted)
})?;
Ok(aborted)
})
Ok(aborted)
},
)
.await
}
}

View File

@@ -1,3 +1,4 @@
use crate::pageserver_client::PageserverClient;
use crate::persistence::Persistence;
use crate::service;
use hyper::StatusCode;
@@ -117,6 +118,15 @@ impl Reconciler {
flush_ms: Option<Duration>,
lazy: bool,
) -> Result<(), ReconcileError> {
if !node.is_available() && config.mode == LocationConfigMode::Detached {
// Attempts to detach from offline nodes may be imitated without doing I/O: a node which is offline
// will get fully reconciled wrt the shard's intent state when it is reactivated, irrespective of
// what we put into `observed`, in [`crate::service::Service::node_activate_reconcile`]
tracing::info!("Node {node} is unavailable during detach: proceeding anyway, it will be detached on next activation");
self.observed.locations.remove(&node.get_id());
return Ok(());
}
self.observed
.locations
.insert(node.get_id(), ObservedStateLocation { conf: None });
@@ -149,9 +159,16 @@ impl Reconciler {
};
tracing::info!("location_config({node}) complete: {:?}", config);
self.observed
.locations
.insert(node.get_id(), ObservedStateLocation { conf: Some(config) });
match config.mode {
LocationConfigMode::Detached => {
self.observed.locations.remove(&node.get_id());
}
_ => {
self.observed
.locations
.insert(node.get_id(), ObservedStateLocation { conf: Some(config) });
}
}
Ok(())
}
@@ -243,8 +260,11 @@ impl Reconciler {
tenant_shard_id: TenantShardId,
node: &Node,
) -> anyhow::Result<HashMap<TimelineId, Lsn>> {
let client =
mgmt_api::Client::new(node.base_url(), self.service_config.jwt_token.as_deref());
let client = PageserverClient::new(
node.get_id(),
node.base_url(),
self.service_config.jwt_token.as_deref(),
);
let timelines = client.timeline_list(&tenant_shard_id).await?;
Ok(timelines

View File

@@ -27,6 +27,7 @@ use pageserver_api::{
models::{SecondaryProgress, TenantConfigRequest},
};
use crate::pageserver_client::PageserverClient;
use pageserver_api::{
models::{
self, LocationConfig, LocationConfigListResponse, LocationConfigMode,
@@ -209,6 +210,7 @@ struct ShardSplitParams {
new_stripe_size: Option<ShardStripeSize>,
targets: Vec<ShardSplitTarget>,
policy: PlacementPolicy,
config: TenantConfig,
shard_ident: ShardIdentity,
}
@@ -551,7 +553,11 @@ impl Service {
break;
}
let client = mgmt_api::Client::new(node.base_url(), self.config.jwt_token.as_deref());
let client = PageserverClient::new(
node.get_id(),
node.base_url(),
self.config.jwt_token.as_deref(),
);
match client
.location_config(
tenant_shard_id,
@@ -1388,7 +1394,8 @@ impl Service {
incremented_generations.len()
);
// Apply the updated generation to our in-memory state
// Apply the updated generation to our in-memory state, and
// gather discover secondary locations.
let mut locked = self.inner.write().unwrap();
let (nodes, tenants, scheduler) = locked.parts_mut();
@@ -1396,62 +1403,65 @@ impl Service {
tenants: Vec::new(),
};
for (tenant_shard_id, new_gen) in incremented_generations {
response.tenants.push(ReAttachResponseTenant {
id: tenant_shard_id,
gen: new_gen.into().unwrap(),
});
// Apply the new generation number to our in-memory state
let shard_state = tenants.get_mut(&tenant_shard_id);
let Some(shard_state) = shard_state else {
// Not fatal. This edge case requires a re-attach to happen
// between inserting a new tenant shard in to the database, and updating our in-memory
// state to know about the shard, _and_ that the state inserted to the database referenced
// a pageserver. Should never happen, but handle it rather than panicking, since it should
// be harmless.
tracing::error!(
"Shard {} is in database for node {} but not in-memory state",
tenant_shard_id,
reattach_req.node_id
);
continue;
};
// TODO: cancel/restart any running reconciliation for this tenant, it might be trying
// to call location_conf API with an old generation. Wait for cancellation to complete
// before responding to this request. Requires well implemented CancellationToken logic
// all the way to where we call location_conf. Even then, there can still be a location_conf
// request in flight over the network: TODO handle that by making location_conf API refuse
// to go backward in generations.
// If [`Persistence::re_attach`] selected this shard, it must have alread
// had a generation set.
debug_assert!(shard_state.generation.is_some());
let Some(old_gen) = shard_state.generation else {
// Should never happen: would only return incremented generation
// for a tenant that already had a non-null generation.
return Err(ApiError::InternalServerError(anyhow::anyhow!(
"Generation must be set while re-attaching"
)));
};
shard_state.generation = Some(std::cmp::max(old_gen, new_gen));
if let Some(observed) = shard_state
.observed
.locations
.get_mut(&reattach_req.node_id)
{
if let Some(conf) = observed.conf.as_mut() {
conf.generation = new_gen.into();
// Scan through all shards, applying updates for ones where we updated generation
// and identifying shards that intend to have a secondary location on this node.
for (tenant_shard_id, shard) in tenants {
if let Some(new_gen) = incremented_generations.get(tenant_shard_id) {
let new_gen = *new_gen;
response.tenants.push(ReAttachResponseTenant {
id: *tenant_shard_id,
gen: Some(new_gen.into().unwrap()),
// A tenant is only put into multi or stale modes in the middle of a [`Reconciler::live_migrate`]
// execution. If a pageserver is restarted during that process, then the reconcile pass will
// fail, and start from scratch, so it doesn't make sense for us to try and preserve
// the stale/multi states at this point.
mode: LocationConfigMode::AttachedSingle,
});
shard.generation = std::cmp::max(shard.generation, Some(new_gen));
if let Some(observed) = shard.observed.locations.get_mut(&reattach_req.node_id) {
// Why can we update `observed` even though we're not sure our response will be received
// by the pageserver? Because the pageserver will not proceed with startup until
// it has processed response: if it loses it, we'll see another request and increment
// generation again, avoiding any uncertainty about dirtiness of tenant's state.
if let Some(conf) = observed.conf.as_mut() {
conf.generation = new_gen.into();
}
} else {
// This node has no observed state for the shard: perhaps it was offline
// when the pageserver restarted. Insert a None, so that the Reconciler
// will be prompted to learn the location's state before it makes changes.
shard
.observed
.locations
.insert(reattach_req.node_id, ObservedStateLocation { conf: None });
}
} else {
// This node has no observed state for the shard: perhaps it was offline
// when the pageserver restarted. Insert a None, so that the Reconciler
// will be prompted to learn the location's state before it makes changes.
shard_state
.observed
.locations
.insert(reattach_req.node_id, ObservedStateLocation { conf: None });
}
} else if shard.intent.get_secondary().contains(&reattach_req.node_id) {
// Ordering: pageserver will not accept /location_config requests until it has
// finished processing the response from re-attach. So we can update our in-memory state
// now, and be confident that we are not stamping on the result of some later location config.
// TODO: however, we are not strictly ordered wrt ReconcileResults queue,
// so we might update observed state here, and then get over-written by some racing
// ReconcileResult. The impact is low however, since we have set state on pageserver something
// that matches intent, so worst case if we race then we end up doing a spurious reconcile.
// TODO: cancel/restart any running reconciliation for this tenant, it might be trying
// to call location_conf API with an old generation. Wait for cancellation to complete
// before responding to this request. Requires well implemented CancellationToken logic
// all the way to where we call location_conf. Even then, there can still be a location_conf
// request in flight over the network: TODO handle that by making location_conf API refuse
// to go backward in generations.
response.tenants.push(ReAttachResponseTenant {
id: *tenant_shard_id,
gen: None,
mode: LocationConfigMode::Secondary,
});
// We must not update observed, because we have no guarantee that our
// response will be received by the pageserver. This could leave it
// falsely dirty, but the resulting reconcile should be idempotent.
}
}
// We consider a node Active once we have composed a re-attach response, but we
@@ -2096,8 +2106,11 @@ impl Service {
})
.collect::<Vec<_>>();
for tenant_shard_id in shard_ids {
let client =
mgmt_api::Client::new(node.base_url(), self.config.jwt_token.as_deref());
let client = PageserverClient::new(
node.get_id(),
node.base_url(),
self.config.jwt_token.as_deref(),
);
tracing::info!("Doing time travel recovery for shard {tenant_shard_id}",);
@@ -2149,7 +2162,11 @@ impl Service {
// Issue concurrent requests to all shards' locations
let mut futs = FuturesUnordered::new();
for (tenant_shard_id, node) in targets {
let client = mgmt_api::Client::new(node.base_url(), self.config.jwt_token.as_deref());
let client = PageserverClient::new(
node.get_id(),
node.base_url(),
self.config.jwt_token.as_deref(),
);
futs.push(async move {
let result = client
.tenant_secondary_download(tenant_shard_id, wait)
@@ -2242,7 +2259,11 @@ impl Service {
// Phase 1: delete on the pageservers
let mut any_pending = false;
for (tenant_shard_id, node) in targets {
let client = mgmt_api::Client::new(node.base_url(), self.config.jwt_token.as_deref());
let client = PageserverClient::new(
node.get_id(),
node.base_url(),
self.config.jwt_token.as_deref(),
);
// TODO: this, like many other places, requires proper retry handling for 503, timeout: those should not
// surface immediately as an error to our caller.
let status = client.tenant_delete(tenant_shard_id).await.map_err(|e| {
@@ -2354,7 +2375,7 @@ impl Service {
tenant_shard_id,
create_req.new_timeline_id,
);
let client = mgmt_api::Client::new(node.base_url(), jwt.as_deref());
let client = PageserverClient::new(node.get_id(), node.base_url(), jwt.as_deref());
client
.timeline_create(tenant_shard_id, &create_req)
@@ -2478,7 +2499,7 @@ impl Service {
"Deleting timeline on shard {tenant_shard_id}/{timeline_id}, attached to node {node}",
);
let client = mgmt_api::Client::new(node.base_url(), jwt.as_deref());
let client = PageserverClient::new(node.get_id(), node.base_url(), jwt.as_deref());
client
.timeline_delete(tenant_shard_id, timeline_id)
.await
@@ -2519,11 +2540,11 @@ impl Service {
}
/// When you need to send an HTTP request to the pageserver that holds shard0 of a tenant, this
/// function looks it up and returns the url. If the tenant isn't found, returns Err(ApiError::NotFound)
pub(crate) fn tenant_shard0_baseurl(
/// function looks up and returns node. If the tenant isn't found, returns Err(ApiError::NotFound)
pub(crate) fn tenant_shard0_node(
&self,
tenant_id: TenantId,
) -> Result<(String, TenantShardId), ApiError> {
) -> Result<(Node, TenantShardId), ApiError> {
let locked = self.inner.read().unwrap();
let Some((tenant_shard_id, shard)) = locked
.tenants
@@ -2555,7 +2576,7 @@ impl Service {
)));
};
Ok((node.base_url(), *tenant_shard_id))
Ok((node.clone(), *tenant_shard_id))
}
pub(crate) fn tenant_locate(
@@ -2725,7 +2746,7 @@ impl Service {
let detach_locations: Vec<(Node, TenantShardId)> = {
let mut detach_locations = Vec::new();
let mut locked = self.inner.write().unwrap();
let (nodes, tenants, _scheduler) = locked.parts_mut();
let (nodes, tenants, scheduler) = locked.parts_mut();
for (tenant_shard_id, shard) in
tenants.range_mut(TenantShardId::tenant_range(op.tenant_id))
@@ -2758,6 +2779,13 @@ impl Service {
tracing::info!("Restoring parent shard {tenant_shard_id}");
shard.splitting = SplitState::Idle;
if let Err(e) = shard.schedule(scheduler) {
// If this shard can't be scheduled now (perhaps due to offline nodes or
// capacity issues), that must not prevent us rolling back a split. In this
// case it should be eventually scheduled in the background.
tracing::warn!("Failed to schedule {tenant_shard_id} during shard abort: {e}")
}
self.maybe_reconcile_shard(shard, nodes);
}
@@ -2849,7 +2877,7 @@ impl Service {
.map(|(shard_id, _)| *shard_id)
.collect::<Vec<_>>();
let (_nodes, tenants, scheduler) = locked.parts_mut();
let (nodes, tenants, scheduler) = locked.parts_mut();
for parent_id in parent_ids {
let child_ids = parent_id.split(new_shard_count);
@@ -2916,6 +2944,8 @@ impl Service {
// find a secondary (e.g. because cluster is overloaded).
tracing::warn!("Failed to schedule child shard {child}: {e}");
}
// In the background, attach secondary locations for the new shards
self.maybe_reconcile_shard(&mut child_state, nodes);
tenants.insert(child, child_state);
response.new_shards.push(child);
@@ -2980,6 +3010,7 @@ impl Service {
)));
let mut policy = None;
let mut config = None;
let mut shard_ident = None;
// Validate input, and calculate which shards we will create
let (old_shard_count, targets) =
@@ -3036,6 +3067,9 @@ impl Service {
if shard_ident.is_none() {
shard_ident = Some(shard.shard);
}
if config.is_none() {
config = Some(shard.config.clone());
}
if tenant_shard_id.shard_count.count() == split_req.new_shard_count {
tracing::info!(
@@ -3054,8 +3088,6 @@ impl Service {
.get(&node_id)
.expect("Pageservers may not be deleted while referenced");
// TODO: if any reconciliation is currently in progress for this shard, wait for it.
targets.push(ShardSplitTarget {
parent_id: *tenant_shard_id,
node: node.clone(),
@@ -3098,6 +3130,7 @@ impl Service {
shard_ident.unwrap()
};
let policy = policy.unwrap();
let config = config.unwrap();
Ok(ShardSplitAction::Split(ShardSplitParams {
old_shard_count,
@@ -3105,6 +3138,7 @@ impl Service {
new_stripe_size: split_req.new_stripe_size,
targets,
policy,
config,
shard_ident,
}))
}
@@ -3124,11 +3158,49 @@ impl Service {
old_shard_count,
new_shard_count,
new_stripe_size,
targets,
mut targets,
policy,
config,
shard_ident,
} = params;
// Drop any secondary locations: pageservers do not support splitting these, and in any case the
// end-state for a split tenant will usually be to have secondary locations on different nodes.
// The reconciliation calls in this block also implicitly cancel+barrier wrt any ongoing reconciliation
// at the time of split.
let waiters = {
let mut locked = self.inner.write().unwrap();
let mut waiters = Vec::new();
let (nodes, tenants, scheduler) = locked.parts_mut();
for target in &mut targets {
let Some(shard) = tenants.get_mut(&target.parent_id) else {
// Paranoia check: this shouldn't happen: we have the oplock for this tenant ID.
return Err(ApiError::InternalServerError(anyhow::anyhow!(
"Shard {} not found",
target.parent_id
)));
};
if shard.intent.get_attached() != &Some(target.node.get_id()) {
// Paranoia check: this shouldn't happen: we have the oplock for this tenant ID.
return Err(ApiError::Conflict(format!(
"Shard {} unexpectedly rescheduled during split",
target.parent_id
)));
}
// Irrespective of PlacementPolicy, clear secondary locations from intent
shard.intent.clear_secondary(scheduler);
// Run Reconciler to execute detach fo secondary locations.
if let Some(waiter) = self.maybe_reconcile_shard(shard, nodes) {
waiters.push(waiter);
}
}
waiters
};
self.await_waiters(waiters, RECONCILE_TIMEOUT).await?;
// Before creating any new child shards in memory or on the pageservers, persist them: this
// enables us to ensure that we will always be able to clean up if something goes wrong. This also
// acts as the protection against two concurrent attempts to split: one of them will get a database
@@ -3157,8 +3229,7 @@ impl Service {
generation: None,
generation_pageserver: Some(target.node.get_id().0 as i64),
placement_policy: serde_json::to_string(&policy).unwrap(),
// TODO: get the config out of the map
config: serde_json::to_string(&TenantConfig::default()).unwrap(),
config: serde_json::to_string(&config).unwrap(),
splitting: SplitState::Splitting,
});
}
@@ -3215,7 +3286,11 @@ impl Service {
node,
child_ids,
} = target;
let client = mgmt_api::Client::new(node.base_url(), self.config.jwt_token.as_deref());
let client = PageserverClient::new(
node.get_id(),
node.base_url(),
self.config.jwt_token.as_deref(),
);
let response = client
.tenant_shard_split(
*parent_id,
@@ -3343,6 +3418,11 @@ impl Service {
// If we were already attached to something, demote that to a secondary
if let Some(old_attached) = old_attached {
if n > 0 {
// Remove other secondaries to make room for the location we'll demote
while shard.intent.get_secondary().len() >= n {
shard.intent.pop_secondary(scheduler);
}
shard.intent.push_secondary(scheduler, old_attached);
}
}
@@ -3370,7 +3450,7 @@ impl Service {
if let Some(waiter) = waiter {
waiter.wait_timeout(RECONCILE_TIMEOUT).await?;
} else {
tracing::warn!("Migration is a no-op");
tracing::info!("Migration is a no-op");
}
Ok(TenantShardMigrateResponse {})

View File

@@ -4,7 +4,10 @@ use std::{
time::Duration,
};
use crate::{metrics, persistence::TenantShardPersistence};
use crate::{
metrics::{self, ReconcileCompleteLabelGroup, ReconcileOutcome},
persistence::TenantShardPersistence,
};
use pageserver_api::controller_api::PlacementPolicy;
use pageserver_api::{
models::{LocationConfig, LocationConfigMode, TenantConfig},
@@ -718,7 +721,10 @@ impl TenantState {
let reconciler_span = tracing::info_span!(parent: None, "reconciler", seq=%reconcile_seq,
tenant_id=%reconciler.tenant_shard_id.tenant_id,
shard_id=%reconciler.tenant_shard_id.shard_slug());
metrics::RECONCILER.spawned.inc();
metrics::METRICS_REGISTRY
.metrics_group
.storage_controller_reconcile_spawn
.inc();
let result_tx = result_tx.clone();
let join_handle = tokio::task::spawn(
async move {
@@ -736,10 +742,12 @@ impl TenantState {
// TODO: wrap all remote API operations in cancellation check
// as well.
if reconciler.cancel.is_cancelled() {
metrics::RECONCILER
.complete
.with_label_values(&[metrics::ReconcilerMetrics::CANCEL])
.inc();
metrics::METRICS_REGISTRY
.metrics_group
.storage_controller_reconcile_complete
.inc(ReconcileCompleteLabelGroup {
status: ReconcileOutcome::Cancel,
});
return;
}
@@ -754,18 +762,18 @@ impl TenantState {
}
// Update result counter
match &result {
Ok(_) => metrics::RECONCILER
.complete
.with_label_values(&[metrics::ReconcilerMetrics::SUCCESS]),
Err(ReconcileError::Cancel) => metrics::RECONCILER
.complete
.with_label_values(&[metrics::ReconcilerMetrics::CANCEL]),
Err(_) => metrics::RECONCILER
.complete
.with_label_values(&[metrics::ReconcilerMetrics::ERROR]),
}
.inc();
let outcome_label = match &result {
Ok(_) => ReconcileOutcome::Success,
Err(ReconcileError::Cancel) => ReconcileOutcome::Cancel,
Err(_) => ReconcileOutcome::Error,
};
metrics::METRICS_REGISTRY
.metrics_group
.storage_controller_reconcile_complete
.inc(ReconcileCompleteLabelGroup {
status: outcome_label,
});
result_tx
.send(ReconcileResult {

View File

@@ -279,6 +279,7 @@ impl StorageController {
&self.listen,
"-p",
self.path.as_ref(),
"--dev",
"--database-url",
&database_url,
"--max-unavailable-interval",

View File

@@ -6,7 +6,9 @@
use serde::{Deserialize, Serialize};
use utils::id::NodeId;
use crate::{controller_api::NodeRegisterRequest, shard::TenantShardId};
use crate::{
controller_api::NodeRegisterRequest, models::LocationConfigMode, shard::TenantShardId,
};
/// Upcall message sent by the pageserver to the configured `control_plane_api` on
/// startup.
@@ -20,12 +22,20 @@ pub struct ReAttachRequest {
pub register: Option<NodeRegisterRequest>,
}
#[derive(Serialize, Deserialize)]
pub struct ReAttachResponseTenant {
pub id: TenantShardId,
pub gen: u32,
fn default_mode() -> LocationConfigMode {
LocationConfigMode::AttachedSingle
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ReAttachResponseTenant {
pub id: TenantShardId,
/// Mandatory if LocationConfigMode is None or set to an Attached* mode
pub gen: Option<u32>,
/// Default value only for backward compat: this field should be set
#[serde(default = "default_mode")]
pub mode: LocationConfigMode,
}
#[derive(Serialize, Deserialize)]
pub struct ReAttachResponse {
pub tenants: Vec<ReAttachResponseTenant>,

View File

@@ -1,5 +1,6 @@
use anyhow::*;
use clap::{value_parser, Arg, ArgMatches, Command};
use postgres::Client;
use std::{path::PathBuf, str::FromStr};
use wal_craft::*;
@@ -8,8 +9,8 @@ fn main() -> Result<()> {
.init();
let arg_matches = cli().get_matches();
let wal_craft = |arg_matches: &ArgMatches, client| {
let (intermediate_lsns, end_of_wal_lsn) = match arg_matches
let wal_craft = |arg_matches: &ArgMatches, client: &mut Client| {
let intermediate_lsns = match arg_matches
.get_one::<String>("type")
.map(|s| s.as_str())
.context("'type' is required")?
@@ -25,6 +26,7 @@ fn main() -> Result<()> {
LastWalRecordCrossingSegment::NAME => LastWalRecordCrossingSegment::craft(client)?,
a => panic!("Unknown --type argument: {a}"),
};
let end_of_wal_lsn = client.pg_current_wal_insert_lsn()?;
for lsn in intermediate_lsns {
println!("intermediate_lsn = {lsn}");
}

View File

@@ -5,7 +5,6 @@ use postgres::types::PgLsn;
use postgres::Client;
use postgres_ffi::{WAL_SEGMENT_SIZE, XLOG_BLCKSZ};
use postgres_ffi::{XLOG_SIZE_OF_XLOG_RECORD, XLOG_SIZE_OF_XLOG_SHORT_PHD};
use std::cmp::Ordering;
use std::path::{Path, PathBuf};
use std::process::Command;
use std::time::{Duration, Instant};
@@ -232,59 +231,52 @@ pub fn ensure_server_config(client: &mut impl postgres::GenericClient) -> anyhow
pub trait Crafter {
const NAME: &'static str;
/// Generates WAL using the client `client`. Returns a pair of:
/// * A vector of some valid "interesting" intermediate LSNs which one may start reading from.
/// May include or exclude Lsn(0) and the end-of-wal.
/// * The expected end-of-wal LSN.
fn craft(client: &mut impl postgres::GenericClient) -> anyhow::Result<(Vec<PgLsn>, PgLsn)>;
/// Generates WAL using the client `client`. Returns a vector of some valid
/// "interesting" intermediate LSNs which one may start reading from.
/// test_end_of_wal uses this to check various starting points.
///
/// Note that postgres is generally keen about writing some WAL. While we
/// try to disable it (autovacuum, big wal_writer_delay, etc) it is always
/// possible, e.g. xl_running_xacts are dumped each 15s. So checks about
/// stable WAL end would be flaky unless postgres is shut down. For this
/// reason returning potential end of WAL here is pointless. Most of the
/// time this doesn't happen though, so it is reasonable to create needed
/// WAL structure and immediately kill postgres like test_end_of_wal does.
fn craft(client: &mut impl postgres::GenericClient) -> anyhow::Result<Vec<PgLsn>>;
}
/// Wraps some WAL craft function, providing current LSN to it before the
/// insertion and flushing WAL afterwards. Also pushes initial LSN to the
/// result.
fn craft_internal<C: postgres::GenericClient>(
client: &mut C,
f: impl Fn(&mut C, PgLsn) -> anyhow::Result<(Vec<PgLsn>, Option<PgLsn>)>,
) -> anyhow::Result<(Vec<PgLsn>, PgLsn)> {
f: impl Fn(&mut C, PgLsn) -> anyhow::Result<Vec<PgLsn>>,
) -> anyhow::Result<Vec<PgLsn>> {
ensure_server_config(client)?;
let initial_lsn = client.pg_current_wal_insert_lsn()?;
info!("LSN initial = {}", initial_lsn);
let (mut intermediate_lsns, last_lsn) = f(client, initial_lsn)?;
let last_lsn = match last_lsn {
None => client.pg_current_wal_insert_lsn()?,
Some(last_lsn) => {
let insert_lsn = client.pg_current_wal_insert_lsn()?;
match last_lsn.cmp(&insert_lsn) {
Ordering::Less => bail!(
"Some records were inserted after the crafted WAL: {} vs {}",
last_lsn,
insert_lsn
),
Ordering::Equal => last_lsn,
Ordering::Greater => bail!("Reported LSN is greater than insert_lsn"),
}
}
};
let mut intermediate_lsns = f(client, initial_lsn)?;
if !intermediate_lsns.starts_with(&[initial_lsn]) {
intermediate_lsns.insert(0, initial_lsn);
}
// Some records may be not flushed, e.g. non-transactional logical messages.
//
// Note: this is broken if pg_current_wal_insert_lsn is at page boundary
// because pg_current_wal_insert_lsn skips page headers.
client.execute("select neon_xlogflush(pg_current_wal_insert_lsn())", &[])?;
match last_lsn.cmp(&client.pg_current_wal_flush_lsn()?) {
Ordering::Less => bail!("Some records were flushed after the crafted WAL"),
Ordering::Equal => {}
Ordering::Greater => bail!("Reported LSN is greater than flush_lsn"),
}
Ok((intermediate_lsns, last_lsn))
Ok(intermediate_lsns)
}
pub struct Simple;
impl Crafter for Simple {
const NAME: &'static str = "simple";
fn craft(client: &mut impl postgres::GenericClient) -> anyhow::Result<(Vec<PgLsn>, PgLsn)> {
fn craft(client: &mut impl postgres::GenericClient) -> anyhow::Result<Vec<PgLsn>> {
craft_internal(client, |client, _| {
client.execute("CREATE table t(x int)", &[])?;
Ok((Vec::new(), None))
Ok(Vec::new())
})
}
}
@@ -292,29 +284,36 @@ impl Crafter for Simple {
pub struct LastWalRecordXlogSwitch;
impl Crafter for LastWalRecordXlogSwitch {
const NAME: &'static str = "last_wal_record_xlog_switch";
fn craft(client: &mut impl postgres::GenericClient) -> anyhow::Result<(Vec<PgLsn>, PgLsn)> {
// Do not use generate_internal because here we end up with flush_lsn exactly on
fn craft(client: &mut impl postgres::GenericClient) -> anyhow::Result<Vec<PgLsn>> {
// Do not use craft_internal because here we end up with flush_lsn exactly on
// the segment boundary and insert_lsn after the initial page header, which is unusual.
ensure_server_config(client)?;
client.execute("CREATE table t(x int)", &[])?;
let before_xlog_switch = client.pg_current_wal_insert_lsn()?;
let after_xlog_switch: PgLsn = client.query_one("SELECT pg_switch_wal()", &[])?.get(0);
let next_segment = PgLsn::from(0x0200_0000);
// pg_switch_wal returns end of last record of the switched segment,
// i.e. end of SWITCH itself.
let xlog_switch_record_end: PgLsn = client.query_one("SELECT pg_switch_wal()", &[])?.get(0);
let before_xlog_switch_u64 = u64::from(before_xlog_switch);
let next_segment = PgLsn::from(
before_xlog_switch_u64 - (before_xlog_switch_u64 % WAL_SEGMENT_SIZE as u64)
+ WAL_SEGMENT_SIZE as u64,
);
ensure!(
after_xlog_switch <= next_segment,
"XLOG_SWITCH message ended after the expected segment boundary: {} > {}",
after_xlog_switch,
xlog_switch_record_end <= next_segment,
"XLOG_SWITCH record ended after the expected segment boundary: {} > {}",
xlog_switch_record_end,
next_segment
);
Ok((vec![before_xlog_switch, after_xlog_switch], next_segment))
Ok(vec![before_xlog_switch, xlog_switch_record_end])
}
}
pub struct LastWalRecordXlogSwitchEndsOnPageBoundary;
/// Craft xlog SWITCH record ending at page boundary.
impl Crafter for LastWalRecordXlogSwitchEndsOnPageBoundary {
const NAME: &'static str = "last_wal_record_xlog_switch_ends_on_page_boundary";
fn craft(client: &mut impl postgres::GenericClient) -> anyhow::Result<(Vec<PgLsn>, PgLsn)> {
fn craft(client: &mut impl postgres::GenericClient) -> anyhow::Result<Vec<PgLsn>> {
// Do not use generate_internal because here we end up with flush_lsn exactly on
// the segment boundary and insert_lsn after the initial page header, which is unusual.
ensure_server_config(client)?;
@@ -361,28 +360,29 @@ impl Crafter for LastWalRecordXlogSwitchEndsOnPageBoundary {
// Emit the XLOG_SWITCH
let before_xlog_switch = client.pg_current_wal_insert_lsn()?;
let after_xlog_switch: PgLsn = client.query_one("SELECT pg_switch_wal()", &[])?.get(0);
let xlog_switch_record_end: PgLsn = client.query_one("SELECT pg_switch_wal()", &[])?.get(0);
let next_segment = PgLsn::from(0x0200_0000);
ensure!(
after_xlog_switch < next_segment,
"XLOG_SWITCH message ended on or after the expected segment boundary: {} > {}",
after_xlog_switch,
xlog_switch_record_end < next_segment,
"XLOG_SWITCH record ended on or after the expected segment boundary: {} > {}",
xlog_switch_record_end,
next_segment
);
ensure!(
u64::from(after_xlog_switch) as usize % XLOG_BLCKSZ == XLOG_SIZE_OF_XLOG_SHORT_PHD,
u64::from(xlog_switch_record_end) as usize % XLOG_BLCKSZ == XLOG_SIZE_OF_XLOG_SHORT_PHD,
"XLOG_SWITCH message ended not on page boundary: {}, offset = {}",
after_xlog_switch,
u64::from(after_xlog_switch) as usize % XLOG_BLCKSZ
xlog_switch_record_end,
u64::from(xlog_switch_record_end) as usize % XLOG_BLCKSZ
);
Ok((vec![before_xlog_switch, after_xlog_switch], next_segment))
Ok(vec![before_xlog_switch, xlog_switch_record_end])
}
}
fn craft_single_logical_message(
/// Write ~16MB logical message; it should cross WAL segment.
fn craft_seg_size_logical_message(
client: &mut impl postgres::GenericClient,
transactional: bool,
) -> anyhow::Result<(Vec<PgLsn>, PgLsn)> {
) -> anyhow::Result<Vec<PgLsn>> {
craft_internal(client, |client, initial_lsn| {
ensure!(
initial_lsn < PgLsn::from(0x0200_0000 - 1024 * 1024),
@@ -405,34 +405,24 @@ fn craft_single_logical_message(
"Logical message crossed two segments"
);
if transactional {
// Transactional logical messages are part of a transaction, so the one above is
// followed by a small COMMIT record.
let after_message_lsn = client.pg_current_wal_insert_lsn()?;
ensure!(
message_lsn < after_message_lsn,
"No record found after the emitted message"
);
Ok((vec![message_lsn], Some(after_message_lsn)))
} else {
Ok((Vec::new(), Some(message_lsn)))
}
Ok(vec![message_lsn])
})
}
pub struct WalRecordCrossingSegmentFollowedBySmallOne;
impl Crafter for WalRecordCrossingSegmentFollowedBySmallOne {
const NAME: &'static str = "wal_record_crossing_segment_followed_by_small_one";
fn craft(client: &mut impl postgres::GenericClient) -> anyhow::Result<(Vec<PgLsn>, PgLsn)> {
craft_single_logical_message(client, true)
fn craft(client: &mut impl postgres::GenericClient) -> anyhow::Result<Vec<PgLsn>> {
// Transactional message crossing WAL segment will be followed by small
// commit record.
craft_seg_size_logical_message(client, true)
}
}
pub struct LastWalRecordCrossingSegment;
impl Crafter for LastWalRecordCrossingSegment {
const NAME: &'static str = "last_wal_record_crossing_segment";
fn craft(client: &mut impl postgres::GenericClient) -> anyhow::Result<(Vec<PgLsn>, PgLsn)> {
craft_single_logical_message(client, false)
fn craft(client: &mut impl postgres::GenericClient) -> anyhow::Result<Vec<PgLsn>> {
craft_seg_size_logical_message(client, false)
}
}

View File

@@ -11,13 +11,15 @@ use utils::const_assert;
use utils::lsn::Lsn;
fn init_logging() {
let _ = env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(
format!("crate=info,postgres_ffi::{PG_MAJORVERSION}::xlog_utils=trace"),
))
let _ = env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(format!(
"crate=info,postgres_ffi::{PG_MAJORVERSION}::xlog_utils=trace"
)))
.is_test(true)
.try_init();
}
/// Test that find_end_of_wal returns the same results as pg_dump on various
/// WALs created by Crafter.
fn test_end_of_wal<C: crate::Crafter>(test_name: &str) {
use crate::*;
@@ -38,13 +40,13 @@ fn test_end_of_wal<C: crate::Crafter>(test_name: &str) {
}
cfg.initdb().unwrap();
let srv = cfg.start_server().unwrap();
let (intermediate_lsns, expected_end_of_wal_partial) =
C::craft(&mut srv.connect_with_timeout().unwrap()).unwrap();
let intermediate_lsns = C::craft(&mut srv.connect_with_timeout().unwrap()).unwrap();
let intermediate_lsns: Vec<Lsn> = intermediate_lsns
.iter()
.map(|&lsn| u64::from(lsn).into())
.collect();
let expected_end_of_wal: Lsn = u64::from(expected_end_of_wal_partial).into();
// Kill postgres. Note that it might have inserted to WAL something after
// 'craft' did its job.
srv.kill();
// Check find_end_of_wal on the initial WAL
@@ -56,7 +58,7 @@ fn test_end_of_wal<C: crate::Crafter>(test_name: &str) {
.filter(|fname| IsXLogFileName(fname))
.max()
.unwrap();
check_pg_waldump_end_of_wal(&cfg, &last_segment, expected_end_of_wal);
let expected_end_of_wal = find_pg_waldump_end_of_wal(&cfg, &last_segment);
for start_lsn in intermediate_lsns
.iter()
.chain(std::iter::once(&expected_end_of_wal))
@@ -91,11 +93,7 @@ fn test_end_of_wal<C: crate::Crafter>(test_name: &str) {
}
}
fn check_pg_waldump_end_of_wal(
cfg: &crate::Conf,
last_segment: &str,
expected_end_of_wal: Lsn,
) {
fn find_pg_waldump_end_of_wal(cfg: &crate::Conf, last_segment: &str) -> Lsn {
// Get the actual end of WAL by pg_waldump
let waldump_output = cfg
.pg_waldump("000000010000000000000001", last_segment)
@@ -113,11 +111,8 @@ fn check_pg_waldump_end_of_wal(
}
};
let waldump_wal_end = Lsn::from_str(caps.get(1).unwrap().as_str()).unwrap();
info!(
"waldump erred on {}, expected wal end at {}",
waldump_wal_end, expected_end_of_wal
);
assert_eq!(waldump_wal_end, expected_end_of_wal);
info!("waldump erred on {}", waldump_wal_end);
waldump_wal_end
}
fn check_end_of_wal(
@@ -210,9 +205,9 @@ pub fn test_update_next_xid() {
#[test]
pub fn test_encode_logical_message() {
let expected = [
64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 21, 0, 0, 170, 34, 166, 227, 255,
38, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 112, 114,
101, 102, 105, 120, 0, 109, 101, 115, 115, 97, 103, 101,
64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 21, 0, 0, 170, 34, 166, 227, 255, 38,
0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 112, 114, 101, 102,
105, 120, 0, 109, 101, 115, 115, 97, 103, 101,
];
let actual = encode_logical_message("prefix", "message");
assert_eq!(expected, actual[..]);

View File

@@ -198,6 +198,7 @@ impl LocalFs {
fs::OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&temp_file_path)
.await
.with_context(|| {

View File

@@ -245,7 +245,7 @@ impl std::io::Write for ChannelWriter {
}
}
async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body>, ApiError> {
pub async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body>, ApiError> {
SERVE_METRICS_COUNT.inc();
let started_at = std::time::Instant::now();
@@ -367,7 +367,6 @@ pub fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
.middleware(Middleware::post_with_info(
add_request_id_header_to_response,
))
.get("/metrics", |r| request_span(r, prometheus_metrics_handler))
.err_handler(route_error_handler)
}

View File

@@ -63,6 +63,7 @@ impl UnwrittenLockFile {
pub fn create_exclusive(lock_file_path: &Utf8Path) -> anyhow::Result<UnwrittenLockFile> {
let lock_file = fs::OpenOptions::new()
.create(true) // O_CREAT
.truncate(true)
.write(true)
.open(lock_file_path)
.context("open lock file")?;

View File

@@ -1,160 +1,156 @@
//! Simple benchmarking around walredo.
//! Quantify a single walredo manager's throughput under N concurrent callers.
//!
//! Right now they hope to just set a baseline. Later we can try to expand into latency and
//! throughput after figuring out the coordinated omission problems below.
//! The benchmark implementation ([`bench_impl`]) is parametrized by
//! - `redo_work` => [`Request::short_request`] or [`Request::medium_request`]
//! - `n_redos` => number of times the benchmark shell execute the `redo_work`
//! - `nclients` => number of clients (more on this shortly).
//!
//! There are two sets of inputs; `short` and `medium`. They were collected on postgres v14 by
//! logging what happens when a sequential scan is requested on a small table, then picking out two
//! suitable from logs.
//! The benchmark impl sets up a multi-threaded tokio runtime with default parameters.
//! It spawns `nclients` times [`client`] tokio tasks.
//! Each task executes the `redo_work` `n_redos/nclients` times.
//!
//! We exercise the following combinations:
//! - `redo_work = short / medium``
//! - `nclients = [1, 2, 4, 8, 16, 32, 64, 128]`
//!
//! Reference data (git blame to see commit) on an i3en.3xlarge
// ```text
//! short/short/1 time: [39.175 µs 39.348 µs 39.536 µs]
//! short/short/2 time: [51.227 µs 51.487 µs 51.755 µs]
//! short/short/4 time: [76.048 µs 76.362 µs 76.674 µs]
//! short/short/8 time: [128.94 µs 129.82 µs 130.74 µs]
//! short/short/16 time: [227.84 µs 229.00 µs 230.28 µs]
//! short/short/32 time: [455.97 µs 457.81 µs 459.90 µs]
//! short/short/64 time: [902.46 µs 904.84 µs 907.32 µs]
//! short/short/128 time: [1.7416 ms 1.7487 ms 1.7561 ms]
//! ``
use std::sync::Arc;
//! We let `criterion` determine the `n_redos` using `iter_custom`.
//! The idea is that for each `(redo_work, nclients)` combination,
//! criterion will run the `bench_impl` multiple times with different `n_redos`.
//! The `bench_impl` reports the aggregate wall clock time from the clients' perspective.
//! Criterion will divide that by `n_redos` to compute the "time per iteration".
//! In our case, "time per iteration" means "time per redo_work execution".
//!
//! NB: the way by which `iter_custom` determines the "number of iterations"
//! is called sampling. Apparently the idea here is to detect outliers.
//! We're not sure whether the current choice of sampling method makes sense.
//! See https://bheisler.github.io/criterion.rs/book/user_guide/command_line_output.html#collecting-samples
//!
//! # Reference Numbers
//!
//! 2024-03-20 on i3en.3xlarge
//!
//! ```text
//! short/1 time: [26.483 µs 26.614 µs 26.767 µs]
//! short/2 time: [32.223 µs 32.465 µs 32.767 µs]
//! short/4 time: [47.203 µs 47.583 µs 47.984 µs]
//! short/8 time: [89.135 µs 89.612 µs 90.139 µs]
//! short/16 time: [190.12 µs 191.52 µs 192.88 µs]
//! short/32 time: [380.96 µs 382.63 µs 384.20 µs]
//! short/64 time: [736.86 µs 741.07 µs 745.03 µs]
//! short/128 time: [1.4106 ms 1.4206 ms 1.4294 ms]
//! medium/1 time: [111.81 µs 112.25 µs 112.79 µs]
//! medium/2 time: [158.26 µs 159.13 µs 160.21 µs]
//! medium/4 time: [334.65 µs 337.14 µs 340.07 µs]
//! medium/8 time: [675.32 µs 679.91 µs 685.25 µs]
//! medium/16 time: [1.2929 ms 1.2996 ms 1.3067 ms]
//! medium/32 time: [2.4295 ms 2.4461 ms 2.4623 ms]
//! medium/64 time: [4.3973 ms 4.4458 ms 4.4875 ms]
//! medium/128 time: [7.5955 ms 7.7847 ms 7.9481 ms]
//! ```
use bytes::{Buf, Bytes};
use pageserver::{
config::PageServerConf, repository::Key, walrecord::NeonWalRecord, walredo::PostgresRedoManager,
use criterion::{BenchmarkId, Criterion};
use pageserver::{config::PageServerConf, walrecord::NeonWalRecord, walredo::PostgresRedoManager};
use pageserver_api::{key::Key, shard::TenantShardId};
use std::{
sync::Arc,
time::{Duration, Instant},
};
use pageserver_api::shard::TenantShardId;
use tokio::task::JoinSet;
use tokio::{sync::Barrier, task::JoinSet};
use utils::{id::TenantId, lsn::Lsn};
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
fn bench(c: &mut Criterion) {
{
let nclients = [1, 2, 4, 8, 16, 32, 64, 128];
for nclients in nclients {
let mut group = c.benchmark_group("short");
group.bench_with_input(
BenchmarkId::from_parameter(nclients),
&nclients,
|b, nclients| {
let redo_work = Arc::new(Request::short_input());
b.iter_custom(|iters| bench_impl(Arc::clone(&redo_work), iters, *nclients));
},
);
}
}
fn redo_scenarios(c: &mut Criterion) {
// logging should be enabled when adding more inputs, since walredo will only report malformed
// input to the stderr.
// utils::logging::init(utils::logging::LogFormat::Plain).unwrap();
{
let nclients = [1, 2, 4, 8, 16, 32, 64, 128];
for nclients in nclients {
let mut group = c.benchmark_group("medium");
group.bench_with_input(
BenchmarkId::from_parameter(nclients),
&nclients,
|b, nclients| {
let redo_work = Arc::new(Request::medium_input());
b.iter_custom(|iters| bench_impl(Arc::clone(&redo_work), iters, *nclients));
},
);
}
}
}
criterion::criterion_group!(benches, bench);
criterion::criterion_main!(benches);
// Returns the sum of each client's wall-clock time spent executing their share of the n_redos.
fn bench_impl(redo_work: Arc<Request>, n_redos: u64, nclients: u64) -> Duration {
let repo_dir = camino_tempfile::tempdir_in(env!("CARGO_TARGET_TMPDIR")).unwrap();
let conf = PageServerConf::dummy_conf(repo_dir.path().to_path_buf());
let conf = Box::leak(Box::new(conf));
let tenant_shard_id = TenantShardId::unsharded(TenantId::generate());
let manager = PostgresRedoManager::new(conf, tenant_shard_id);
let manager = Arc::new(manager);
{
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
tracing::info!("executing first");
rt.block_on(short().execute(&manager)).unwrap();
tracing::info!("first executed");
}
let thread_counts = [1, 2, 4, 8, 16, 32, 64, 128];
let mut group = c.benchmark_group("short");
group.sampling_mode(criterion::SamplingMode::Flat);
for thread_count in thread_counts {
group.bench_with_input(
BenchmarkId::new("short", thread_count),
&thread_count,
|b, thread_count| {
add_multithreaded_walredo_requesters(b, *thread_count, &manager, short);
},
);
}
drop(group);
let mut group = c.benchmark_group("medium");
group.sampling_mode(criterion::SamplingMode::Flat);
for thread_count in thread_counts {
group.bench_with_input(
BenchmarkId::new("medium", thread_count),
&thread_count,
|b, thread_count| {
add_multithreaded_walredo_requesters(b, *thread_count, &manager, medium);
},
);
}
drop(group);
}
/// Sets up a multi-threaded tokio runtime with default worker thread count,
/// then, spawn `requesters` tasks that repeatedly:
/// - get input from `input_factor()`
/// - call `manager.request_redo()` with their input
///
/// This stress-tests the scalability of a single walredo manager at high tokio-level concurrency.
///
/// Using tokio's default worker thread count means the results will differ on machines
/// with different core countrs. We don't care about that, the performance will always
/// be different on different hardware. To compare performance of different software versions,
/// use the same hardware.
fn add_multithreaded_walredo_requesters(
b: &mut criterion::Bencher,
nrequesters: usize,
manager: &Arc<PostgresRedoManager>,
input_factory: fn() -> Request,
) {
assert_ne!(nrequesters, 0);
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let barrier = Arc::new(tokio::sync::Barrier::new(nrequesters + 1));
let start = Arc::new(Barrier::new(nclients as usize));
let mut requesters = JoinSet::new();
for _ in 0..nrequesters {
let _entered = rt.enter();
let manager = manager.clone();
let barrier = barrier.clone();
requesters.spawn(async move {
loop {
let input = input_factory();
barrier.wait().await;
let page = input.execute(&manager).await.unwrap();
assert_eq!(page.remaining(), 8192);
barrier.wait().await;
}
let mut tasks = JoinSet::new();
let manager = PostgresRedoManager::new(conf, tenant_shard_id);
let manager = Arc::new(manager);
for _ in 0..nclients {
rt.block_on(async {
tasks.spawn(client(
Arc::clone(&manager),
Arc::clone(&start),
Arc::clone(&redo_work),
// divide the amount of work equally among the clients
n_redos / nclients,
))
});
}
let do_one_iteration = || {
rt.block_on(async {
barrier.wait().await;
// wait for work to complete
barrier.wait().await;
})
};
b.iter_batched(
|| {
// warmup
do_one_iteration();
},
|()| {
// work loop
do_one_iteration();
},
criterion::BatchSize::PerIteration,
);
rt.block_on(requesters.shutdown());
rt.block_on(async move {
let mut total_wallclock_time = std::time::Duration::from_millis(0);
while let Some(res) = tasks.join_next().await {
total_wallclock_time += res.unwrap();
}
total_wallclock_time
})
}
criterion_group!(benches, redo_scenarios);
criterion_main!(benches);
async fn client(
mgr: Arc<PostgresRedoManager>,
start: Arc<Barrier>,
redo_work: Arc<Request>,
n_redos: u64,
) -> Duration {
start.wait().await;
let start = Instant::now();
for _ in 0..n_redos {
let page = redo_work.execute(&mgr).await.unwrap();
assert_eq!(page.remaining(), 8192);
// The real pageserver will rarely if ever do 2 walredos in a row without
// yielding to the executor.
tokio::task::yield_now().await;
}
start.elapsed()
}
macro_rules! lsn {
($input:expr) => {{
@@ -166,12 +162,46 @@ macro_rules! lsn {
}};
}
/// Short payload, 1132 bytes.
// pg_records are copypasted from log, where they are put with Debug impl of Bytes, which uses \0
// for null bytes.
#[allow(clippy::octal_escapes)]
fn short() -> Request {
Request {
/// Simple wrapper around `WalRedoManager::request_redo`.
///
/// In benchmarks this is cloned around.
#[derive(Clone)]
struct Request {
key: Key,
lsn: Lsn,
base_img: Option<(Lsn, Bytes)>,
records: Vec<(Lsn, NeonWalRecord)>,
pg_version: u32,
}
impl Request {
async fn execute(&self, manager: &PostgresRedoManager) -> anyhow::Result<Bytes> {
let Request {
key,
lsn,
base_img,
records,
pg_version,
} = self;
// TODO: avoid these clones
manager
.request_redo(*key, *lsn, base_img.clone(), records.clone(), *pg_version)
.await
}
fn pg_record(will_init: bool, bytes: &'static [u8]) -> NeonWalRecord {
let rec = Bytes::from_static(bytes);
NeonWalRecord::Postgres { will_init, rec }
}
/// Short payload, 1132 bytes.
// pg_records are copypasted from log, where they are put with Debug impl of Bytes, which uses \0
// for null bytes.
#[allow(clippy::octal_escapes)]
pub fn short_input() -> Request {
let pg_record = Self::pg_record;
Request {
key: Key {
field1: 0,
field2: 1663,
@@ -194,13 +224,14 @@ fn short() -> Request {
],
pg_version: 14,
}
}
}
/// Medium sized payload, serializes as 26393 bytes.
// see [`short`]
#[allow(clippy::octal_escapes)]
fn medium() -> Request {
Request {
/// Medium sized payload, serializes as 26393 bytes.
// see [`short`]
#[allow(clippy::octal_escapes)]
pub fn medium_input() -> Request {
let pg_record = Self::pg_record;
Request {
key: Key {
field1: 0,
field2: 1663,
@@ -442,37 +473,5 @@ fn medium() -> Request {
],
pg_version: 14,
}
}
fn pg_record(will_init: bool, bytes: &'static [u8]) -> NeonWalRecord {
let rec = Bytes::from_static(bytes);
NeonWalRecord::Postgres { will_init, rec }
}
/// Simple wrapper around `WalRedoManager::request_redo`.
///
/// In benchmarks this is cloned around.
#[derive(Clone)]
struct Request {
key: Key,
lsn: Lsn,
base_img: Option<(Lsn, Bytes)>,
records: Vec<(Lsn, NeonWalRecord)>,
pg_version: u32,
}
impl Request {
async fn execute(self, manager: &PostgresRedoManager) -> anyhow::Result<Bytes> {
let Request {
key,
lsn,
base_img,
records,
pg_version,
} = self;
manager
.request_redo(key, lsn, base_img, records, pg_version)
.await
}
}

View File

@@ -15,9 +15,9 @@ use metrics::launch_timestamp::{set_launch_timestamp_metric, LaunchTimestamp};
use pageserver::control_plane_client::ControlPlaneClient;
use pageserver::disk_usage_eviction_task::{self, launch_disk_usage_global_eviction_task};
use pageserver::metrics::{STARTUP_DURATION, STARTUP_IS_LOADING};
use pageserver::task_mgr::WALRECEIVER_RUNTIME;
use pageserver::tenant::{secondary, TenantSharedResources};
use remote_storage::GenericRemoteStorage;
use tokio::signal::unix::SignalKind;
use tokio::time::Instant;
use tracing::*;
@@ -28,7 +28,7 @@ use pageserver::{
deletion_queue::DeletionQueue,
http, page_cache, page_service, task_mgr,
task_mgr::TaskKind,
task_mgr::{BACKGROUND_RUNTIME, COMPUTE_REQUEST_RUNTIME, MGMT_REQUEST_RUNTIME},
task_mgr::THE_RUNTIME,
tenant::mgr,
virtual_file,
};
@@ -323,7 +323,7 @@ fn start_pageserver(
// Launch broker client
// The storage_broker::connect call needs to happen inside a tokio runtime thread.
let broker_client = WALRECEIVER_RUNTIME
let broker_client = THE_RUNTIME
.block_on(async {
// Note: we do not attempt connecting here (but validate endpoints sanity).
storage_broker::connect(conf.broker_endpoint.clone(), conf.broker_keepalive_interval)
@@ -391,7 +391,7 @@ fn start_pageserver(
conf,
);
if let Some(deletion_workers) = deletion_workers {
deletion_workers.spawn_with(BACKGROUND_RUNTIME.handle());
deletion_workers.spawn_with(THE_RUNTIME.handle());
}
// Up to this point no significant I/O has been done: this should have been fast. Record
@@ -423,7 +423,7 @@ fn start_pageserver(
// Scan the local 'tenants/' directory and start loading the tenants
let deletion_queue_client = deletion_queue.new_client();
let tenant_manager = BACKGROUND_RUNTIME.block_on(mgr::init_tenant_mgr(
let tenant_manager = THE_RUNTIME.block_on(mgr::init_tenant_mgr(
conf,
TenantSharedResources {
broker_client: broker_client.clone(),
@@ -435,7 +435,7 @@ fn start_pageserver(
))?;
let tenant_manager = Arc::new(tenant_manager);
BACKGROUND_RUNTIME.spawn({
THE_RUNTIME.spawn({
let shutdown_pageserver = shutdown_pageserver.clone();
let drive_init = async move {
// NOTE: unlike many futures in pageserver, this one is cancellation-safe
@@ -545,7 +545,7 @@ fn start_pageserver(
// Start up the service to handle HTTP mgmt API request. We created the
// listener earlier already.
{
let _rt_guard = MGMT_REQUEST_RUNTIME.enter();
let _rt_guard = THE_RUNTIME.enter();
let router_state = Arc::new(
http::routes::State::new(
@@ -569,7 +569,6 @@ fn start_pageserver(
.with_graceful_shutdown(task_mgr::shutdown_watcher());
task_mgr::spawn(
MGMT_REQUEST_RUNTIME.handle(),
TaskKind::HttpEndpointListener,
None,
None,
@@ -594,7 +593,6 @@ fn start_pageserver(
let local_disk_storage = conf.workdir.join("last_consumption_metrics.json");
task_mgr::spawn(
crate::BACKGROUND_RUNTIME.handle(),
TaskKind::MetricsCollection,
None,
None,
@@ -615,6 +613,7 @@ fn start_pageserver(
pageserver::consumption_metrics::collect_metrics(
metric_collection_endpoint,
&conf.metric_collection_bucket,
conf.metric_collection_interval,
conf.cached_metric_collection_interval,
conf.synthetic_size_calculation_interval,
@@ -642,7 +641,6 @@ fn start_pageserver(
DownloadBehavior::Error,
);
task_mgr::spawn(
COMPUTE_REQUEST_RUNTIME.handle(),
TaskKind::LibpqEndpointListener,
None,
None,
@@ -666,42 +664,37 @@ fn start_pageserver(
let mut shutdown_pageserver = Some(shutdown_pageserver.drop_guard());
// All started up! Now just sit and wait for shutdown signal.
{
use signal_hook::consts::*;
let signal_handler = BACKGROUND_RUNTIME.spawn_blocking(move || {
let mut signals =
signal_hook::iterator::Signals::new([SIGINT, SIGTERM, SIGQUIT]).unwrap();
return signals
.forever()
.next()
.expect("forever() never returns None unless explicitly closed");
});
let signal = BACKGROUND_RUNTIME
.block_on(signal_handler)
.expect("join error");
match signal {
SIGQUIT => {
info!("Got signal {signal}. Terminating in immediate shutdown mode",);
std::process::exit(111);
}
SIGINT | SIGTERM => {
info!("Got signal {signal}. Terminating gracefully in fast shutdown mode",);
// This cancels the `shutdown_pageserver` cancellation tree.
// Right now that tree doesn't reach very far, and `task_mgr` is used instead.
// The plan is to change that over time.
shutdown_pageserver.take();
let bg_remote_storage = remote_storage.clone();
let bg_deletion_queue = deletion_queue.clone();
BACKGROUND_RUNTIME.block_on(pageserver::shutdown_pageserver(
&tenant_manager,
bg_remote_storage.map(|_| bg_deletion_queue),
0,
));
unreachable!()
}
_ => unreachable!(),
}
{
THE_RUNTIME.block_on(async move {
let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt()).unwrap();
let mut sigterm = tokio::signal::unix::signal(SignalKind::terminate()).unwrap();
let mut sigquit = tokio::signal::unix::signal(SignalKind::quit()).unwrap();
let signal = tokio::select! {
_ = sigquit.recv() => {
info!("Got signal SIGQUIT. Terminating in immediate shutdown mode",);
std::process::exit(111);
}
_ = sigint.recv() => { "SIGINT" },
_ = sigterm.recv() => { "SIGTERM" },
};
info!("Got signal {signal}. Terminating gracefully in fast shutdown mode",);
// This cancels the `shutdown_pageserver` cancellation tree.
// Right now that tree doesn't reach very far, and `task_mgr` is used instead.
// The plan is to change that over time.
shutdown_pageserver.take();
let bg_remote_storage = remote_storage.clone();
let bg_deletion_queue = deletion_queue.clone();
pageserver::shutdown_pageserver(
&tenant_manager,
bg_remote_storage.map(|_| bg_deletion_queue),
0,
)
.await;
unreachable!()
})
}
}

View File

@@ -234,6 +234,7 @@ pub struct PageServerConf {
// How often to send unchanged cached metrics to the metrics endpoint.
pub cached_metric_collection_interval: Duration,
pub metric_collection_endpoint: Option<Url>,
pub metric_collection_bucket: Option<RemoteStorageConfig>,
pub synthetic_size_calculation_interval: Duration,
pub disk_usage_based_eviction: Option<DiskUsageEvictionTaskConfig>,
@@ -373,6 +374,7 @@ struct PageServerConfigBuilder {
cached_metric_collection_interval: BuilderValue<Duration>,
metric_collection_endpoint: BuilderValue<Option<Url>>,
synthetic_size_calculation_interval: BuilderValue<Duration>,
metric_collection_bucket: BuilderValue<Option<RemoteStorageConfig>>,
disk_usage_based_eviction: BuilderValue<Option<DiskUsageEvictionTaskConfig>>,
@@ -455,6 +457,8 @@ impl PageServerConfigBuilder {
.expect("cannot parse default synthetic size calculation interval")),
metric_collection_endpoint: Set(DEFAULT_METRIC_COLLECTION_ENDPOINT),
metric_collection_bucket: Set(None),
disk_usage_based_eviction: Set(None),
test_remote_failures: Set(0),
@@ -586,6 +590,13 @@ impl PageServerConfigBuilder {
self.metric_collection_endpoint = BuilderValue::Set(metric_collection_endpoint)
}
pub fn metric_collection_bucket(
&mut self,
metric_collection_bucket: Option<RemoteStorageConfig>,
) {
self.metric_collection_bucket = BuilderValue::Set(metric_collection_bucket)
}
pub fn synthetic_size_calculation_interval(
&mut self,
synthetic_size_calculation_interval: Duration,
@@ -694,6 +705,7 @@ impl PageServerConfigBuilder {
metric_collection_interval,
cached_metric_collection_interval,
metric_collection_endpoint,
metric_collection_bucket,
synthetic_size_calculation_interval,
disk_usage_based_eviction,
test_remote_failures,
@@ -942,6 +954,9 @@ impl PageServerConf {
let endpoint = parse_toml_string(key, item)?.parse().context("failed to parse metric_collection_endpoint")?;
builder.metric_collection_endpoint(Some(endpoint));
},
"metric_collection_bucket" => {
builder.metric_collection_bucket(RemoteStorageConfig::from_toml(item)?)
}
"synthetic_size_calculation_interval" =>
builder.synthetic_size_calculation_interval(parse_toml_duration(key, item)?),
"test_remote_failures" => builder.test_remote_failures(parse_toml_u64(key, item)?),
@@ -1057,6 +1072,7 @@ impl PageServerConf {
metric_collection_interval: Duration::from_secs(60),
cached_metric_collection_interval: Duration::from_secs(60 * 60),
metric_collection_endpoint: defaults::DEFAULT_METRIC_COLLECTION_ENDPOINT,
metric_collection_bucket: None,
synthetic_size_calculation_interval: Duration::from_secs(60),
disk_usage_based_eviction: None,
test_remote_failures: 0,
@@ -1289,6 +1305,7 @@ background_task_maximum_delay = '334 s'
defaults::DEFAULT_CACHED_METRIC_COLLECTION_INTERVAL
)?,
metric_collection_endpoint: defaults::DEFAULT_METRIC_COLLECTION_ENDPOINT,
metric_collection_bucket: None,
synthetic_size_calculation_interval: humantime::parse_duration(
defaults::DEFAULT_SYNTHETIC_SIZE_CALCULATION_INTERVAL
)?,
@@ -1363,6 +1380,7 @@ background_task_maximum_delay = '334 s'
metric_collection_interval: Duration::from_secs(222),
cached_metric_collection_interval: Duration::from_secs(22200),
metric_collection_endpoint: Some(Url::parse("http://localhost:80/metrics")?),
metric_collection_bucket: None,
synthetic_size_calculation_interval: Duration::from_secs(333),
disk_usage_based_eviction: None,
test_remote_failures: 0,

View File

@@ -1,12 +1,13 @@
//! Periodically collect consumption metrics for all active tenants
//! and push them to a HTTP endpoint.
use crate::context::{DownloadBehavior, RequestContext};
use crate::task_mgr::{self, TaskKind, BACKGROUND_RUNTIME};
use crate::task_mgr::{self, TaskKind};
use crate::tenant::tasks::BackgroundLoopKind;
use crate::tenant::{mgr, LogicalSizeCalculationCause, PageReconstructError, Tenant};
use camino::Utf8PathBuf;
use consumption_metrics::EventType;
use pageserver_api::models::TenantState;
use remote_storage::{GenericRemoteStorage, RemoteStorageConfig};
use reqwest::Url;
use std::collections::HashMap;
use std::sync::Arc;
@@ -41,6 +42,7 @@ type Cache = HashMap<MetricsKey, (EventType, u64)>;
#[allow(clippy::too_many_arguments)]
pub async fn collect_metrics(
metric_collection_endpoint: &Url,
metric_collection_bucket: &Option<RemoteStorageConfig>,
metric_collection_interval: Duration,
_cached_metric_collection_interval: Duration,
synthetic_size_calculation_interval: Duration,
@@ -59,7 +61,6 @@ pub async fn collect_metrics(
let worker_ctx =
ctx.detached_child(TaskKind::CalculateSyntheticSize, DownloadBehavior::Download);
task_mgr::spawn(
BACKGROUND_RUNTIME.handle(),
TaskKind::CalculateSyntheticSize,
None,
None,
@@ -94,6 +95,20 @@ pub async fn collect_metrics(
.build()
.expect("Failed to create http client with timeout");
let bucket_client = if let Some(bucket_config) = metric_collection_bucket {
match GenericRemoteStorage::from_config(bucket_config) {
Ok(client) => Some(client),
Err(e) => {
// Non-fatal error: if we were given an invalid config, we will proceed
// with sending metrics over the network, but not to S3.
tracing::warn!("Invalid configuration for metric_collection_bucket: {e}");
None
}
}
} else {
None
};
let node_id = node_id.to_string();
loop {
@@ -118,10 +133,18 @@ pub async fn collect_metrics(
tracing::error!("failed to persist metrics to {path:?}: {e:#}");
}
}
if let Some(bucket_client) = &bucket_client {
let res =
upload::upload_metrics_bucket(bucket_client, &cancel, &node_id, &metrics).await;
if let Err(e) = res {
tracing::error!("failed to upload to S3: {e:#}");
}
}
};
let upload = async {
let res = upload::upload_metrics(
let res = upload::upload_metrics_http(
&client,
metric_collection_endpoint,
&cancel,
@@ -132,7 +155,7 @@ pub async fn collect_metrics(
.await;
if let Err(e) = res {
// serialization error which should never happen
tracing::error!("failed to upload due to {e:#}");
tracing::error!("failed to upload via HTTP due to {e:#}");
}
};

View File

@@ -1,4 +1,9 @@
use std::time::SystemTime;
use chrono::{DateTime, Utc};
use consumption_metrics::{Event, EventChunk, IdempotencyKey, CHUNK_SIZE};
use remote_storage::{GenericRemoteStorage, RemotePath};
use tokio::io::AsyncWriteExt;
use tokio_util::sync::CancellationToken;
use tracing::Instrument;
@@ -13,8 +18,9 @@ struct Ids {
pub(super) timeline_id: Option<TimelineId>,
}
/// Serialize and write metrics to an HTTP endpoint
#[tracing::instrument(skip_all, fields(metrics_total = %metrics.len()))]
pub(super) async fn upload_metrics(
pub(super) async fn upload_metrics_http(
client: &reqwest::Client,
metric_collection_endpoint: &reqwest::Url,
cancel: &CancellationToken,
@@ -74,6 +80,60 @@ pub(super) async fn upload_metrics(
Ok(())
}
/// Serialize and write metrics to a remote storage object
#[tracing::instrument(skip_all, fields(metrics_total = %metrics.len()))]
pub(super) async fn upload_metrics_bucket(
client: &GenericRemoteStorage,
cancel: &CancellationToken,
node_id: &str,
metrics: &[RawMetric],
) -> anyhow::Result<()> {
if metrics.is_empty() {
// Skip uploads if we have no metrics, so that readers don't have to handle the edge case
// of an empty object.
return Ok(());
}
// Compose object path
let datetime: DateTime<Utc> = SystemTime::now().into();
let ts_prefix = datetime.format("year=%Y/month=%m/day=%d/%H:%M:%SZ");
let path = RemotePath::from_string(&format!("{ts_prefix}_{node_id}.ndjson.gz"))?;
// Set up a gzip writer into a buffer
let mut compressed_bytes: Vec<u8> = Vec::new();
let compressed_writer = std::io::Cursor::new(&mut compressed_bytes);
let mut gzip_writer = async_compression::tokio::write::GzipEncoder::new(compressed_writer);
// Serialize and write into compressed buffer
let started_at = std::time::Instant::now();
for res in serialize_in_chunks(CHUNK_SIZE, metrics, node_id) {
let (_chunk, body) = res?;
gzip_writer.write_all(&body).await?;
}
gzip_writer.flush().await?;
gzip_writer.shutdown().await?;
let compressed_length = compressed_bytes.len();
// Write to remote storage
client
.upload_storage_object(
futures::stream::once(futures::future::ready(Ok(compressed_bytes.into()))),
compressed_length,
&path,
cancel,
)
.await?;
let elapsed = started_at.elapsed();
tracing::info!(
compressed_length,
elapsed_ms = elapsed.as_millis(),
"write metrics bucket at {path}",
);
Ok(())
}
// The return type is quite ugly, but we gain testability in isolation
fn serialize_in_chunks<'a, F>(
chunk_size: usize,

View File

@@ -5,7 +5,8 @@ use pageserver_api::{
controller_api::NodeRegisterRequest,
shard::TenantShardId,
upcall_api::{
ReAttachRequest, ReAttachResponse, ValidateRequest, ValidateRequestTenant, ValidateResponse,
ReAttachRequest, ReAttachResponse, ReAttachResponseTenant, ValidateRequest,
ValidateRequestTenant, ValidateResponse,
},
};
use serde::{de::DeserializeOwned, Serialize};
@@ -37,7 +38,9 @@ pub trait ControlPlaneGenerationsApi {
fn re_attach(
&self,
conf: &PageServerConf,
) -> impl Future<Output = Result<HashMap<TenantShardId, Generation>, RetryForeverError>> + Send;
) -> impl Future<
Output = Result<HashMap<TenantShardId, ReAttachResponseTenant>, RetryForeverError>,
> + Send;
fn validate(
&self,
tenants: Vec<(TenantShardId, Generation)>,
@@ -118,7 +121,7 @@ impl ControlPlaneGenerationsApi for ControlPlaneClient {
async fn re_attach(
&self,
conf: &PageServerConf,
) -> Result<HashMap<TenantShardId, Generation>, RetryForeverError> {
) -> Result<HashMap<TenantShardId, ReAttachResponseTenant>, RetryForeverError> {
let re_attach_path = self
.base_url
.join("re-attach")
@@ -170,8 +173,6 @@ impl ControlPlaneGenerationsApi for ControlPlaneClient {
register,
};
fail::fail_point!("control-plane-client-re-attach");
let response: ReAttachResponse = self.retry_http_forever(&re_attach_path, request).await?;
tracing::info!(
"Received re-attach response with {} tenants",
@@ -181,7 +182,7 @@ impl ControlPlaneGenerationsApi for ControlPlaneClient {
Ok(response
.tenants
.into_iter()
.map(|t| (t.id, Generation::new(t.gen)))
.map(|rart| (rart.id, rart))
.collect::<HashMap<_, _>>())
}
@@ -207,7 +208,7 @@ impl ControlPlaneGenerationsApi for ControlPlaneClient {
.collect(),
};
fail::fail_point!("control-plane-client-validate");
crate::tenant::pausable_failpoint!("control-plane-client-validate");
let response: ValidateResponse = self.retry_http_forever(&re_attach_path, request).await?;

View File

@@ -724,8 +724,8 @@ impl DeletionQueue {
mod test {
use camino::Utf8Path;
use hex_literal::hex;
use pageserver_api::shard::ShardIndex;
use std::io::ErrorKind;
use pageserver_api::{shard::ShardIndex, upcall_api::ReAttachResponseTenant};
use std::{io::ErrorKind, time::Duration};
use tracing::info;
use remote_storage::{RemoteStorageConfig, RemoteStorageKind};
@@ -834,9 +834,10 @@ mod test {
async fn re_attach(
&self,
_conf: &PageServerConf,
) -> Result<HashMap<TenantShardId, Generation>, RetryForeverError> {
) -> Result<HashMap<TenantShardId, ReAttachResponseTenant>, RetryForeverError> {
unimplemented!()
}
async fn validate(
&self,
tenants: Vec<(TenantShardId, Generation)>,

View File

@@ -59,7 +59,7 @@ use utils::{completion, id::TimelineId};
use crate::{
config::PageServerConf,
metrics::disk_usage_based_eviction::METRICS,
task_mgr::{self, TaskKind, BACKGROUND_RUNTIME},
task_mgr::{self, TaskKind},
tenant::{
self,
mgr::TenantManager,
@@ -202,7 +202,6 @@ pub fn launch_disk_usage_global_eviction_task(
info!("launching disk usage based eviction task");
task_mgr::spawn(
BACKGROUND_RUNTIME.handle(),
TaskKind::DiskUsageEviction,
None,
None,

View File

@@ -36,6 +36,7 @@ use tokio_util::sync::CancellationToken;
use tracing::*;
use utils::auth::JwtAuth;
use utils::failpoint_support::failpoints_handler;
use utils::http::endpoint::prometheus_metrics_handler;
use utils::http::endpoint::request_span;
use utils::http::json::json_request_or_empty_body;
use utils::http::request::{get_request_param, must_get_query_param, parse_query_param};
@@ -2266,6 +2267,7 @@ pub fn make_router(
Ok(router
.data(state)
.get("/metrics", |r| request_span(r, prometheus_metrics_handler))
.get("/v1/status", |r| api_handler(r, status_handler))
.put("/v1/failpoints", |r| {
testing_api_handler("manage failpoints", r, failpoints_handler)

View File

@@ -699,6 +699,14 @@ pub static STARTUP_IS_LOADING: Lazy<UIntGauge> = Lazy::new(|| {
.expect("Failed to register pageserver_startup_is_loading")
});
pub(crate) static TIMELINE_EPHEMERAL_BYTES: Lazy<UIntGauge> = Lazy::new(|| {
register_uint_gauge!(
"pageserver_timeline_ephemeral_bytes",
"Total number of bytes in ephemeral layers, summed for all timelines. Approximate, lazily updated."
)
.expect("Failed to register metric")
});
/// Metrics related to the lifecycle of a [`crate::tenant::Tenant`] object: things
/// like how long it took to load.
///

View File

@@ -180,7 +180,6 @@ pub async fn libpq_listener_main(
// only deal with a particular timeline, but we don't know which one
// yet.
task_mgr::spawn(
&tokio::runtime::Handle::current(),
TaskKind::PageRequestHandler,
None,
None,

View File

@@ -98,42 +98,22 @@ use utils::id::TimelineId;
// other operations, if the upload tasks e.g. get blocked on locks. It shouldn't
// happen, but still.
//
pub static COMPUTE_REQUEST_RUNTIME: Lazy<Runtime> = Lazy::new(|| {
tokio::runtime::Builder::new_multi_thread()
.thread_name("compute request worker")
.enable_all()
.build()
.expect("Failed to create compute request runtime")
});
pub static MGMT_REQUEST_RUNTIME: Lazy<Runtime> = Lazy::new(|| {
/// The single tokio runtime used by all pageserver code.
/// In the past, we had multiple runtimes, and in the future we should weed out
/// remaining references to this global field and rely on ambient runtime instead,
/// i.e., use `tokio::spawn` instead of `THE_RUNTIME.spawn()`, etc.
pub static THE_RUNTIME: Lazy<Runtime> = Lazy::new(|| {
tokio::runtime::Builder::new_multi_thread()
.thread_name("mgmt request worker")
.enable_all()
.build()
.expect("Failed to create mgmt request runtime")
});
pub static WALRECEIVER_RUNTIME: Lazy<Runtime> = Lazy::new(|| {
tokio::runtime::Builder::new_multi_thread()
.thread_name("walreceiver worker")
.enable_all()
.build()
.expect("Failed to create walreceiver runtime")
});
pub static BACKGROUND_RUNTIME: Lazy<Runtime> = Lazy::new(|| {
tokio::runtime::Builder::new_multi_thread()
.thread_name("background op worker")
// if you change the number of worker threads please change the constant below
.enable_all()
.build()
.expect("Failed to create background op runtime")
});
pub(crate) static BACKGROUND_RUNTIME_WORKER_THREADS: Lazy<usize> = Lazy::new(|| {
pub(crate) static THE_RUNTIME_WORKER_THREADS: Lazy<usize> = Lazy::new(|| {
// force init and thus panics
let _ = BACKGROUND_RUNTIME.handle();
let _ = THE_RUNTIME.handle();
// replicates tokio-1.28.1::loom::sys::num_cpus which is not available publicly
// tokio would had already panicked for parsing errors or NotUnicode
//
@@ -325,7 +305,6 @@ struct PageServerTask {
/// Note: if shutdown_process_on_error is set to true failure
/// of the task will lead to shutdown of entire process
pub fn spawn<F>(
runtime: &tokio::runtime::Handle,
kind: TaskKind,
tenant_shard_id: Option<TenantShardId>,
timeline_id: Option<TimelineId>,
@@ -354,7 +333,7 @@ where
let task_name = name.to_string();
let task_cloned = Arc::clone(&task);
let join_handle = runtime.spawn(task_wrapper(
let join_handle = THE_RUNTIME.spawn(task_wrapper(
task_name,
task_id,
task_cloned,

View File

@@ -144,6 +144,7 @@ macro_rules! pausable_failpoint {
}
};
}
pub(crate) use pausable_failpoint;
pub mod blob_io;
pub mod block_io;
@@ -202,6 +203,13 @@ pub(super) struct AttachedTenantConf {
}
impl AttachedTenantConf {
fn new(tenant_conf: TenantConfOpt, location: AttachedLocationConfig) -> Self {
Self {
tenant_conf,
location,
}
}
fn try_from(location_conf: LocationConf) -> anyhow::Result<Self> {
match &location_conf.mode {
LocationMode::Attached(attach_conf) => Ok(Self {
@@ -654,7 +662,6 @@ impl Tenant {
let tenant_clone = Arc::clone(&tenant);
let ctx = ctx.detached_child(TaskKind::Attach, DownloadBehavior::Warn);
task_mgr::spawn(
&tokio::runtime::Handle::current(),
TaskKind::Attach,
Some(tenant_shard_id),
None,
@@ -678,9 +685,20 @@ impl Tenant {
}
// Ideally we should use Tenant::set_broken_no_wait, but it is not supposed to be used when tenant is in loading state.
enum BrokenVerbosity {
Error,
Info
}
let make_broken =
|t: &Tenant, err: anyhow::Error| {
error!("attach failed, setting tenant state to Broken: {err:?}");
|t: &Tenant, err: anyhow::Error, verbosity: BrokenVerbosity| {
match verbosity {
BrokenVerbosity::Info => {
info!("attach cancelled, setting tenant state to Broken: {err}");
},
BrokenVerbosity::Error => {
error!("attach failed, setting tenant state to Broken: {err:?}");
}
}
t.state.send_modify(|state| {
// The Stopping case is for when we have passed control on to DeleteTenantFlow:
// if it errors, we will call make_broken when tenant is already in Stopping.
@@ -744,7 +762,7 @@ impl Tenant {
// Make the tenant broken so that set_stopping will not hang waiting for it to leave
// the Attaching state. This is an over-reaction (nothing really broke, the tenant is
// just shutting down), but ensures progress.
make_broken(&tenant_clone, anyhow::anyhow!("Shut down while Attaching"));
make_broken(&tenant_clone, anyhow::anyhow!("Shut down while Attaching"), BrokenVerbosity::Info);
return Ok(());
},
)
@@ -766,7 +784,7 @@ impl Tenant {
match res {
Ok(p) => Some(p),
Err(e) => {
make_broken(&tenant_clone, anyhow::anyhow!(e));
make_broken(&tenant_clone, anyhow::anyhow!(e), BrokenVerbosity::Error);
return Ok(());
}
}
@@ -790,7 +808,7 @@ impl Tenant {
{
Ok(should_resume_deletion) => should_resume_deletion,
Err(err) => {
make_broken(&tenant_clone, anyhow::anyhow!(err));
make_broken(&tenant_clone, anyhow::anyhow!(err), BrokenVerbosity::Error);
return Ok(());
}
}
@@ -820,7 +838,7 @@ impl Tenant {
.await;
if let Err(e) = deleted {
make_broken(&tenant_clone, anyhow::anyhow!(e));
make_broken(&tenant_clone, anyhow::anyhow!(e), BrokenVerbosity::Error);
}
return Ok(());
@@ -841,7 +859,7 @@ impl Tenant {
tenant_clone.activate(broker_client, None, &ctx);
}
Err(e) => {
make_broken(&tenant_clone, anyhow::anyhow!(e));
make_broken(&tenant_clone, anyhow::anyhow!(e), BrokenVerbosity::Error);
}
}

View File

@@ -196,16 +196,17 @@ impl LocationConf {
/// For use when attaching/re-attaching: update the generation stored in this
/// structure. If we were in a secondary state, promote to attached (posession
/// of a fresh generation implies this).
pub(crate) fn attach_in_generation(&mut self, generation: Generation) {
pub(crate) fn attach_in_generation(&mut self, mode: AttachmentMode, generation: Generation) {
match &mut self.mode {
LocationMode::Attached(attach_conf) => {
attach_conf.generation = generation;
attach_conf.attach_mode = mode;
}
LocationMode::Secondary(_) => {
// We are promoted to attached by the control plane's re-attach response
self.mode = LocationMode::Attached(AttachedLocationConfig {
generation,
attach_mode: AttachmentMode::Single,
attach_mode: mode,
})
}
}

View File

@@ -111,6 +111,7 @@ async fn create_local_delete_mark(
let _ = std::fs::OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&marker_path)
.with_context(|| format!("could not create delete marker file {marker_path:?}"))?;
@@ -481,7 +482,6 @@ impl DeleteTenantFlow {
let tenant_shard_id = tenant.tenant_shard_id;
task_mgr::spawn(
task_mgr::BACKGROUND_RUNTIME.handle(),
TaskKind::TimelineDeletionWorker,
Some(tenant_shard_id),
None,

View File

@@ -4,10 +4,11 @@
use camino::{Utf8DirEntry, Utf8Path, Utf8PathBuf};
use itertools::Itertools;
use pageserver_api::key::Key;
use pageserver_api::models::ShardParameters;
use pageserver_api::models::{LocationConfigMode, ShardParameters};
use pageserver_api::shard::{
ShardCount, ShardIdentity, ShardNumber, ShardStripeSize, TenantShardId,
};
use pageserver_api::upcall_api::ReAttachResponseTenant;
use rand::{distributions::Alphanumeric, Rng};
use std::borrow::Cow;
use std::cmp::Ordering;
@@ -124,6 +125,46 @@ pub(crate) enum ShardSelector {
Page(Key),
}
/// A convenience for use with the re_attach ControlPlaneClient function: rather
/// than the serializable struct, we build this enum that encapsulates
/// the invariant that attached tenants always have generations.
///
/// This represents the subset of a LocationConfig that we receive during re-attach.
pub(crate) enum TenantStartupMode {
Attached((AttachmentMode, Generation)),
Secondary,
}
impl TenantStartupMode {
/// Return the generation & mode that should be used when starting
/// this tenant.
///
/// If this returns None, the re-attach struct is in an invalid state and
/// should be ignored in the response.
fn from_reattach_tenant(rart: ReAttachResponseTenant) -> Option<Self> {
match (rart.mode, rart.gen) {
(LocationConfigMode::Detached, _) => None,
(LocationConfigMode::Secondary, _) => Some(Self::Secondary),
(LocationConfigMode::AttachedMulti, Some(g)) => {
Some(Self::Attached((AttachmentMode::Multi, Generation::new(g))))
}
(LocationConfigMode::AttachedSingle, Some(g)) => {
Some(Self::Attached((AttachmentMode::Single, Generation::new(g))))
}
(LocationConfigMode::AttachedStale, Some(g)) => {
Some(Self::Attached((AttachmentMode::Stale, Generation::new(g))))
}
_ => {
tracing::warn!(
"Received invalid re-attach state for tenant {}: {rart:?}",
rart.id
);
None
}
}
}
}
impl TenantsMap {
/// Convenience function for typical usage, where we want to get a `Tenant` object, for
/// working with attached tenants. If the TenantId is in the map but in Secondary state,
@@ -270,7 +311,7 @@ pub struct TenantManager {
fn emergency_generations(
tenant_confs: &HashMap<TenantShardId, anyhow::Result<LocationConf>>,
) -> HashMap<TenantShardId, Generation> {
) -> HashMap<TenantShardId, TenantStartupMode> {
tenant_confs
.iter()
.filter_map(|(tid, lc)| {
@@ -278,12 +319,15 @@ fn emergency_generations(
Ok(lc) => lc,
Err(_) => return None,
};
let gen = match &lc.mode {
LocationMode::Attached(alc) => Some(alc.generation),
LocationMode::Secondary(_) => None,
};
gen.map(|g| (*tid, g))
Some((
*tid,
match &lc.mode {
LocationMode::Attached(alc) => {
TenantStartupMode::Attached((alc.attach_mode, alc.generation))
}
LocationMode::Secondary(_) => TenantStartupMode::Secondary,
},
))
})
.collect()
}
@@ -293,7 +337,7 @@ async fn init_load_generations(
tenant_confs: &HashMap<TenantShardId, anyhow::Result<LocationConf>>,
resources: &TenantSharedResources,
cancel: &CancellationToken,
) -> anyhow::Result<Option<HashMap<TenantShardId, Generation>>> {
) -> anyhow::Result<Option<HashMap<TenantShardId, TenantStartupMode>>> {
let generations = if conf.control_plane_emergency_mode {
error!(
"Emergency mode! Tenants will be attached unsafely using their last known generation"
@@ -303,7 +347,12 @@ async fn init_load_generations(
info!("Calling control plane API to re-attach tenants");
// If we are configured to use the control plane API, then it is the source of truth for what tenants to load.
match client.re_attach(conf).await {
Ok(tenants) => tenants,
Ok(tenants) => tenants
.into_iter()
.flat_map(|(id, rart)| {
TenantStartupMode::from_reattach_tenant(rart).map(|tsm| (id, tsm))
})
.collect(),
Err(RetryForeverError::ShuttingDown) => {
anyhow::bail!("Shut down while waiting for control plane re-attach response")
}
@@ -321,9 +370,17 @@ async fn init_load_generations(
// Must only do this if remote storage is enabled, otherwise deletion queue
// is not running and channel push will fail.
if resources.remote_storage.is_some() {
resources
.deletion_queue_client
.recover(generations.clone())?;
let attached_tenants = generations
.iter()
.flat_map(|(id, start_mode)| {
match start_mode {
TenantStartupMode::Attached((_mode, generation)) => Some(generation),
TenantStartupMode::Secondary => None,
}
.map(|gen| (*id, *gen))
})
.collect();
resources.deletion_queue_client.recover(attached_tenants)?;
}
Ok(Some(generations))
@@ -489,9 +546,8 @@ pub async fn init_tenant_mgr(
// Scan local filesystem for attached tenants
let tenant_configs = init_load_tenant_configs(conf).await?;
// Determine which tenants are to be attached
let tenant_generations =
init_load_generations(conf, &tenant_configs, &resources, &cancel).await?;
// Determine which tenants are to be secondary or attached, and in which generation
let tenant_modes = init_load_generations(conf, &tenant_configs, &resources, &cancel).await?;
tracing::info!(
"Attaching {} tenants at startup, warming up {} at a time",
@@ -521,97 +577,102 @@ pub async fn init_tenant_mgr(
}
};
let generation = if let Some(generations) = &tenant_generations {
// FIXME: if we were attached, and get demoted to secondary on re-attach, we
// don't have a place to get a config.
// (https://github.com/neondatabase/neon/issues/5377)
const DEFAULT_SECONDARY_CONF: SecondaryLocationConfig =
SecondaryLocationConfig { warm: true };
// Update the location config according to the re-attach response
if let Some(tenant_modes) = &tenant_modes {
// We have a generation map: treat it as the authority for whether
// this tenant is really attached.
if let Some(gen) = generations.get(&tenant_shard_id) {
if let LocationMode::Attached(attached) = &location_conf.mode {
if attached.generation > *gen {
match tenant_modes.get(&tenant_shard_id) {
None => {
info!(tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), "Detaching tenant, control plane omitted it in re-attach response");
if let Err(e) = safe_remove_tenant_dir_all(&tenant_dir_path).await {
error!(tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(),
"Failed to remove detached tenant directory '{tenant_dir_path}': {e:?}",
);
}
// We deleted local content: move on to next tenant, don't try and spawn this one.
continue;
}
Some(TenantStartupMode::Secondary) => {
if !matches!(location_conf.mode, LocationMode::Secondary(_)) {
location_conf.mode = LocationMode::Secondary(DEFAULT_SECONDARY_CONF);
}
}
Some(TenantStartupMode::Attached((attach_mode, generation))) => {
let old_gen_higher = match &location_conf.mode {
LocationMode::Attached(AttachedLocationConfig {
generation: old_generation,
attach_mode: _attach_mode,
}) => {
if old_generation > generation {
Some(old_generation)
} else {
None
}
}
_ => None,
};
if let Some(old_generation) = old_gen_higher {
tracing::error!(tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(),
"Control plane gave decreasing generation ({gen:?}) in re-attach response for tenant that was attached in generation {:?}, demoting to secondary",
attached.generation
"Control plane gave decreasing generation ({generation:?}) in re-attach response for tenant that was attached in generation {:?}, demoting to secondary",
old_generation
);
// We cannot safely attach this tenant given a bogus generation number, but let's avoid throwing away
// local disk content: demote to secondary rather than detaching.
tenants.insert(
tenant_shard_id,
TenantSlot::Secondary(SecondaryTenant::new(
tenant_shard_id,
location_conf.shard,
location_conf.tenant_conf.clone(),
&SecondaryLocationConfig { warm: false },
)),
);
location_conf.mode = LocationMode::Secondary(DEFAULT_SECONDARY_CONF);
} else {
location_conf.attach_in_generation(*attach_mode, *generation);
}
}
*gen
} else {
match &location_conf.mode {
LocationMode::Secondary(secondary_config) => {
// We do not require the control plane's permission for secondary mode
// tenants, because they do no remote writes and hence require no
// generation number
info!(tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), "Loaded tenant in secondary mode");
tenants.insert(
tenant_shard_id,
TenantSlot::Secondary(SecondaryTenant::new(
tenant_shard_id,
location_conf.shard,
location_conf.tenant_conf,
secondary_config,
)),
);
}
LocationMode::Attached(_) => {
// TODO: augment re-attach API to enable the control plane to
// instruct us about secondary attachments. That way, instead of throwing
// away local state, we can gracefully fall back to secondary here, if the control
// plane tells us so.
// (https://github.com/neondatabase/neon/issues/5377)
info!(tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), "Detaching tenant, control plane omitted it in re-attach response");
if let Err(e) = safe_remove_tenant_dir_all(&tenant_dir_path).await {
error!(tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(),
"Failed to remove detached tenant directory '{tenant_dir_path}': {e:?}",
);
}
}
};
continue;
}
} else {
// Legacy mode: no generation information, any tenant present
// on local disk may activate
info!(tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), "Starting tenant in legacy mode, no generation",);
Generation::none()
};
// Presence of a generation number implies attachment: attach the tenant
// if it wasn't already, and apply the generation number.
location_conf.attach_in_generation(generation);
Tenant::persist_tenant_config(conf, &tenant_shard_id, &location_conf).await?;
let shard_identity = location_conf.shard;
match tenant_spawn(
conf,
tenant_shard_id,
&tenant_dir_path,
resources.clone(),
AttachedTenantConf::try_from(location_conf)?,
shard_identity,
Some(init_order.clone()),
&TENANTS,
SpawnMode::Lazy,
&ctx,
) {
Ok(tenant) => {
tenants.insert(tenant_shard_id, TenantSlot::Attached(tenant));
let slot = match location_conf.mode {
LocationMode::Attached(attached_conf) => {
match tenant_spawn(
conf,
tenant_shard_id,
&tenant_dir_path,
resources.clone(),
AttachedTenantConf::new(location_conf.tenant_conf, attached_conf),
shard_identity,
Some(init_order.clone()),
&TENANTS,
SpawnMode::Lazy,
&ctx,
) {
Ok(tenant) => TenantSlot::Attached(tenant),
Err(e) => {
error!(tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), "Failed to start tenant: {e:#}");
continue;
}
}
}
Err(e) => {
error!(tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), "Failed to start tenant: {e:#}");
}
}
LocationMode::Secondary(secondary_conf) => TenantSlot::Secondary(SecondaryTenant::new(
tenant_shard_id,
shard_identity,
location_conf.tenant_conf,
&secondary_conf,
)),
};
tenants.insert(tenant_shard_id, slot);
}
info!("Processed {} local tenants at startup", tenants.len());
@@ -1789,7 +1850,6 @@ impl TenantManager {
let task_tenant_id = None;
task_mgr::spawn(
task_mgr::BACKGROUND_RUNTIME.handle(),
TaskKind::MgmtRequest,
task_tenant_id,
None,
@@ -2142,7 +2202,7 @@ pub(crate) async fn load_tenant(
let mut location_conf =
Tenant::load_tenant_config(conf, &tenant_shard_id).map_err(TenantMapInsertError::Other)?;
location_conf.attach_in_generation(generation);
location_conf.attach_in_generation(AttachmentMode::Single, generation);
Tenant::persist_tenant_config(conf, &tenant_shard_id, &location_conf).await?;
@@ -2755,15 +2815,12 @@ pub(crate) fn immediate_gc(
// TODO: spawning is redundant now, need to hold the gate
task_mgr::spawn(
&tokio::runtime::Handle::current(),
TaskKind::GarbageCollector,
Some(tenant_shard_id),
Some(timeline_id),
&format!("timeline_gc_handler garbage collection run for tenant {tenant_shard_id} timeline {timeline_id}"),
false,
async move {
fail::fail_point!("immediate_gc_task_pre");
#[allow(unused_mut)]
let mut result = tenant
.gc_iteration(Some(timeline_id), gc_horizon, pitr, &cancel, &ctx)

View File

@@ -223,7 +223,6 @@ use crate::{
config::PageServerConf,
task_mgr,
task_mgr::TaskKind,
task_mgr::BACKGROUND_RUNTIME,
tenant::metadata::TimelineMetadata,
tenant::upload_queue::{
UploadOp, UploadQueue, UploadQueueInitialized, UploadQueueStopped, UploadTask,
@@ -307,8 +306,6 @@ pub enum PersistIndexPartWithDeletedFlagError {
pub struct RemoteTimelineClient {
conf: &'static PageServerConf,
runtime: tokio::runtime::Handle,
tenant_shard_id: TenantShardId,
timeline_id: TimelineId,
generation: Generation,
@@ -341,12 +338,6 @@ impl RemoteTimelineClient {
) -> RemoteTimelineClient {
RemoteTimelineClient {
conf,
runtime: if cfg!(test) {
// remote_timeline_client.rs tests rely on current-thread runtime
tokio::runtime::Handle::current()
} else {
BACKGROUND_RUNTIME.handle().clone()
},
tenant_shard_id,
timeline_id,
generation,
@@ -1281,7 +1272,6 @@ impl RemoteTimelineClient {
let tenant_shard_id = self.tenant_shard_id;
let timeline_id = self.timeline_id;
task_mgr::spawn(
&self.runtime,
TaskKind::RemoteUploadTask,
Some(self.tenant_shard_id),
Some(self.timeline_id),
@@ -1876,7 +1866,6 @@ mod tests {
fn build_client(&self, generation: Generation) -> Arc<RemoteTimelineClient> {
Arc::new(RemoteTimelineClient {
conf: self.harness.conf,
runtime: tokio::runtime::Handle::current(),
tenant_shard_id: self.harness.tenant_shard_id,
timeline_id: TIMELINE_ID,
generation,

View File

@@ -8,7 +8,7 @@ use std::{sync::Arc, time::SystemTime};
use crate::{
config::PageServerConf,
disk_usage_eviction_task::DiskUsageEvictionInfo,
task_mgr::{self, TaskKind, BACKGROUND_RUNTIME},
task_mgr::{self, TaskKind},
virtual_file::MaybeFatalIo,
};
@@ -317,7 +317,6 @@ pub fn spawn_tasks(
tokio::sync::mpsc::channel::<CommandRequest<UploadCommand>>(16);
task_mgr::spawn(
BACKGROUND_RUNTIME.handle(),
TaskKind::SecondaryDownloads,
None,
None,
@@ -338,7 +337,6 @@ pub fn spawn_tasks(
);
task_mgr::spawn(
BACKGROUND_RUNTIME.handle(),
TaskKind::SecondaryUploads,
None,
None,

View File

@@ -15,6 +15,7 @@ use crate::{
tenant::{
config::SecondaryLocationConfig,
debug_assert_current_span_has_tenant_and_timeline_id,
ephemeral_file::is_ephemeral_file,
remote_timeline_client::{
index::LayerFileMetadata, is_temp_download_file, FAILED_DOWNLOAD_WARN_THRESHOLD,
FAILED_REMOTE_OP_RETRIES,
@@ -961,7 +962,10 @@ async fn init_timeline_state(
// Secondary mode doesn't use local metadata files, but they might have been left behind by an attached tenant.
warn!(path=?dentry.path(), "found legacy metadata file, these should have been removed in load_tenant_config");
continue;
} else if crate::is_temporary(&file_path) || is_temp_download_file(&file_path) {
} else if crate::is_temporary(&file_path)
|| is_temp_download_file(&file_path)
|| is_ephemeral_file(file_name)
{
// Temporary files are frequently left behind from restarting during downloads
tracing::info!("Cleaning up temporary file {file_path}");
if let Err(e) = tokio::fs::remove_file(&file_path)

View File

@@ -23,8 +23,12 @@ use tracing::*;
use utils::{bin_ser::BeSer, id::TimelineId, lsn::Lsn, vec_map::VecMap};
// avoid binding to Write (conflicts with std::io::Write)
// while being able to use std::fmt::Write's methods
use crate::metrics::TIMELINE_EPHEMERAL_BYTES;
use std::cmp::Ordering;
use std::fmt::Write as _;
use std::ops::Range;
use std::sync::atomic::Ordering as AtomicOrdering;
use std::sync::atomic::{AtomicU64, AtomicUsize};
use tokio::sync::{RwLock, RwLockWriteGuard};
use super::{
@@ -70,6 +74,8 @@ pub struct InMemoryLayerInner {
/// Each serialized Value is preceded by a 'u32' length field.
/// PerSeg::page_versions map stores offsets into this file.
file: EphemeralFile,
resource_units: GlobalResourceUnits,
}
impl std::fmt::Debug for InMemoryLayerInner {
@@ -78,6 +84,101 @@ impl std::fmt::Debug for InMemoryLayerInner {
}
}
/// State shared by all in-memory (ephemeral) layers. Updated infrequently during background ticks in Timeline,
/// to minimize contention.
///
/// This global state is used to implement behaviors that require a global view of the system, e.g.
/// rolling layers proactively to limit the total amount of dirty data.
struct GlobalResources {
// How many bytes are in all EphemeralFile objects
dirty_bytes: AtomicU64,
// How many layers are contributing to dirty_bytes
dirty_layers: AtomicUsize,
}
// Per-timeline RAII struct for its contribution to [`GlobalResources`]
struct GlobalResourceUnits {
// How many dirty bytes have I added to the global dirty_bytes: this guard object is responsible
// for decrementing the global counter by this many bytes when dropped.
dirty_bytes: u64,
}
impl GlobalResourceUnits {
// Hint for the layer append path to update us when the layer size differs from the last
// call to update_size by this much. If we don't reach this threshold, we'll still get
// updated when the Timeline "ticks" in the background.
const MAX_SIZE_DRIFT: u64 = 10 * 1024 * 1024;
fn new() -> Self {
GLOBAL_RESOURCES
.dirty_layers
.fetch_add(1, AtomicOrdering::Relaxed);
Self { dirty_bytes: 0 }
}
/// Do not call this frequently: all timelines will write to these same global atomics,
/// so this is a relatively expensive operation. Wait at least a few seconds between calls.
fn publish_size(&mut self, size: u64) {
let new_global_dirty_bytes = match size.cmp(&self.dirty_bytes) {
Ordering::Equal => {
return;
}
Ordering::Greater => {
let delta = size - self.dirty_bytes;
let old = GLOBAL_RESOURCES
.dirty_bytes
.fetch_add(delta, AtomicOrdering::Relaxed);
old + delta
}
Ordering::Less => {
let delta = self.dirty_bytes - size;
let old = GLOBAL_RESOURCES
.dirty_bytes
.fetch_sub(delta, AtomicOrdering::Relaxed);
old - delta
}
};
// This is a sloppy update: concurrent updates to the counter will race, and the exact
// value of the metric might not be the exact latest value of GLOBAL_RESOURCES::dirty_bytes.
// That's okay: as long as the metric contains some recent value, it doesn't have to always
// be literally the last update.
TIMELINE_EPHEMERAL_BYTES.set(new_global_dirty_bytes);
self.dirty_bytes = size;
}
// Call publish_size if the input size differs from last published size by more than
// the drift limit
fn maybe_publish_size(&mut self, size: u64) {
let publish = match size.cmp(&self.dirty_bytes) {
Ordering::Equal => false,
Ordering::Greater => size - self.dirty_bytes > Self::MAX_SIZE_DRIFT,
Ordering::Less => self.dirty_bytes - size > Self::MAX_SIZE_DRIFT,
};
if publish {
self.publish_size(size);
}
}
}
impl Drop for GlobalResourceUnits {
fn drop(&mut self) {
GLOBAL_RESOURCES
.dirty_layers
.fetch_sub(1, AtomicOrdering::Relaxed);
// Subtract our contribution to the global total dirty bytes
self.publish_size(0);
}
}
static GLOBAL_RESOURCES: GlobalResources = GlobalResources {
dirty_bytes: AtomicU64::new(0),
dirty_layers: AtomicUsize::new(0),
};
impl InMemoryLayer {
pub(crate) fn get_timeline_id(&self) -> TimelineId {
self.timeline_id
@@ -328,6 +429,7 @@ impl InMemoryLayer {
inner: RwLock::new(InMemoryLayerInner {
index: HashMap::new(),
file,
resource_units: GlobalResourceUnits::new(),
}),
})
}
@@ -378,9 +480,18 @@ impl InMemoryLayer {
warn!("Key {} at {} already exists", key, lsn);
}
let size = locked_inner.file.len();
locked_inner.resource_units.maybe_publish_size(size);
Ok(())
}
pub(crate) async fn tick(&self) {
let mut inner = self.inner.write().await;
let size = inner.file.len();
inner.resource_units.publish_size(size);
}
pub(crate) async fn put_tombstones(&self, _key_ranges: &[(Range<Key>, Lsn)]) -> Result<()> {
// TODO: Currently, we just leak the storage for any deleted keys
Ok(())

View File

@@ -1447,7 +1447,7 @@ impl LayerInner {
#[cfg(test)]
tokio::task::spawn(fut);
#[cfg(not(test))]
crate::task_mgr::BACKGROUND_RUNTIME.spawn(fut);
crate::task_mgr::THE_RUNTIME.spawn(fut);
}
/// Needed to use entered runtime in tests, but otherwise use BACKGROUND_RUNTIME.
@@ -1458,7 +1458,7 @@ impl LayerInner {
#[cfg(test)]
tokio::task::spawn_blocking(f);
#[cfg(not(test))]
crate::task_mgr::BACKGROUND_RUNTIME.spawn_blocking(f);
crate::task_mgr::THE_RUNTIME.spawn_blocking(f);
}
}

View File

@@ -8,7 +8,7 @@ use std::time::{Duration, Instant};
use crate::context::{DownloadBehavior, RequestContext};
use crate::metrics::TENANT_TASK_EVENTS;
use crate::task_mgr;
use crate::task_mgr::{TaskKind, BACKGROUND_RUNTIME};
use crate::task_mgr::TaskKind;
use crate::tenant::throttle::Stats;
use crate::tenant::timeline::CompactionError;
use crate::tenant::{Tenant, TenantState};
@@ -18,7 +18,7 @@ use utils::{backoff, completion};
static CONCURRENT_BACKGROUND_TASKS: once_cell::sync::Lazy<tokio::sync::Semaphore> =
once_cell::sync::Lazy::new(|| {
let total_threads = *task_mgr::BACKGROUND_RUNTIME_WORKER_THREADS;
let total_threads = *crate::task_mgr::THE_RUNTIME_WORKER_THREADS;
let permits = usize::max(
1,
// while a lot of the work is done on spawn_blocking, we still do
@@ -85,7 +85,6 @@ pub fn start_background_loops(
) {
let tenant_shard_id = tenant.tenant_shard_id;
task_mgr::spawn(
BACKGROUND_RUNTIME.handle(),
TaskKind::Compaction,
Some(tenant_shard_id),
None,
@@ -109,7 +108,6 @@ pub fn start_background_loops(
},
);
task_mgr::spawn(
BACKGROUND_RUNTIME.handle(),
TaskKind::GarbageCollector,
Some(tenant_shard_id),
None,

View File

@@ -1723,7 +1723,6 @@ impl Timeline {
initdb_optimization_count: 0,
};
task_mgr::spawn(
task_mgr::BACKGROUND_RUNTIME.handle(),
task_mgr::TaskKind::LayerFlushTask,
Some(self.tenant_shard_id),
Some(self.timeline_id),
@@ -2086,7 +2085,6 @@ impl Timeline {
DownloadBehavior::Download,
);
task_mgr::spawn(
task_mgr::BACKGROUND_RUNTIME.handle(),
task_mgr::TaskKind::InitialLogicalSizeCalculation,
Some(self.tenant_shard_id),
Some(self.timeline_id),
@@ -2264,7 +2262,6 @@ impl Timeline {
DownloadBehavior::Download,
);
task_mgr::spawn(
task_mgr::BACKGROUND_RUNTIME.handle(),
task_mgr::TaskKind::OndemandLogicalSizeCalculation,
Some(self.tenant_shard_id),
Some(self.timeline_id),
@@ -3840,7 +3837,7 @@ impl Timeline {
};
let timer = self.metrics.garbage_collect_histo.start_timer();
fail_point!("before-timeline-gc");
pausable_failpoint!("before-timeline-gc");
// Is the timeline being deleted?
if self.is_stopping() {
@@ -4151,7 +4148,6 @@ impl Timeline {
let self_clone = Arc::clone(&self);
let task_id = task_mgr::spawn(
task_mgr::BACKGROUND_RUNTIME.handle(),
task_mgr::TaskKind::DownloadAllRemoteLayers,
Some(self.tenant_shard_id),
Some(self.timeline_id),
@@ -4469,6 +4465,9 @@ impl<'a> TimelineWriter<'a> {
let action = self.get_open_layer_action(last_record_lsn, 0);
if action == OpenLayerAction::Roll {
self.roll_layer(last_record_lsn).await?;
} else if let Some(writer_state) = &mut *self.write_guard {
// Periodic update of statistics
writer_state.open_layer.tick().await;
}
Ok(())

View File

@@ -263,6 +263,25 @@ impl Timeline {
}
}
// drop the readlock for now; in theory, gc could also remove the same layers as we are now
// compacting. FIXME: how to prepare such a test case?
// 0. tenant with minimal pitr
// 1. create 10 layers
// 2. await on pausable_failpoint after dropping the read guard
// 3. delete all data, vacuum, checkpoint
// 4. gc
//
// now gc deletes the layers and when we finally get to writing our results back in
// finish_compact_batch, LayerManager::finish_compact_l0 will panic when deleting a layer
// which does not exist in LayerMap::remove_historic_noflush or LayerFileManager::remove.
//
// is the easy solution just to make the deletions from compaction more lenient? currently
// gc holds a write lock, so it cannot have this problem right now. if gc were to be loosened to take the
// read lock only momentarily and write lock for applying, it would have a similar issue in
// trying to gc layers which have just been compacted.
stats.read_lock_held_compute_holes_micros = stats.read_lock_held_key_sort_micros.till_now();
drop_rlock(guard);
// Gather the files to compact in this iteration.
//
// Start with the oldest Level 0 delta file, and collect any other
@@ -347,6 +366,9 @@ impl Timeline {
stats.read_lock_held_key_sort_micros = stats.read_lock_held_prerequisites_micros.till_now();
let guard = self.layers.read().await;
let layers = guard.layer_map();
for &DeltaEntry { key: next_key, .. } in all_keys.iter() {
if let Some(prev_key) = prev {
// just first fast filter
@@ -370,8 +392,6 @@ impl Timeline {
}
prev = Some(next_key.next());
}
stats.read_lock_held_compute_holes_micros = stats.read_lock_held_key_sort_micros.till_now();
drop_rlock(guard);
stats.read_lock_drop_micros = stats.read_lock_held_compute_holes_micros.till_now();
let mut holes = heap.into_vec();
holes.sort_unstable_by_key(|hole| hole.key_range.start);

View File

@@ -443,7 +443,6 @@ impl DeleteTimelineFlow {
let timeline_id = timeline.timeline_id;
task_mgr::spawn(
task_mgr::BACKGROUND_RUNTIME.handle(),
TaskKind::TimelineDeletionWorker,
Some(tenant_shard_id),
Some(timeline_id),

View File

@@ -28,7 +28,7 @@ use tracing::{debug, error, info, info_span, instrument, warn, Instrument};
use crate::{
context::{DownloadBehavior, RequestContext},
pgdatadir_mapping::CollectKeySpaceError,
task_mgr::{self, TaskKind, BACKGROUND_RUNTIME},
task_mgr::{self, TaskKind},
tenant::{
tasks::BackgroundLoopKind, timeline::EvictionError, LogicalSizeCalculationCause, Tenant,
},
@@ -56,7 +56,6 @@ impl Timeline {
let self_clone = Arc::clone(self);
let background_tasks_can_start = background_tasks_can_start.cloned();
task_mgr::spawn(
BACKGROUND_RUNTIME.handle(),
TaskKind::Eviction,
Some(self.tenant_shard_id),
Some(self.timeline_id),

View File

@@ -24,7 +24,7 @@ mod connection_manager;
mod walreceiver_connection;
use crate::context::{DownloadBehavior, RequestContext};
use crate::task_mgr::{self, TaskKind, WALRECEIVER_RUNTIME};
use crate::task_mgr::{self, TaskKind};
use crate::tenant::debug_assert_current_span_has_tenant_and_timeline_id;
use crate::tenant::timeline::walreceiver::connection_manager::{
connection_manager_loop_step, ConnectionManagerState,
@@ -82,7 +82,6 @@ impl WalReceiver {
let loop_status = Arc::new(std::sync::RwLock::new(None));
let manager_status = Arc::clone(&loop_status);
task_mgr::spawn(
WALRECEIVER_RUNTIME.handle(),
TaskKind::WalReceiverManager,
Some(timeline.tenant_shard_id),
Some(timeline_id),
@@ -181,7 +180,7 @@ impl<E: Clone> TaskHandle<E> {
let (events_sender, events_receiver) = watch::channel(TaskStateUpdate::Started);
let cancellation_clone = cancellation.clone();
let join_handle = WALRECEIVER_RUNTIME.spawn(async move {
let join_handle = tokio::spawn(async move {
events_sender.send(TaskStateUpdate::Started).ok();
task(events_sender, cancellation_clone).await
// events_sender is dropped at some point during the .await above.

View File

@@ -11,7 +11,6 @@ use std::{
use anyhow::{anyhow, Context};
use bytes::BytesMut;
use chrono::{NaiveDateTime, Utc};
use fail::fail_point;
use futures::StreamExt;
use postgres::{error::SqlState, SimpleQueryMessage, SimpleQueryRow};
use postgres_ffi::WAL_SEGMENT_SIZE;
@@ -27,9 +26,7 @@ use super::TaskStateUpdate;
use crate::{
context::RequestContext,
metrics::{LIVE_CONNECTIONS_COUNT, WALRECEIVER_STARTED_CONNECTIONS, WAL_INGEST},
task_mgr,
task_mgr::TaskKind,
task_mgr::WALRECEIVER_RUNTIME,
task_mgr::{self, TaskKind},
tenant::{debug_assert_current_span_has_tenant_and_timeline_id, Timeline, WalReceiverInfo},
walingest::WalIngest,
walrecord::DecodedWALRecord,
@@ -163,7 +160,6 @@ pub(super) async fn handle_walreceiver_connection(
);
let connection_cancellation = cancellation.clone();
task_mgr::spawn(
WALRECEIVER_RUNTIME.handle(),
TaskKind::WalReceiverConnectionPoller,
Some(timeline.tenant_shard_id),
Some(timeline.timeline_id),
@@ -329,7 +325,17 @@ pub(super) async fn handle_walreceiver_connection(
filtered_records += 1;
}
fail_point!("walreceiver-after-ingest");
// don't simply use pausable_failpoint here because its spawn_blocking slows
// slows down the tests too much.
fail::fail_point!("walreceiver-after-ingest-blocking");
if let Err(()) = (|| {
fail::fail_point!("walreceiver-after-ingest-pause-activate", |_| {
Err(())
});
Ok(())
})() {
pausable_failpoint!("walreceiver-after-ingest-pause");
}
last_rec_lsn = lsn;

View File

@@ -312,7 +312,7 @@ pg_cluster_size(PG_FUNCTION_ARGS)
{
int64 size;
size = GetZenithCurrentClusterSize();
size = GetNeonCurrentClusterSize();
if (size == 0)
PG_RETURN_NULL();

View File

@@ -26,6 +26,8 @@ extern void pg_init_libpagestore(void);
extern void pg_init_walproposer(void);
extern uint64 BackpressureThrottlingTime(void);
extern void SetNeonCurrentClusterSize(uint64 size);
extern uint64 GetNeonCurrentClusterSize(void);
extern void replication_feedback_get_lsns(XLogRecPtr *writeLsn, XLogRecPtr *flushLsn, XLogRecPtr *applyLsn);
extern void PGDLLEXPORT WalProposerSync(int argc, char *argv[]);

View File

@@ -1831,7 +1831,7 @@ neon_extend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno,
reln->smgr_relpersistence == RELPERSISTENCE_PERMANENT &&
!IsAutoVacuumWorkerProcess())
{
uint64 current_size = GetZenithCurrentClusterSize();
uint64 current_size = GetNeonCurrentClusterSize();
if (current_size >= ((uint64) max_cluster_size) * 1024 * 1024)
ereport(ERROR,
@@ -1912,7 +1912,7 @@ neon_zeroextend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blocknum,
reln->smgr_relpersistence == RELPERSISTENCE_PERMANENT &&
!IsAutoVacuumWorkerProcess())
{
uint64 current_size = GetZenithCurrentClusterSize();
uint64 current_size = GetNeonCurrentClusterSize();
if (current_size >= ((uint64) max_cluster_size) * 1024 * 1024)
ereport(ERROR,

View File

@@ -287,6 +287,7 @@ typedef struct WalproposerShmemState
slock_t mutex;
term_t mineLastElectedTerm;
pg_atomic_uint64 backpressureThrottlingTime;
pg_atomic_uint64 currentClusterSize;
/* last feedback from each shard */
PageserverFeedback shard_ps_feedback[MAX_SHARDS];

View File

@@ -282,6 +282,7 @@ WalproposerShmemInit(void)
memset(walprop_shared, 0, WalproposerShmemSize());
SpinLockInit(&walprop_shared->mutex);
pg_atomic_init_u64(&walprop_shared->backpressureThrottlingTime, 0);
pg_atomic_init_u64(&walprop_shared->currentClusterSize, 0);
}
LWLockRelease(AddinShmemInitLock);
@@ -1972,7 +1973,7 @@ walprop_pg_process_safekeeper_feedback(WalProposer *wp, Safekeeper *sk)
/* Only one main shard sends non-zero currentClusterSize */
if (sk->appendResponse.ps_feedback.currentClusterSize > 0)
SetZenithCurrentClusterSize(sk->appendResponse.ps_feedback.currentClusterSize);
SetNeonCurrentClusterSize(sk->appendResponse.ps_feedback.currentClusterSize);
if (min_feedback.disk_consistent_lsn != standby_apply_lsn)
{
@@ -2094,6 +2095,18 @@ GetLogRepRestartLSN(WalProposer *wp)
return lrRestartLsn;
}
void SetNeonCurrentClusterSize(uint64 size)
{
pg_atomic_write_u64(&walprop_shared->currentClusterSize, size);
}
uint64 GetNeonCurrentClusterSize(void)
{
return pg_atomic_read_u64(&walprop_shared->currentClusterSize);
}
uint64 GetNeonCurrentClusterSize(void);
static const walproposer_api walprop_pg = {
.get_shmem_state = walprop_pg_get_shmem_state,
.start_streaming = walprop_pg_start_streaming,

View File

@@ -11,6 +11,10 @@ testing = []
[dependencies]
anyhow.workspace = true
async-trait.workspace = true
aws-config.workspace = true
aws-sdk-iam.workspace = true
aws-sigv4.workspace = true
aws-types.workspace = true
base64.workspace = true
bstr.workspace = true
bytes = { workspace = true, features = ["serde"] }
@@ -27,6 +31,7 @@ hashlink.workspace = true
hex.workspace = true
hmac.workspace = true
hostname.workspace = true
http.workspace = true
humantime.workspace = true
hyper-tungstenite.workspace = true
hyper.workspace = true
@@ -63,6 +68,7 @@ sha2 = { workspace = true, features = ["asm"] }
smol_str.workspace = true
smallvec.workspace = true
socket2.workspace = true
subtle.workspace = true
sync_wrapper.workspace = true
task-local-extensions.workspace = true
thiserror.workspace = true
@@ -91,6 +97,7 @@ workspace_hack.workspace = true
[dev-dependencies]
camino-tempfile.workspace = true
fallible-iterator.workspace = true
rcgen.workspace = true
rstest.workspace = true
tokio-postgres-rustls.workspace = true

View File

@@ -408,3 +408,228 @@ impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> {
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use postgres_protocol::{
authentication::sasl::{ChannelBinding, ScramSha256},
message::{backend::Message as PgMessage, frontend},
};
use provider::AuthSecret;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
use crate::{
auth::{ComputeUserInfoMaybeEndpoint, IpPattern},
config::AuthenticationConfig,
console::{
self,
provider::{self, CachedAllowedIps, CachedRoleSecret},
CachedNodeInfo,
},
context::RequestMonitoring,
proxy::NeonOptions,
scram::ServerSecret,
stream::{PqStream, Stream},
};
use super::auth_quirks;
struct Auth {
ips: Vec<IpPattern>,
secret: AuthSecret,
}
impl console::Api for Auth {
async fn get_role_secret(
&self,
_ctx: &mut RequestMonitoring,
_user_info: &super::ComputeUserInfo,
) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError> {
Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone())))
}
async fn get_allowed_ips_and_secret(
&self,
_ctx: &mut RequestMonitoring,
_user_info: &super::ComputeUserInfo,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>
{
Ok((
CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())),
Some(CachedRoleSecret::new_uncached(Some(self.secret.clone()))),
))
}
async fn wake_compute(
&self,
_ctx: &mut RequestMonitoring,
_user_info: &super::ComputeUserInfo,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
unimplemented!()
}
}
static CONFIG: &AuthenticationConfig = &AuthenticationConfig {
scram_protocol_timeout: std::time::Duration::from_secs(5),
};
async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
loop {
r.read_buf(&mut *b).await.unwrap();
if let Some(m) = PgMessage::parse(&mut *b).unwrap() {
break m;
}
}
}
#[tokio::test]
async fn auth_quirks_scram() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut ctx = RequestMonitoring::test();
let api = Auth {
ips: vec![],
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
};
let user_info = ComputeUserInfoMaybeEndpoint {
user: "conrad".into(),
endpoint_id: Some("endpoint".into()),
options: NeonOptions::default(),
};
let handle = tokio::spawn(async move {
let mut scram = ScramSha256::new(b"my-secret-password", ChannelBinding::unsupported());
let mut read = BytesMut::new();
// server should offer scram
match read_message(&mut client, &mut read).await {
PgMessage::AuthenticationSasl(a) => {
let options: Vec<&str> = a.mechanisms().collect().unwrap();
assert_eq!(options, ["SCRAM-SHA-256"]);
}
_ => panic!("wrong message"),
}
// client sends client-first-message
let mut write = BytesMut::new();
frontend::sasl_initial_response("SCRAM-SHA-256", scram.message(), &mut write).unwrap();
client.write_all(&write).await.unwrap();
// server response with server-first-message
match read_message(&mut client, &mut read).await {
PgMessage::AuthenticationSaslContinue(a) => {
scram.update(a.data()).await.unwrap();
}
_ => panic!("wrong message"),
}
// client response with client-final-message
write.clear();
frontend::sasl_response(scram.message(), &mut write).unwrap();
client.write_all(&write).await.unwrap();
// server response with server-final-message
match read_message(&mut client, &mut read).await {
PgMessage::AuthenticationSaslFinal(a) => {
scram.finish(a.data()).unwrap();
}
_ => panic!("wrong message"),
}
});
let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, false, CONFIG)
.await
.unwrap();
handle.await.unwrap();
}
#[tokio::test]
async fn auth_quirks_cleartext() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut ctx = RequestMonitoring::test();
let api = Auth {
ips: vec![],
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
};
let user_info = ComputeUserInfoMaybeEndpoint {
user: "conrad".into(),
endpoint_id: Some("endpoint".into()),
options: NeonOptions::default(),
};
let handle = tokio::spawn(async move {
let mut read = BytesMut::new();
let mut write = BytesMut::new();
// server should offer cleartext
match read_message(&mut client, &mut read).await {
PgMessage::AuthenticationCleartextPassword => {}
_ => panic!("wrong message"),
}
// client responds with password
write.clear();
frontend::password_message(b"my-secret-password", &mut write).unwrap();
client.write_all(&write).await.unwrap();
});
let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, CONFIG)
.await
.unwrap();
handle.await.unwrap();
}
#[tokio::test]
async fn auth_quirks_password_hack() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut ctx = RequestMonitoring::test();
let api = Auth {
ips: vec![],
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
};
let user_info = ComputeUserInfoMaybeEndpoint {
user: "conrad".into(),
endpoint_id: None,
options: NeonOptions::default(),
};
let handle = tokio::spawn(async move {
let mut read = BytesMut::new();
// server should offer cleartext
match read_message(&mut client, &mut read).await {
PgMessage::AuthenticationCleartextPassword => {}
_ => panic!("wrong message"),
}
// client responds with password
let mut write = BytesMut::new();
frontend::password_message(b"endpoint=my-endpoint;my-secret-password", &mut write)
.unwrap();
client.write_all(&write).await.unwrap();
});
let creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, CONFIG)
.await
.unwrap();
assert_eq!(creds.info.endpoint, "my-endpoint");
handle.await.unwrap();
}
}

View File

@@ -194,14 +194,7 @@ pub(crate) async fn validate_password_and_exchange(
}
// perform scram authentication as both client and server to validate the keys
AuthSecret::Scram(scram_secret) => {
use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
let sasl_client = ScramSha256::new(password, ChannelBinding::unsupported());
let outcome = crate::scram::exchange(
&scram_secret,
sasl_client,
crate::config::TlsServerEndPoint::Undefined,
)
.await?;
let outcome = crate::scram::exchange(&scram_secret, password).await?;
let client_key = match outcome {
sasl::Outcome::Success(client_key) => client_key,

View File

@@ -1,3 +1,10 @@
use aws_config::environment::EnvironmentVariableCredentialsProvider;
use aws_config::imds::credentials::ImdsCredentialsProvider;
use aws_config::meta::credentials::CredentialsProviderChain;
use aws_config::meta::region::RegionProviderChain;
use aws_config::profile::ProfileFileCredentialsProvider;
use aws_config::provider_config::ProviderConfig;
use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider;
use futures::future::Either;
use proxy::auth;
use proxy::auth::backend::MaybeOwned;
@@ -10,11 +17,14 @@ use proxy::config::ProjectInfoCacheOptions;
use proxy::console;
use proxy::context::parquet::ParquetUploadArgs;
use proxy::http;
use proxy::metrics::NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT;
use proxy::rate_limiter::EndpointRateLimiter;
use proxy::rate_limiter::RateBucketInfo;
use proxy::rate_limiter::RateLimiterConfig;
use proxy::redis::cancellation_publisher::RedisPublisherClient;
use proxy::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use proxy::redis::elasticache;
use proxy::redis::notifications;
use proxy::redis::publisher::RedisPublisherClient;
use proxy::serverless::GlobalConnPoolOptions;
use proxy::usage_metrics;
@@ -150,9 +160,24 @@ struct ProxyCliArgs {
/// disable ip check for http requests. If it is too time consuming, it could be turned off.
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
disable_ip_check_for_http: bool,
/// redis url for notifications.
/// redis url for notifications (if empty, redis_host:port will be used for both notifications and streaming connections)
#[clap(long)]
redis_notifications: Option<String>,
/// redis host for streaming connections (might be different from the notifications host)
#[clap(long)]
redis_host: Option<String>,
/// redis port for streaming connections (might be different from the notifications host)
#[clap(long)]
redis_port: Option<u16>,
/// redis cluster name, used in aws elasticache
#[clap(long)]
redis_cluster_name: Option<String>,
/// redis user_id, used in aws elasticache
#[clap(long)]
redis_user_id: Option<String>,
/// aws region to retrieve credentials
#[clap(long, default_value_t = String::new())]
aws_region: String,
/// cache for `project_info` (use `size=0` to disable)
#[clap(long, default_value = config::ProjectInfoCacheOptions::CACHE_DEFAULT_OPTIONS)]
project_info_cache: String,
@@ -216,6 +241,61 @@ async fn main() -> anyhow::Result<()> {
let config = build_config(&args)?;
info!("Authentication backend: {}", config.auth_backend);
info!("Using region: {}", config.aws_region);
let region_provider = RegionProviderChain::default_provider().or_else(&*config.aws_region); // Replace with your Redis region if needed
let provider_conf =
ProviderConfig::without_region().with_region(region_provider.region().await);
let aws_credentials_provider = {
// uses "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"
CredentialsProviderChain::first_try("env", EnvironmentVariableCredentialsProvider::new())
// uses "AWS_PROFILE" / `aws sso login --profile <profile>`
.or_else(
"profile-sso",
ProfileFileCredentialsProvider::builder()
.configure(&provider_conf)
.build(),
)
// uses "AWS_WEB_IDENTITY_TOKEN_FILE", "AWS_ROLE_ARN", "AWS_ROLE_SESSION_NAME"
// needed to access remote extensions bucket
.or_else(
"token",
WebIdentityTokenCredentialsProvider::builder()
.configure(&provider_conf)
.build(),
)
// uses imds v2
.or_else("imds", ImdsCredentialsProvider::builder().build())
};
let elasticache_credentials_provider = Arc::new(elasticache::CredentialsProvider::new(
elasticache::AWSIRSAConfig::new(
config.aws_region.clone(),
args.redis_cluster_name,
args.redis_user_id,
),
aws_credentials_provider,
));
let redis_notifications_client =
match (args.redis_notifications, (args.redis_host, args.redis_port)) {
(Some(url), _) => {
info!("Starting redis notifications listener ({url})");
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url))
}
(None, (Some(host), Some(port))) => Some(
ConnectionWithCredentialsProvider::new_with_credentials_provider(
host,
port,
elasticache_credentials_provider.clone(),
),
),
(None, (None, None)) => {
warn!("Redis is disabled");
None
}
_ => {
bail!("redis-host and redis-port must be specified together");
}
};
// Check that we can bind to address before further initialization
let http_address: SocketAddr = args.http.parse()?;
@@ -233,17 +313,22 @@ async fn main() -> anyhow::Result<()> {
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(&config.endpoint_rps_limit));
let cancel_map = CancelMap::default();
let redis_publisher = match &args.redis_notifications {
Some(url) => Some(Arc::new(Mutex::new(RedisPublisherClient::new(
url,
// let redis_notifications_client = redis_notifications_client.map(|x| Box::leak(Box::new(x)));
let redis_publisher = match &redis_notifications_client {
Some(redis_publisher) => Some(Arc::new(Mutex::new(RedisPublisherClient::new(
redis_publisher.clone(),
args.region.clone(),
&config.redis_rps_limit,
)?))),
None => None,
};
let cancellation_handler = Arc::new(CancellationHandler::new(
let cancellation_handler = Arc::new(CancellationHandler::<
Option<Arc<tokio::sync::Mutex<RedisPublisherClient>>>,
>::new(
cancel_map.clone(),
redis_publisher,
NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT,
));
// client facing tasks. these will exit on error or on cancellation
@@ -290,17 +375,16 @@ async fn main() -> anyhow::Result<()> {
if let auth::BackendType::Console(api, _) = &config.auth_backend {
if let proxy::console::provider::ConsoleBackend::Console(api) = &**api {
let cache = api.caches.project_info.clone();
if let Some(url) = args.redis_notifications {
info!("Starting redis notifications listener ({url})");
if let Some(redis_notifications_client) = redis_notifications_client {
let cache = api.caches.project_info.clone();
maintenance_tasks.spawn(notifications::task_main(
url.to_owned(),
redis_notifications_client.clone(),
cache.clone(),
cancel_map.clone(),
args.region.clone(),
));
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
}
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
}
}
@@ -445,8 +529,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
endpoint_rps_limit,
redis_rps_limit,
handshake_timeout: args.handshake_timeout,
// TODO: add this argument
region: args.region.clone(),
aws_region: args.aws_region.clone(),
}));
Ok(config)

View File

@@ -1,4 +1,3 @@
use async_trait::async_trait;
use dashmap::DashMap;
use pq_proto::CancelKeyData;
use std::{net::SocketAddr, sync::Arc};
@@ -10,18 +9,26 @@ use tracing::info;
use uuid::Uuid;
use crate::{
error::ReportableError, metrics::NUM_CANCELLATION_REQUESTS,
redis::publisher::RedisPublisherClient,
error::ReportableError,
metrics::NUM_CANCELLATION_REQUESTS,
redis::cancellation_publisher::{
CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
},
};
pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
pub type CancellationHandlerMain = CancellationHandler<Option<Arc<Mutex<RedisPublisherClient>>>>;
pub type CancellationHandlerMainInternal = Option<Arc<Mutex<RedisPublisherClient>>>;
/// Enables serving `CancelRequest`s.
///
/// If there is a `RedisPublisherClient` available, it will be used to publish the cancellation key to other proxy instances.
pub struct CancellationHandler {
/// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances.
pub struct CancellationHandler<P> {
map: CancelMap,
redis_client: Option<Arc<Mutex<RedisPublisherClient>>>,
client: P,
/// This field used for the monitoring purposes.
/// Represents the source of the cancellation request.
from: &'static str,
}
#[derive(Debug, Error)]
@@ -44,49 +51,9 @@ impl ReportableError for CancelError {
}
}
impl CancellationHandler {
pub fn new(map: CancelMap, redis_client: Option<Arc<Mutex<RedisPublisherClient>>>) -> Self {
Self { map, redis_client }
}
/// Cancel a running query for the corresponding connection.
pub async fn cancel_session(
&self,
key: CancelKeyData,
session_id: Uuid,
) -> Result<(), CancelError> {
let from = "from_client";
// NB: we should immediately release the lock after cloning the token.
let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
tracing::warn!("query cancellation key not found: {key}");
if let Some(redis_client) = &self.redis_client {
NUM_CANCELLATION_REQUESTS
.with_label_values(&[from, "not_found"])
.inc();
info!("publishing cancellation key to Redis");
match redis_client.lock().await.try_publish(key, session_id).await {
Ok(()) => {
info!("cancellation key successfuly published to Redis");
}
Err(e) => {
tracing::error!("failed to publish a message: {e}");
return Err(CancelError::IO(std::io::Error::new(
std::io::ErrorKind::Other,
e.to_string(),
)));
}
}
}
return Ok(());
};
NUM_CANCELLATION_REQUESTS
.with_label_values(&[from, "found"])
.inc();
info!("cancelling query per user's request using key {key}");
cancel_closure.try_cancel_query().await
}
impl<P: CancellationPublisher> CancellationHandler<P> {
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
pub fn get_session(self: Arc<Self>) -> Session {
pub fn get_session(self: Arc<Self>) -> Session<P> {
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
// expose it and we don't want to do another roundtrip to query
// for it. The client will be able to notice that this is not the
@@ -112,9 +79,39 @@ impl CancellationHandler {
cancellation_handler: self,
}
}
/// Try to cancel a running query for the corresponding connection.
/// If the cancellation key is not found, it will be published to Redis.
pub async fn cancel_session(
&self,
key: CancelKeyData,
session_id: Uuid,
) -> Result<(), CancelError> {
// NB: we should immediately release the lock after cloning the token.
let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
tracing::warn!("query cancellation key not found: {key}");
NUM_CANCELLATION_REQUESTS
.with_label_values(&[self.from, "not_found"])
.inc();
match self.client.try_publish(key, session_id).await {
Ok(()) => {} // do nothing
Err(e) => {
return Err(CancelError::IO(std::io::Error::new(
std::io::ErrorKind::Other,
e.to_string(),
)));
}
}
return Ok(());
};
NUM_CANCELLATION_REQUESTS
.with_label_values(&[self.from, "found"])
.inc();
info!("cancelling query per user's request using key {key}");
cancel_closure.try_cancel_query().await
}
#[cfg(test)]
fn contains(&self, session: &Session) -> bool {
fn contains(&self, session: &Session<P>) -> bool {
self.map.contains_key(&session.key)
}
@@ -124,31 +121,19 @@ impl CancellationHandler {
}
}
#[async_trait]
pub trait NotificationsCancellationHandler {
async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError>;
impl CancellationHandler<()> {
pub fn new(map: CancelMap, from: &'static str) -> Self {
Self {
map,
client: (),
from,
}
}
}
#[async_trait]
impl NotificationsCancellationHandler for CancellationHandler {
async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError> {
let from = "from_redis";
let cancel_closure = self.map.get(&key).and_then(|x| x.clone());
match cancel_closure {
Some(cancel_closure) => {
NUM_CANCELLATION_REQUESTS
.with_label_values(&[from, "found"])
.inc();
cancel_closure.try_cancel_query().await
}
None => {
NUM_CANCELLATION_REQUESTS
.with_label_values(&[from, "not_found"])
.inc();
tracing::warn!("query cancellation key not found: {key}");
Ok(())
}
}
impl<P: CancellationPublisherMut> CancellationHandler<Option<Arc<Mutex<P>>>> {
pub fn new(map: CancelMap, client: Option<Arc<Mutex<P>>>, from: &'static str) -> Self {
Self { map, client, from }
}
}
@@ -178,14 +163,14 @@ impl CancelClosure {
}
/// Helper for registering query cancellation tokens.
pub struct Session {
pub struct Session<P> {
/// The user-facing key identifying this session.
key: CancelKeyData,
/// The [`CancelMap`] this session belongs to.
cancellation_handler: Arc<CancellationHandler>,
cancellation_handler: Arc<CancellationHandler<P>>,
}
impl Session {
impl<P> Session<P> {
/// Store the cancel token for the given session.
/// This enables query cancellation in `crate::proxy::prepare_client_connection`.
pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
@@ -198,7 +183,7 @@ impl Session {
}
}
impl Drop for Session {
impl<P> Drop for Session<P> {
fn drop(&mut self) {
self.cancellation_handler.map.remove(&self.key);
info!("dropped query cancellation key {}", &self.key);
@@ -207,14 +192,16 @@ impl Drop for Session {
#[cfg(test)]
mod tests {
use crate::metrics::NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS;
use super::*;
#[tokio::test]
async fn check_session_drop() -> anyhow::Result<()> {
let cancellation_handler = Arc::new(CancellationHandler {
map: CancelMap::default(),
redis_client: None,
});
let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
CancelMap::default(),
NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS,
));
let session = cancellation_handler.clone().get_session();
assert!(cancellation_handler.contains(&session));
@@ -224,4 +211,19 @@ mod tests {
Ok(())
}
#[tokio::test]
async fn cancel_session_noop_regression() {
let handler = CancellationHandler::<()>::new(Default::default(), "local");
handler
.cancel_session(
CancelKeyData {
backend_pid: 0,
cancel_key: 0,
},
Uuid::new_v4(),
)
.await
.unwrap();
}
}

View File

@@ -82,14 +82,13 @@ pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
/// A config for establishing a connection to compute node.
/// Eventually, `tokio_postgres` will be replaced with something better.
/// Newtype allows us to implement methods on top of it.
#[derive(Clone)]
#[repr(transparent)]
#[derive(Clone, Default)]
pub struct ConnCfg(Box<tokio_postgres::Config>);
/// Creation and initialization routines.
impl ConnCfg {
pub fn new() -> Self {
Self(Default::default())
Self::default()
}
/// Reuse password or auth keys from the other config.
@@ -165,12 +164,6 @@ impl std::ops::DerefMut for ConnCfg {
}
}
impl Default for ConnCfg {
fn default() -> Self {
Self::new()
}
}
impl ConnCfg {
/// Establish a raw TCP connection to the compute node.
async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {

View File

@@ -28,6 +28,7 @@ pub struct ProxyConfig {
pub redis_rps_limit: Vec<RateBucketInfo>,
pub region: String,
pub handshake_timeout: Duration,
pub aws_region: String,
}
#[derive(Debug)]

View File

@@ -6,7 +6,7 @@ pub mod messages;
/// Wrappers for console APIs and their mocks.
pub mod provider;
pub use provider::{errors, Api, AuthSecret, CachedNodeInfo, NodeInfo};
pub(crate) use provider::{errors, Api, AuthSecret, CachedNodeInfo, NodeInfo};
/// Various cache-related types.
pub mod caches {

View File

@@ -14,7 +14,6 @@ use crate::{
context::RequestMonitoring,
scram, EndpointCacheKey, ProjectId,
};
use async_trait::async_trait;
use dashmap::DashMap;
use std::{sync::Arc, time::Duration};
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
@@ -326,8 +325,7 @@ pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPatt
/// This will allocate per each call, but the http requests alone
/// already require a few allocations, so it should be fine.
#[async_trait]
pub trait Api {
pub(crate) trait Api {
/// Get the client's auth secret for authentication.
/// Returns option because user not found situation is special.
/// We still have to mock the scram to avoid leaking information that user doesn't exist.
@@ -363,7 +361,6 @@ pub enum ConsoleBackend {
Test(Box<dyn crate::auth::backend::TestBackend>),
}
#[async_trait]
impl Api for ConsoleBackend {
async fn get_role_secret(
&self,

View File

@@ -8,7 +8,6 @@ use crate::console::provider::{CachedAllowedIps, CachedRoleSecret};
use crate::context::RequestMonitoring;
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
use crate::{auth::IpPattern, cache::Cached};
use async_trait::async_trait;
use futures::TryFutureExt;
use std::{str::FromStr, sync::Arc};
use thiserror::Error;
@@ -144,7 +143,6 @@ async fn get_execute_postgres_query(
Ok(Some(entry))
}
#[async_trait]
impl super::Api for Api {
#[tracing::instrument(skip_all)]
async fn get_role_secret(

View File

@@ -14,7 +14,6 @@ use crate::{
context::RequestMonitoring,
metrics::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER},
};
use async_trait::async_trait;
use futures::TryFutureExt;
use std::sync::Arc;
use tokio::time::Instant;
@@ -168,7 +167,6 @@ impl Api {
}
}
#[async_trait]
impl super::Api for Api {
#[tracing::instrument(skip_all)]
async fn get_role_secret(

View File

@@ -2,14 +2,21 @@ use anyhow::{anyhow, bail};
use hyper::{Body, Request, Response, StatusCode};
use std::{convert::Infallible, net::TcpListener};
use tracing::info;
use utils::http::{endpoint, error::ApiError, json::json_response, RouterBuilder, RouterService};
use utils::http::{
endpoint::{self, prometheus_metrics_handler, request_span},
error::ApiError,
json::json_response,
RouterBuilder, RouterService,
};
async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
json_response(StatusCode::OK, "")
}
fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
endpoint::make_router().get("/v1/status", status_handler)
endpoint::make_router()
.get("/metrics", |r| request_span(r, prometheus_metrics_handler))
.get("/v1/status", status_handler)
}
pub async fn task_main(http_listener: TcpListener) -> anyhow::Result<Infallible> {

View File

@@ -161,6 +161,9 @@ pub static NUM_CANCELLATION_REQUESTS: Lazy<IntCounterVec> = Lazy::new(|| {
.unwrap()
});
pub const NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT: &str = "from_client";
pub const NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS: &str = "from_redis";
pub enum Waiting {
Cplane,
Client,

View File

@@ -10,7 +10,7 @@ pub mod wake_compute;
use crate::{
auth,
cancellation::{self, CancellationHandler},
cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal},
compute,
config::{ProxyConfig, TlsConfig},
context::RequestMonitoring,
@@ -62,7 +62,7 @@ pub async fn task_main(
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_handler: Arc<CancellationHandler>,
cancellation_handler: Arc<CancellationHandlerMain>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("proxy has shut down");
@@ -233,12 +233,12 @@ impl ReportableError for ClientRequestError {
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
cancellation_handler: Arc<CancellationHandler>,
cancellation_handler: Arc<CancellationHandlerMain>,
stream: S,
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: IntCounterPairGuard,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
) -> Result<Option<ProxyPassthrough<CancellationHandlerMainInternal, S>>, ClientRequestError> {
info!("handling interactive connection from client");
let proto = ctx.protocol;
@@ -338,9 +338,9 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
/// Finish client connection initialization: confirm auth success, send params, etc.
#[tracing::instrument(skip_all)]
async fn prepare_client_connection(
async fn prepare_client_connection<P>(
node: &compute::PostgresConnection,
session: &cancellation::Session,
session: &cancellation::Session<P>,
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> Result<(), std::io::Error> {
// Register compute's query cancellation token and produce a new, unique one.

View File

@@ -55,17 +55,17 @@ pub async fn proxy_pass(
Ok(())
}
pub struct ProxyPassthrough<S> {
pub struct ProxyPassthrough<P, S> {
pub client: Stream<S>,
pub compute: PostgresConnection,
pub aux: MetricsAuxInfo,
pub req: IntCounterPairGuard,
pub conn: IntCounterPairGuard,
pub cancel: cancellation::Session,
pub cancel: cancellation::Session<P>,
}
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
impl<P, S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<P, S> {
pub async fn proxy_pass(self) -> anyhow::Result<()> {
let res = proxy_pass(self.client, self.compute.stream, self.aux).await;
self.compute.cancel_closure.try_cancel_query().await?;

View File

@@ -1,2 +1,4 @@
pub mod cancellation_publisher;
pub mod connection_with_credentials_provider;
pub mod elasticache;
pub mod notifications;
pub mod publisher;

View File

@@ -0,0 +1,161 @@
use std::sync::Arc;
use pq_proto::CancelKeyData;
use redis::AsyncCommands;
use tokio::sync::Mutex;
use uuid::Uuid;
use crate::rate_limiter::{RateBucketInfo, RedisRateLimiter};
use super::{
connection_with_credentials_provider::ConnectionWithCredentialsProvider,
notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME},
};
pub trait CancellationPublisherMut: Send + Sync + 'static {
#[allow(async_fn_in_trait)]
async fn try_publish(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
) -> anyhow::Result<()>;
}
pub trait CancellationPublisher: Send + Sync + 'static {
#[allow(async_fn_in_trait)]
async fn try_publish(
&self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
) -> anyhow::Result<()>;
}
impl CancellationPublisher for () {
async fn try_publish(
&self,
_cancel_key_data: CancelKeyData,
_session_id: Uuid,
) -> anyhow::Result<()> {
Ok(())
}
}
impl<P: CancellationPublisher> CancellationPublisherMut for P {
async fn try_publish(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
) -> anyhow::Result<()> {
<P as CancellationPublisher>::try_publish(self, cancel_key_data, session_id).await
}
}
impl<P: CancellationPublisher> CancellationPublisher for Option<P> {
async fn try_publish(
&self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
) -> anyhow::Result<()> {
if let Some(p) = self {
p.try_publish(cancel_key_data, session_id).await
} else {
Ok(())
}
}
}
impl<P: CancellationPublisherMut> CancellationPublisher for Arc<Mutex<P>> {
async fn try_publish(
&self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
) -> anyhow::Result<()> {
self.lock()
.await
.try_publish(cancel_key_data, session_id)
.await
}
}
pub struct RedisPublisherClient {
client: ConnectionWithCredentialsProvider,
region_id: String,
limiter: RedisRateLimiter,
}
impl RedisPublisherClient {
pub fn new(
client: ConnectionWithCredentialsProvider,
region_id: String,
info: &'static [RateBucketInfo],
) -> anyhow::Result<Self> {
Ok(Self {
client,
region_id,
limiter: RedisRateLimiter::new(info),
})
}
async fn publish(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
) -> anyhow::Result<()> {
let payload = serde_json::to_string(&Notification::Cancel(CancelSession {
region_id: Some(self.region_id.clone()),
cancel_key_data,
session_id,
}))?;
self.client.publish(PROXY_CHANNEL_NAME, payload).await?;
Ok(())
}
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
match self.client.connect().await {
Ok(()) => {}
Err(e) => {
tracing::error!("failed to connect to redis: {e}");
return Err(e);
}
}
Ok(())
}
async fn try_publish_internal(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
) -> anyhow::Result<()> {
if !self.limiter.check() {
tracing::info!("Rate limit exceeded. Skipping cancellation message");
return Err(anyhow::anyhow!("Rate limit exceeded"));
}
match self.publish(cancel_key_data, session_id).await {
Ok(()) => return Ok(()),
Err(e) => {
tracing::error!("failed to publish a message: {e}");
}
}
tracing::info!("Publisher is disconnected. Reconnectiong...");
self.try_connect().await?;
self.publish(cancel_key_data, session_id).await
}
}
impl CancellationPublisherMut for RedisPublisherClient {
async fn try_publish(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
) -> anyhow::Result<()> {
tracing::info!("publishing cancellation key to Redis");
match self.try_publish_internal(cancel_key_data, session_id).await {
Ok(()) => {
tracing::info!("cancellation key successfuly published to Redis");
Ok(())
}
Err(e) => {
tracing::error!("failed to publish a message: {e}");
Err(e)
}
}
}
}

View File

@@ -0,0 +1,225 @@
use std::{sync::Arc, time::Duration};
use futures::FutureExt;
use redis::{
aio::{ConnectionLike, MultiplexedConnection},
ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisResult,
};
use tokio::task::JoinHandle;
use tracing::{error, info};
use super::elasticache::CredentialsProvider;
enum Credentials {
Static(ConnectionInfo),
Dynamic(Arc<CredentialsProvider>, redis::ConnectionAddr),
}
impl Clone for Credentials {
fn clone(&self) -> Self {
match self {
Credentials::Static(info) => Credentials::Static(info.clone()),
Credentials::Dynamic(provider, addr) => {
Credentials::Dynamic(Arc::clone(provider), addr.clone())
}
}
}
}
/// A wrapper around `redis::MultiplexedConnection` that automatically refreshes the token.
/// Provides PubSub connection without credentials refresh.
pub struct ConnectionWithCredentialsProvider {
credentials: Credentials,
con: Option<MultiplexedConnection>,
refresh_token_task: Option<JoinHandle<()>>,
mutex: tokio::sync::Mutex<()>,
}
impl Clone for ConnectionWithCredentialsProvider {
fn clone(&self) -> Self {
Self {
credentials: self.credentials.clone(),
con: None,
refresh_token_task: None,
mutex: tokio::sync::Mutex::new(()),
}
}
}
impl ConnectionWithCredentialsProvider {
pub fn new_with_credentials_provider(
host: String,
port: u16,
credentials_provider: Arc<CredentialsProvider>,
) -> Self {
Self {
credentials: Credentials::Dynamic(
credentials_provider,
redis::ConnectionAddr::TcpTls {
host,
port,
insecure: false,
tls_params: None,
},
),
con: None,
refresh_token_task: None,
mutex: tokio::sync::Mutex::new(()),
}
}
pub fn new_with_static_credentials<T: IntoConnectionInfo>(params: T) -> Self {
Self {
credentials: Credentials::Static(params.into_connection_info().unwrap()),
con: None,
refresh_token_task: None,
mutex: tokio::sync::Mutex::new(()),
}
}
pub async fn connect(&mut self) -> anyhow::Result<()> {
let _guard = self.mutex.lock().await;
if let Some(con) = self.con.as_mut() {
match redis::cmd("PING").query_async(con).await {
Ok(()) => {
return Ok(());
}
Err(e) => {
error!("Error during PING: {e:?}");
}
}
} else {
info!("Connection is not established");
}
info!("Establishing a new connection...");
self.con = None;
if let Some(f) = self.refresh_token_task.take() {
f.abort()
}
let con = self
.get_client()
.await?
.get_multiplexed_tokio_connection()
.await?;
if let Credentials::Dynamic(credentials_provider, _) = &self.credentials {
let credentials_provider = credentials_provider.clone();
let con2 = con.clone();
let f = tokio::spawn(async move {
let _ = Self::keep_connection(con2, credentials_provider).await;
});
self.refresh_token_task = Some(f);
}
self.con = Some(con);
Ok(())
}
async fn get_connection_info(&self) -> anyhow::Result<ConnectionInfo> {
match &self.credentials {
Credentials::Static(info) => Ok(info.clone()),
Credentials::Dynamic(provider, addr) => {
let (username, password) = provider.provide_credentials().await?;
Ok(ConnectionInfo {
addr: addr.clone(),
redis: RedisConnectionInfo {
db: 0,
username: Some(username),
password: Some(password.clone()),
},
})
}
}
}
async fn get_client(&self) -> anyhow::Result<redis::Client> {
let client = redis::Client::open(self.get_connection_info().await?)?;
Ok(client)
}
// PubSub does not support credentials refresh.
// Requires manual reconnection every 12h.
pub async fn get_async_pubsub(&self) -> anyhow::Result<redis::aio::PubSub> {
Ok(self.get_client().await?.get_async_pubsub().await?)
}
// The connection lives for 12h.
// It can be prolonged with sending `AUTH` commands with the refreshed token.
// https://docs.aws.amazon.com/AmazonElastiCache/latest/red-ug/auth-iam.html#auth-iam-limits
async fn keep_connection(
mut con: MultiplexedConnection,
credentials_provider: Arc<CredentialsProvider>,
) -> anyhow::Result<()> {
loop {
// The connection lives for 12h, for the sanity check we refresh it every hour.
tokio::time::sleep(Duration::from_secs(60 * 60)).await;
match Self::refresh_token(&mut con, credentials_provider.clone()).await {
Ok(()) => {
info!("Token refreshed");
}
Err(e) => {
error!("Error during token refresh: {e:?}");
}
}
}
}
async fn refresh_token(
con: &mut MultiplexedConnection,
credentials_provider: Arc<CredentialsProvider>,
) -> anyhow::Result<()> {
let (user, password) = credentials_provider.provide_credentials().await?;
redis::cmd("AUTH")
.arg(user)
.arg(password)
.query_async(con)
.await?;
Ok(())
}
/// Sends an already encoded (packed) command into the TCP socket and
/// reads the single response from it.
pub async fn send_packed_command(&mut self, cmd: &redis::Cmd) -> RedisResult<redis::Value> {
// Clone connection to avoid having to lock the ArcSwap in write mode
let con = self.con.as_mut().ok_or(redis::RedisError::from((
redis::ErrorKind::IoError,
"Connection not established",
)))?;
con.send_packed_command(cmd).await
}
/// Sends multiple already encoded (packed) command into the TCP socket
/// and reads `count` responses from it. This is used to implement
/// pipelining.
pub async fn send_packed_commands(
&mut self,
cmd: &redis::Pipeline,
offset: usize,
count: usize,
) -> RedisResult<Vec<redis::Value>> {
// Clone shared connection future to avoid having to lock the ArcSwap in write mode
let con = self.con.as_mut().ok_or(redis::RedisError::from((
redis::ErrorKind::IoError,
"Connection not established",
)))?;
con.send_packed_commands(cmd, offset, count).await
}
}
impl ConnectionLike for ConnectionWithCredentialsProvider {
fn req_packed_command<'a>(
&'a mut self,
cmd: &'a redis::Cmd,
) -> redis::RedisFuture<'a, redis::Value> {
(async move { self.send_packed_command(cmd).await }).boxed()
}
fn req_packed_commands<'a>(
&'a mut self,
cmd: &'a redis::Pipeline,
offset: usize,
count: usize,
) -> redis::RedisFuture<'a, Vec<redis::Value>> {
(async move { self.send_packed_commands(cmd, offset, count).await }).boxed()
}
fn get_db(&self) -> i64 {
0
}
}

View File

@@ -0,0 +1,110 @@
use std::time::{Duration, SystemTime};
use aws_config::meta::credentials::CredentialsProviderChain;
use aws_sdk_iam::config::ProvideCredentials;
use aws_sigv4::http_request::{
self, SignableBody, SignableRequest, SignatureLocation, SigningSettings,
};
use tracing::info;
#[derive(Debug)]
pub struct AWSIRSAConfig {
region: String,
service_name: String,
cluster_name: String,
user_id: String,
token_ttl: Duration,
action: String,
}
impl AWSIRSAConfig {
pub fn new(region: String, cluster_name: Option<String>, user_id: Option<String>) -> Self {
AWSIRSAConfig {
region,
service_name: "elasticache".to_string(),
cluster_name: cluster_name.unwrap_or_default(),
user_id: user_id.unwrap_or_default(),
// "The IAM authentication token is valid for 15 minutes"
// https://docs.aws.amazon.com/memorydb/latest/devguide/auth-iam.html#auth-iam-limits
token_ttl: Duration::from_secs(15 * 60),
action: "connect".to_string(),
}
}
}
/// Credentials provider for AWS elasticache authentication.
///
/// Official documentation:
/// <https://docs.aws.amazon.com/AmazonElastiCache/latest/red-ug/auth-iam.html>
///
/// Useful resources:
/// <https://aws.amazon.com/blogs/database/simplify-managing-access-to-amazon-elasticache-for-redis-clusters-with-iam/>
pub struct CredentialsProvider {
config: AWSIRSAConfig,
credentials_provider: CredentialsProviderChain,
}
impl CredentialsProvider {
pub fn new(config: AWSIRSAConfig, credentials_provider: CredentialsProviderChain) -> Self {
CredentialsProvider {
config,
credentials_provider,
}
}
pub async fn provide_credentials(&self) -> anyhow::Result<(String, String)> {
let aws_credentials = self
.credentials_provider
.provide_credentials()
.await?
.into();
info!("AWS credentials successfully obtained");
info!("Connecting to Redis with configuration: {:?}", self.config);
let mut settings = SigningSettings::default();
settings.signature_location = SignatureLocation::QueryParams;
settings.expires_in = Some(self.config.token_ttl);
let signing_params = aws_sigv4::sign::v4::SigningParams::builder()
.identity(&aws_credentials)
.region(&self.config.region)
.name(&self.config.service_name)
.time(SystemTime::now())
.settings(settings)
.build()?
.into();
let auth_params = [
("Action", &self.config.action),
("User", &self.config.user_id),
];
let auth_params = url::form_urlencoded::Serializer::new(String::new())
.extend_pairs(auth_params)
.finish();
let auth_uri = http::Uri::builder()
.scheme("http")
.authority(self.config.cluster_name.as_bytes())
.path_and_query(format!("/?{auth_params}"))
.build()?;
info!("{}", auth_uri);
// Convert the HTTP request into a signable request
let signable_request = SignableRequest::new(
"GET",
auth_uri.to_string(),
std::iter::empty(),
SignableBody::Bytes(&[]),
)?;
// Sign and then apply the signature to the request
let (si, _) = http_request::sign(signable_request, &signing_params)?.into_parts();
let mut signable_request = http::Request::builder()
.method("GET")
.uri(auth_uri)
.body(())?;
si.apply_to_request_http1x(&mut signable_request);
Ok((
self.config.user_id.clone(),
signable_request
.uri()
.to_string()
.replacen("http://", "", 1),
))
}
}

View File

@@ -6,11 +6,12 @@ use redis::aio::PubSub;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::{
cache::project_info::ProjectInfoCache,
cancellation::{CancelMap, CancellationHandler, NotificationsCancellationHandler},
cancellation::{CancelMap, CancellationHandler},
intern::{ProjectIdInt, RoleNameInt},
metrics::REDIS_BROKEN_MESSAGES,
metrics::{NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS, REDIS_BROKEN_MESSAGES},
};
const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
@@ -18,23 +19,13 @@ pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates";
const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
struct RedisConsumerClient {
client: redis::Client,
}
impl RedisConsumerClient {
pub fn new(url: &str) -> anyhow::Result<Self> {
let client = redis::Client::open(url)?;
Ok(Self { client })
}
async fn try_connect(&self) -> anyhow::Result<PubSub> {
let mut conn = self.client.get_async_connection().await?.into_pubsub();
tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
conn.subscribe(CPLANE_CHANNEL_NAME).await?;
tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`");
conn.subscribe(PROXY_CHANNEL_NAME).await?;
Ok(conn)
}
async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Result<PubSub> {
let mut conn = client.get_async_pubsub().await?;
tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
conn.subscribe(CPLANE_CHANNEL_NAME).await?;
tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`");
conn.subscribe(PROXY_CHANNEL_NAME).await?;
Ok(conn)
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
@@ -80,21 +71,18 @@ where
serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
}
struct MessageHandler<
C: ProjectInfoCache + Send + Sync + 'static,
H: NotificationsCancellationHandler + Send + Sync + 'static,
> {
struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
cache: Arc<C>,
cancellation_handler: Arc<H>,
cancellation_handler: Arc<CancellationHandler<()>>,
region_id: String,
}
impl<
C: ProjectInfoCache + Send + Sync + 'static,
H: NotificationsCancellationHandler + Send + Sync + 'static,
> MessageHandler<C, H>
{
pub fn new(cache: Arc<C>, cancellation_handler: Arc<H>, region_id: String) -> Self {
impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
pub fn new(
cache: Arc<C>,
cancellation_handler: Arc<CancellationHandler<()>>,
region_id: String,
) -> Self {
Self {
cache,
cancellation_handler,
@@ -139,7 +127,7 @@ impl<
// This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message.
match self
.cancellation_handler
.cancel_session_no_publish(cancel_session.cancel_key_data)
.cancel_session(cancel_session.cancel_key_data, uuid::Uuid::nil())
.await
{
Ok(()) => {}
@@ -182,7 +170,7 @@ fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
/// Handle console's invalidation messages.
#[tracing::instrument(name = "console_notifications", skip_all)]
pub async fn task_main<C>(
url: String,
redis: ConnectionWithCredentialsProvider,
cache: Arc<C>,
cancel_map: CancelMap,
region_id: String,
@@ -193,13 +181,15 @@ where
cache.enable_ttl();
let handler = MessageHandler::new(
cache,
Arc::new(CancellationHandler::new(cancel_map, None)),
Arc::new(CancellationHandler::<()>::new(
cancel_map,
NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS,
)),
region_id,
);
loop {
let redis = RedisConsumerClient::new(&url)?;
let conn = match redis.try_connect().await {
let mut conn = match try_connect(&redis).await {
Ok(conn) => {
handler.disable_ttl();
conn
@@ -212,7 +202,7 @@ where
continue;
}
};
let mut stream = conn.into_on_message();
let mut stream = conn.on_message();
while let Some(msg) = stream.next().await {
match handler.handle_message(msg).await {
Ok(()) => {}

View File

@@ -1,80 +0,0 @@
use pq_proto::CancelKeyData;
use redis::AsyncCommands;
use uuid::Uuid;
use crate::rate_limiter::{RateBucketInfo, RedisRateLimiter};
use super::notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME};
pub struct RedisPublisherClient {
client: redis::Client,
publisher: Option<redis::aio::Connection>,
region_id: String,
limiter: RedisRateLimiter,
}
impl RedisPublisherClient {
pub fn new(
url: &str,
region_id: String,
info: &'static [RateBucketInfo],
) -> anyhow::Result<Self> {
let client = redis::Client::open(url)?;
Ok(Self {
client,
publisher: None,
region_id,
limiter: RedisRateLimiter::new(info),
})
}
pub async fn try_publish(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
) -> anyhow::Result<()> {
if !self.limiter.check() {
tracing::info!("Rate limit exceeded. Skipping cancellation message");
return Err(anyhow::anyhow!("Rate limit exceeded"));
}
match self.publish(cancel_key_data, session_id).await {
Ok(()) => return Ok(()),
Err(e) => {
tracing::error!("failed to publish a message: {e}");
self.publisher = None;
}
}
tracing::info!("Publisher is disconnected. Reconnectiong...");
self.try_connect().await?;
self.publish(cancel_key_data, session_id).await
}
async fn publish(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
) -> anyhow::Result<()> {
let conn = self
.publisher
.as_mut()
.ok_or_else(|| anyhow::anyhow!("not connected"))?;
let payload = serde_json::to_string(&Notification::Cancel(CancelSession {
region_id: Some(self.region_id.clone()),
cancel_key_data,
session_id,
}))?;
conn.publish(PROXY_CHANNEL_NAME, payload).await?;
Ok(())
}
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
match self.client.get_async_connection().await {
Ok(conn) => {
self.publisher = Some(conn);
}
Err(e) => {
tracing::error!("failed to connect to redis: {e}");
return Err(e.into());
}
}
Ok(())
}
}

View File

@@ -33,6 +33,9 @@ pub enum Error {
#[error("Internal error: missing digest")]
MissingBinding,
#[error("could not decode salt: {0}")]
Base64(#[from] base64::DecodeError),
#[error(transparent)]
Io(#[from] io::Error),
}
@@ -55,6 +58,7 @@ impl ReportableError for Error {
Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User,
Error::BadClientMessage(_) => crate::error::ErrorKind::User,
Error::MissingBinding => crate::error::ErrorKind::Service,
Error::Base64(_) => crate::error::ErrorKind::ControlPlane,
Error::Io(_) => crate::error::ErrorKind::ClientDisconnect,
}
}

View File

@@ -56,8 +56,6 @@ fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
#[cfg(test)]
mod tests {
use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
use crate::sasl::{Mechanism, Step};
use super::{Exchange, ServerSecret};
@@ -115,16 +113,9 @@ mod tests {
async fn run_round_trip_test(server_password: &str, client_password: &str) {
let scram_secret = ServerSecret::build(server_password).await.unwrap();
let sasl_client =
ScramSha256::new(client_password.as_bytes(), ChannelBinding::unsupported());
let outcome = super::exchange(
&scram_secret,
sasl_client,
crate::config::TlsServerEndPoint::Undefined,
)
.await
.unwrap();
let outcome = super::exchange(&scram_secret, client_password.as_bytes())
.await
.unwrap();
match outcome {
crate::sasl::Outcome::Success(_) => {}

View File

@@ -2,13 +2,16 @@
use std::convert::Infallible;
use postgres_protocol::authentication::sasl::ScramSha256;
use hmac::{Hmac, Mac};
use sha2::Sha256;
use tokio::task::yield_now;
use super::messages::{
ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
};
use super::secret::ServerSecret;
use super::signature::SignatureBuilder;
use super::ScramKey;
use crate::config;
use crate::sasl::{self, ChannelBinding, Error as SaslError};
@@ -71,40 +74,62 @@ impl<'a> Exchange<'a> {
}
}
// copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L36-L61>
async fn pbkdf2(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] {
let hmac = Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
let mut prev = hmac
.clone()
.chain_update(salt)
.chain_update(1u32.to_be_bytes())
.finalize()
.into_bytes();
let mut hi = prev;
for i in 1..iterations {
prev = hmac.clone().chain_update(prev).finalize().into_bytes();
for (hi, prev) in hi.iter_mut().zip(prev) {
*hi ^= prev;
}
// yield every ~250us
// hopefully reduces tail latencies
if i % 1024 == 0 {
yield_now().await
}
}
hi.into()
}
// copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L236-L248>
async fn derive_client_key(password: &[u8], salt: &[u8], iterations: u32) -> ScramKey {
let salted_password = pbkdf2(password, salt, iterations).await;
let make_key = |name| {
let key = Hmac::<Sha256>::new_from_slice(&salted_password)
.expect("HMAC is able to accept all key sizes")
.chain_update(name)
.finalize();
<[u8; 32]>::from(key.into_bytes())
};
make_key(b"Client Key").into()
}
pub async fn exchange(
secret: &ServerSecret,
mut client: ScramSha256,
tls_server_end_point: config::TlsServerEndPoint,
password: &[u8],
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
use sasl::Step::*;
let salt = base64::decode(&secret.salt_base64)?;
let client_key = derive_client_key(password, &salt, secret.iterations).await;
let init = SaslInitial {
nonce: rand::random,
};
let client_first = std::str::from_utf8(client.message())
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
let sent = match init.transition(secret, &tls_server_end_point, client_first)? {
Continue(sent, server_first) => {
client.update(server_first.as_bytes()).await?;
sent
}
Success(x, _) => match x {},
Failure(msg) => return Ok(sasl::Outcome::Failure(msg)),
};
let client_final = std::str::from_utf8(client.message())
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
let keys = match sent.transition(secret, &tls_server_end_point, client_final)? {
Success(keys, server_final) => {
client.finish(server_final.as_bytes())?;
keys
}
Continue(x, _) => match x {},
Failure(msg) => return Ok(sasl::Outcome::Failure(msg)),
};
Ok(sasl::Outcome::Success(keys))
if secret.is_password_invalid(&client_key).into() {
Ok(sasl::Outcome::Failure("password doesn't match"))
} else {
Ok(sasl::Outcome::Success(client_key))
}
}
impl SaslInitial {
@@ -185,7 +210,7 @@ impl SaslSentInner {
.derive_client_key(&client_final_message.proof);
// Auth fails either if keys don't match or it's pre-determined to fail.
if client_key.sha256() != secret.stored_key || secret.doomed {
if secret.is_password_invalid(&client_key).into() {
return Ok(sasl::Step::Failure("password doesn't match"));
}

View File

@@ -1,17 +1,31 @@
//! Tools for client/server/stored key management.
use subtle::ConstantTimeEq;
/// Faithfully taken from PostgreSQL.
pub const SCRAM_KEY_LEN: usize = 32;
/// One of the keys derived from the user's password.
/// We use the same structure for all keys, i.e.
/// `ClientKey`, `StoredKey`, and `ServerKey`.
#[derive(Clone, Default, PartialEq, Eq, Debug)]
#[derive(Clone, Default, Eq, Debug)]
#[repr(transparent)]
pub struct ScramKey {
bytes: [u8; SCRAM_KEY_LEN],
}
impl PartialEq for ScramKey {
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).into()
}
}
impl ConstantTimeEq for ScramKey {
fn ct_eq(&self, other: &Self) -> subtle::Choice {
self.bytes.ct_eq(&other.bytes)
}
}
impl ScramKey {
pub fn sha256(&self) -> Self {
super::sha256([self.as_ref()]).into()

View File

@@ -206,6 +206,28 @@ mod tests {
}
}
#[test]
fn parse_client_first_message_with_invalid_gs2_authz() {
assert!(ClientFirstMessage::parse("n,authzid,n=user,r=nonce").is_none())
}
#[test]
fn parse_client_first_message_with_extra_params() {
let msg = ClientFirstMessage::parse("n,,n=user,r=nonce,a=foo,b=bar,c=baz").unwrap();
assert_eq!(msg.bare, "n=user,r=nonce,a=foo,b=bar,c=baz");
assert_eq!(msg.username, "user");
assert_eq!(msg.nonce, "nonce");
assert_eq!(msg.cbind_flag, ChannelBinding::NotSupportedClient);
}
#[test]
fn parse_client_first_message_with_extra_params_invalid() {
// must be of the form `<ascii letter>=<...>`
assert!(ClientFirstMessage::parse("n,,n=user,r=nonce,abc=foo").is_none());
assert!(ClientFirstMessage::parse("n,,n=user,r=nonce,1=foo").is_none());
assert!(ClientFirstMessage::parse("n,,n=user,r=nonce,a").is_none());
}
#[test]
fn parse_client_final_message() {
let input = [

View File

@@ -1,5 +1,7 @@
//! Tools for SCRAM server secret management.
use subtle::{Choice, ConstantTimeEq};
use super::base64_decode_array;
use super::key::ScramKey;
@@ -40,6 +42,11 @@ impl ServerSecret {
Some(secret)
}
pub fn is_password_invalid(&self, client_key: &ScramKey) -> Choice {
// constant time to not leak partial key match
client_key.sha256().ct_ne(&self.stored_key) | Choice::from(self.doomed as u8)
}
/// To avoid revealing information to an attacker, we use a
/// mocked server secret even if the user doesn't exist.
/// See `auth-scram.c : mock_scram_secret` for details.

View File

@@ -21,11 +21,12 @@ pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use tokio_util::task::TaskTracker;
use tracing::instrument::Instrumented;
use crate::cancellation::CancellationHandlerMain;
use crate::config::ProxyConfig;
use crate::context::RequestMonitoring;
use crate::protocol2::{ProxyProtocolAccept, WithClientIp, WithConnectionGuard};
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
use crate::{cancellation::CancellationHandler, config::ProxyConfig};
use hyper::{
server::conn::{AddrIncoming, AddrStream},
Body, Method, Request, Response,
@@ -47,7 +48,7 @@ pub async fn task_main(
ws_listener: TcpListener,
cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_handler: Arc<CancellationHandler>,
cancellation_handler: Arc<CancellationHandlerMain>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("websocket server has shut down");
@@ -237,7 +238,7 @@ async fn request_handler(
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
ws_connections: TaskTracker,
cancellation_handler: Arc<CancellationHandler>,
cancellation_handler: Arc<CancellationHandlerMain>,
peer_addr: IpAddr,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
// used to cancel in-flight HTTP requests. not used to cancel websockets

View File

@@ -1,5 +1,5 @@
use crate::{
cancellation::CancellationHandler,
cancellation::CancellationHandlerMain,
config::ProxyConfig,
context::RequestMonitoring,
error::{io_error, ReportableError},
@@ -134,7 +134,7 @@ pub async fn serve_websocket(
config: &'static ProxyConfig,
mut ctx: RequestMonitoring,
websocket: HyperWebsocket,
cancellation_handler: Arc<CancellationHandler>,
cancellation_handler: Arc<CancellationHandlerMain>,
hostname: Option<String>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> anyhow::Result<()> {

View File

@@ -1,5 +1,5 @@
[toolchain]
channel = "1.76.0"
channel = "1.77.0"
profile = "default"
# The default profile includes rustc, rust-std, cargo, rust-docs, rustfmt and clippy.
# https://rust-lang.github.io/rustup/concepts/profiles.html

View File

@@ -225,6 +225,7 @@ async fn write_segment(
assert!(from <= to);
assert!(to <= wal_seg_size);
#[allow(clippy::suspicious_open_options)]
let mut file = OpenOptions::new()
.create(true)
.write(true)

View File

@@ -20,7 +20,7 @@ use std::io::Write as _;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tracing::{info_span, Instrument};
use utils::http::endpoint::{request_span, ChannelWriter};
use utils::http::endpoint::{prometheus_metrics_handler, request_span, ChannelWriter};
use crate::debug_dump::TimelineDigestRequest;
use crate::receive_wal::WalReceiverState;
@@ -515,6 +515,7 @@ pub fn make_router(conf: SafeKeeperConf) -> RouterBuilder<hyper::Body, ApiError>
router
.data(Arc::new(conf))
.data(auth)
.get("/metrics", |r| request_span(r, prometheus_metrics_handler))
.get("/v1/status", |r| request_span(r, status_handler))
.put("/v1/failpoints", |r| {
request_span(r, move |r| async {

View File

@@ -221,6 +221,7 @@ impl PhysicalStorage {
// half initialized segment, first bake it under tmp filename and
// then rename.
let tmp_path = self.timeline_dir.join("waltmp");
#[allow(clippy::suspicious_open_options)]
let mut file = OpenOptions::new()
.create(true)
.write(true)

View File

@@ -244,6 +244,7 @@ impl SimulationApi {
mutex: 0,
mineLastElectedTerm: 0,
backpressureThrottlingTime: pg_atomic_uint64 { value: 0 },
currentClusterSize: pg_atomic_uint64 { value: 0 },
shard_ps_feedback: [empty_feedback; 128],
num_shards: 0,
min_ps_feedback: empty_feedback,

View File

@@ -1155,13 +1155,17 @@ class NeonEnv:
After this method returns, there should be no child processes running.
"""
self.endpoints.stop_all()
# Stop storage controller before pageservers: we don't want it to spuriously
# detect a pageserver "failure" during test teardown
self.storage_controller.stop(immediate=immediate)
for sk in self.safekeepers:
sk.stop(immediate=immediate)
for pageserver in self.pageservers:
if ps_assert_metric_no_errors:
pageserver.assert_no_metric_errors()
pageserver.stop(immediate=immediate)
self.storage_controller.stop(immediate=immediate)
self.broker.stop(immediate=immediate)
@property
@@ -1617,8 +1621,6 @@ class NeonCli(AbstractNeonCli):
if timeline_id is not None:
cmd.extend(["--timeline-id", str(timeline_id)])
log.info(f"create_timeline: {cmd}")
res = self.raw_cli(cmd)
res.check_returncode()
@@ -1650,8 +1652,6 @@ class NeonCli(AbstractNeonCli):
if ancestor_start_lsn is not None:
cmd.extend(["--ancestor-start-lsn", str(ancestor_start_lsn)])
log.info(f"create_branch: {cmd}")
res = self.raw_cli(cmd)
res.check_returncode()
@@ -2154,6 +2154,18 @@ class NeonStorageController(MetricsGetter):
shards: list[dict[str, Any]] = body["shards"]
return shards
def tenant_describe(self, tenant_id: TenantId):
"""
:return: list of {"shard_id": "", "node_id": int, "listen_pg_addr": str, "listen_pg_port": int, "listen_http_addr: str, "listen_http_port: int}
"""
response = self.request(
"GET",
f"{self.env.storage_controller_api}/control/v1/tenant/{tenant_id}",
headers=self.headers(TokenScope.ADMIN),
)
response.raise_for_status()
return response.json()
def tenant_shard_split(
self, tenant_id: TenantId, shard_count: int, shard_stripe_size: Optional[int] = None
) -> list[TenantShardId]:

View File

@@ -96,6 +96,8 @@ DEFAULT_STORAGE_CONTROLLER_ALLOWED_ERRORS = [
".*Call to node.*management API.*failed.*ReceiveBody.*",
# Many tests will start up with a node offline
".*startup_reconcile: Could not scan node.*",
# Tests run in dev mode
".*Starting in dev mode.*",
]

View File

@@ -62,9 +62,7 @@ def wait_for_upload(
)
time.sleep(1)
raise Exception(
"timed out while waiting for remote_consistent_lsn to reach {}, was {}".format(
lsn, current_lsn
)
f"timed out while waiting for {tenant}/{timeline} remote_consistent_lsn to reach {lsn}, was {current_lsn}"
)

View File

@@ -116,7 +116,7 @@ def test_backpressure_received_lsn_lag(neon_env_builder: NeonEnvBuilder):
# Configure failpoint to slow down walreceiver ingest
with closing(env.pageserver.connect()) as psconn:
with psconn.cursor(cursor_factory=psycopg2.extras.DictCursor) as pscur:
pscur.execute("failpoints walreceiver-after-ingest=sleep(20)")
pscur.execute("failpoints walreceiver-after-ingest-blocking=sleep(20)")
# FIXME
# Wait for the check thread to start

View File

@@ -267,9 +267,10 @@ def test_forward_compatibility(
def check_neon_works(env: NeonEnv, test_output_dir: Path, sql_dump_path: Path, repo_dir: Path):
ep = env.endpoints.create_start("main")
connstr = ep.connstr()
pg_bin = PgBin(test_output_dir, env.pg_distrib_dir, env.pg_version)
connstr = ep.connstr()
pg_bin.run_capture(
["pg_dumpall", f"--dbname={connstr}", f"--file={test_output_dir / 'dump.sql'}"]
)
@@ -286,6 +287,9 @@ def check_neon_works(env: NeonEnv, test_output_dir: Path, sql_dump_path: Path, r
timeline_id = env.initial_timeline
pg_version = env.pg_version
# Stop endpoint while we recreate timeline
ep.stop()
try:
pageserver_http.timeline_preserve_initdb_archive(tenant_id, timeline_id)
except PageserverApiException as e:
@@ -310,6 +314,9 @@ def check_neon_works(env: NeonEnv, test_output_dir: Path, sql_dump_path: Path, r
existing_initdb_timeline_id=timeline_id,
)
# Timeline exists again: restart the endpoint
ep.start()
pg_bin.run_capture(
["pg_dumpall", f"--dbname={connstr}", f"--file={test_output_dir / 'dump-from-wal.sql'}"]
)

View File

@@ -84,3 +84,21 @@ def test_hot_standby(neon_simple_env: NeonEnv):
# clean up
if slow_down_send:
sk_http.configure_failpoints(("sk-send-wal-replica-sleep", "off"))
def test_2_replicas_start(neon_simple_env: NeonEnv):
env = neon_simple_env
with env.endpoints.create_start(
branch_name="main",
endpoint_id="primary",
) as primary:
time.sleep(1)
with env.endpoints.new_replica_start(
origin=primary, endpoint_id="secondary1"
) as secondary1:
with env.endpoints.new_replica_start(
origin=primary, endpoint_id="secondary2"
) as secondary2:
wait_replica_caughtup(primary, secondary1)
wait_replica_caughtup(primary, secondary2)

View File

@@ -1,4 +1,6 @@
import gzip
import json
import os
import time
from dataclasses import dataclass
from pathlib import Path
@@ -10,7 +12,11 @@ from fixtures.neon_fixtures import (
NeonEnvBuilder,
wait_for_last_flush_lsn,
)
from fixtures.remote_storage import RemoteStorageKind
from fixtures.remote_storage import (
LocalFsStorage,
RemoteStorageKind,
remote_storage_to_toml_inline_table,
)
from fixtures.types import TenantId, TimelineId
from pytest_httpserver import HTTPServer
from werkzeug.wrappers.request import Request
@@ -40,6 +46,9 @@ def test_metric_collection(
uploads.put((events, is_last == "true"))
return Response(status=200)
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
assert neon_env_builder.pageserver_remote_storage is not None
# Require collecting metrics frequently, since we change
# the timeline and want something to be logged about it.
#
@@ -48,12 +57,11 @@ def test_metric_collection(
neon_env_builder.pageserver_config_override = f"""
metric_collection_interval="1s"
metric_collection_endpoint="{metric_collection_endpoint}"
metric_collection_bucket={remote_storage_to_toml_inline_table(neon_env_builder.pageserver_remote_storage)}
cached_metric_collection_interval="0s"
synthetic_size_calculation_interval="3s"
"""
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
log.info(f"test_metric_collection endpoint is {metric_collection_endpoint}")
# mock http server that returns OK for the metrics
@@ -70,6 +78,7 @@ def test_metric_collection(
# we have a fast rate of calculation, these can happen at shutdown
".*synthetic_size_worker:calculate_synthetic_size.*:gather_size_inputs.*: failed to calculate logical size at .*: cancelled.*",
".*synthetic_size_worker: failed to calculate synthetic size for tenant .*: failed to calculate some logical_sizes",
".*metrics_collection: failed to upload to S3: Failed to upload data of length .* to storage path.*",
]
)
@@ -166,6 +175,20 @@ def test_metric_collection(
httpserver.check()
# Check that at least one bucket output object is present, and that all
# can be decompressed and decoded.
bucket_dumps = {}
assert isinstance(env.pageserver_remote_storage, LocalFsStorage)
for dirpath, _dirs, files in os.walk(env.pageserver_remote_storage.root):
for file in files:
file_path = os.path.join(dirpath, file)
log.info(file_path)
if file.endswith(".gz"):
bucket_dumps[file_path] = json.load(gzip.open(file_path))
assert len(bucket_dumps) >= 1
assert all("events" in data for data in bucket_dumps.values())
def test_metric_collection_cleans_up_tempfile(
httpserver: HTTPServer,

View File

@@ -1,5 +1,4 @@
import asyncio
import time
from typing import Tuple
import pytest
@@ -10,7 +9,7 @@ from fixtures.neon_fixtures import (
tenant_get_shards,
)
from fixtures.pageserver.http import PageserverHttpClient
from fixtures.pageserver.utils import wait_for_last_record_lsn
from fixtures.pageserver.utils import wait_for_last_record_lsn, wait_for_upload
from fixtures.types import Lsn, TenantId, TimelineId
from fixtures.utils import wait_until
@@ -61,6 +60,15 @@ def wait_until_pageserver_is_caught_up(
assert waited >= last_flush_lsn
def wait_until_pageserver_has_uploaded(
env: NeonEnv, last_flush_lsns: list[Tuple[TenantId, TimelineId, Lsn]]
):
for tenant, timeline, last_flush_lsn in last_flush_lsns:
shards = tenant_get_shards(env, tenant)
for tenant_shard_id, pageserver in shards:
wait_for_upload(pageserver.http_client(), tenant_shard_id, timeline, last_flush_lsn)
def wait_for_wal_ingest_metric(pageserver_http: PageserverHttpClient) -> float:
def query():
value = pageserver_http.get_metric_value("pageserver_wal_ingest_records_received_total")
@@ -86,25 +94,50 @@ def test_pageserver_small_inmemory_layers(
The workload creates a number of timelines and writes some data to each,
but not enough to trigger flushes via the `checkpoint_distance` config.
"""
def get_dirty_bytes():
v = (
env.pageserver.http_client().get_metric_value("pageserver_timeline_ephemeral_bytes")
or 0
)
log.info(f"dirty_bytes: {v}")
return v
def assert_dirty_bytes(v):
assert get_dirty_bytes() == v
env = neon_env_builder.init_configs()
env.start()
last_flush_lsns = asyncio.run(workload(env, TIMELINE_COUNT, ENTRIES_PER_TIMELINE))
wait_until_pageserver_is_caught_up(env, last_flush_lsns)
# We didn't write enough data to trigger a size-based checkpoint
assert get_dirty_bytes() > 0
ps_http_client = env.pageserver.http_client()
total_wal_ingested_before_restart = wait_for_wal_ingest_metric(ps_http_client)
log.info("Sleeping for checkpoint timeout ...")
time.sleep(CHECKPOINT_TIMEOUT_SECONDS + 5)
# Within ~ the checkpoint interval, all the ephemeral layers should be frozen and flushed,
# such that there are zero bytes of ephemeral layer left on the pageserver
log.info("Waiting for background checkpoints...")
wait_until(CHECKPOINT_TIMEOUT_SECONDS * 2, 1, lambda: assert_dirty_bytes(0)) # type: ignore
# Zero ephemeral layer bytes does not imply that all the frozen layers were uploaded: they
# must be uploaded to remain visible to the pageserver after restart.
wait_until_pageserver_has_uploaded(env, last_flush_lsns)
env.pageserver.restart(immediate=immediate_shutdown)
wait_until_pageserver_is_caught_up(env, last_flush_lsns)
# Catching up with WAL ingest should have resulted in zero bytes of ephemeral layers, since
# we froze, flushed and uploaded everything before restarting. There can be no more WAL writes
# because we shut down compute endpoints before flushing.
assert get_dirty_bytes() == 0
total_wal_ingested_after_restart = wait_for_wal_ingest_metric(ps_http_client)
log.info(f"WAL ingested before restart: {total_wal_ingested_before_restart}")
log.info(f"WAL ingested after restart: {total_wal_ingested_after_restart}")
leeway = total_wal_ingested_before_restart * 5 / 100
assert total_wal_ingested_after_restart <= leeway
assert total_wal_ingested_after_restart == 0

Some files were not shown because too many files have changed in this diff Show More