Compare commits

..

26 Commits

Author SHA1 Message Date
Conrad Ludgate
2cca1b3e4e fix 2024-08-21 18:44:57 +01:00
Conrad Ludgate
471b3b300d fix pin 2024-08-21 16:29:52 +01:00
Conrad Ludgate
fbd4b91169 asyncreadready 2024-08-21 16:16:49 +01:00
Conrad Ludgate
8cc45ad9bd asrawfd things 2024-08-21 15:28:25 +01:00
Conrad Ludgate
aabbd55187 add ktls handling 2024-08-21 14:42:41 +01:00
Conrad Ludgate
987a859352 start integrating ktls 2024-08-21 14:11:58 +01:00
Conrad Ludgate
e171fd805b add ktls dep 2024-08-21 13:51:02 +01:00
Conrad Ludgate
1e4702b26a update rustls 2024-08-21 13:47:19 +01:00
Joonas Koivunen
3b8016488e test: test_timeline_ancestor_detach_errors rare allowed_error (#8782)
Add another allowed_error for this rarity.

Fixes: #8773
2024-08-21 12:51:08 +01:00
Joonas Koivunen
477246f42c storcon: handle heartbeater shutdown gracefully (#8767)
if a heartbeat happens during shutdown, then the task is already
cancelled and will not be sending responses.

Fixes: #8766
2024-08-21 12:28:27 +01:00
Christian Schwarz
21b684718e pageserver: add counter for wait time on background loop semaphore (#8769)
## Problem

Compaction jobs and other background loops are concurrency-limited
through a global semaphore.

The current counters allow quantifying how _many_ tasks are waiting.
But there is no way to tell how _much_ delay is added by the semaphore.

So, add a counter that aggregates the wall clock time seconds spent
acquiring the semaphore.

The metrics can be used as follows:

* retroactively calculate average acquisition time in a given time range
* compare the degree of background loop backlog among pageservers

The metric is insufficient to calculate

* run-up of ongoing acquisitions that haven't finished acquiring yet
* Not easily feasible because ["Cancelling a call to acquire makes you
lose your place in the
queue"](https://docs.rs/tokio/latest/tokio/sync/struct.Semaphore.html#method.acquire)

## Summary of changes

* Refactor the metrics to follow the current best practice for typed
metrics in `metrics.rs`.
* Add the new counter.
2024-08-21 10:55:01 +00:00
Peter Bendel
6d8572ded6 Benchmarking: need to checkout actions to download Neon artifacts (#8770)
## Problem

Database preparation workflow needs Neon artifacts but does not checkout
necessary download action.

We were lucke in a few runs like this one

https://github.com/neondatabase/neon/actions/runs/10413970941/job/28870668020

but this is flaky and a race condition which failed here


https://github.com/neondatabase/neon/actions/runs/10446395644/job/28923749772#step:4:1



## Summary of changes

Checkout code (including actions) before invoking download action

Successful test run
https://github.com/neondatabase/neon/actions/runs/10469356296/job/28992200694
2024-08-21 08:08:49 +01:00
Alex Chi Z.
c8b9116a97 impr(pageserver): abort on fatal I/O writer error (#8777)
part of https://github.com/neondatabase/neon/issues/8140

The blob writer path now uses `maybe_fatal_err`

Signed-off-by: Alex Chi Z <chi@neon.tech>
2024-08-20 20:05:33 +01:00
John Spray
beefc7a810 pageserver: add metric pageserver_secondary_heatmap_total_size (#8768)
## Problem

We don't have a convenient way for a human to ask "how far are secondary
downloads along for this tenant".

This is useful when driving migrations of tenants to the storage
controller, as we first create a secondary location and want to see it
warm up before we cut over. That can already be done via storcon_cli,
but we would like a way that doesn't require direct API access.

## Summary of changes

Add a metric that reports to total size of layers in the heatmap: this
may be used in conjunction with the existing
`pageserver_secondary_resident_physical_size` to estimate "warmth" of
the secondary location.
2024-08-20 19:47:42 +01:00
Vlad Lazar
fa0750a37e storcon: add peer jwt token (#8764)
## Problem

Storage controllers did not have the right token to speak to their peers
for leadership transitions.

## Summary of changes

Accept a peer jwt token for the storage controller.

Epic: https://github.com/neondatabase/cloud/issues/14701
2024-08-20 15:25:21 +01:00
Conrad Ludgate
0170611a97 proxy: small changes (#8752)
## Problem

#8736 is getting too big. splitting off some simple changes here

## Summary of changes

Local proxy wont always be using tls, so make it optional. Local proxy
wont be using ws for now, so make it optional. Remove a dead config var.
2024-08-20 14:16:27 +01:00
Vlad Lazar
1c96957e85 storcon: run db migrations after step down sequence (#8756)
## Problem

Previously, we would run db migrations before doing the step-down
sequence. This meant that the current leader would have to deal with
the schema changes and that's generally not safe.

## Summary of changes

Push the step-down procedure earlier in start-up and
do db migrations right after it (but before we load-up the in-memory
state from the db).

Epic: https://github.com/neondatabase/cloud/issues/14701
2024-08-20 14:00:36 +01:00
John Spray
02a28c01ca Revert "safekeeper: check for non-consecutive writes in safekeeper.rs" (#8771)
Reverts neondatabase/neon#8640

This broke `test_last_log_term_switch` via a merge race of some kind.
2024-08-20 11:34:53 +00:00
Alexander Bayandin
c96593b473 Make Postgres 16 default version (#8745)
## Problem

The default Postgres version is set to 15 in code, while we use 16 in
most of the other places (and Postgres 17 is coming)

## Summary of changes
- Run `benchmarks` job with Postgres 16 (instead of Postgres 14)
- Set `DEFAULT_PG_VERSION` to 16 in all places
- Remove deprecated `--pg-version` pytest argument
- Update `test_metadata_bincode_serde_ensure_roundtrip` for Postgres 16
2024-08-20 10:46:58 +01:00
Christian Schwarz
ef57e73fbf task_mgr::spawn: require a TenantId (#8462)
… to dis-incentivize global tasks via task_mgr in the future

(As of https://github.com/neondatabase/neon/pull/8339 all remaining
task_mgr usage is tenant or timeline scoped.)
2024-08-20 08:26:44 +00:00
Arseny Sher
4c5a0fdc75 safekeeper: check for non-consecutive writes in safekeeper.rs
wal_storage.rs already checks this, but since this is a quite legit scenario
check it at safekeeper.rs (consensus level) as well.

ref https://github.com/neondatabase/neon/issues/8212
2024-08-20 07:12:56 +03:00
Arpad Müller
4b26783c94 scrubber: remove _generic postfix and two unused functions (#8761)
Removes the `_generic` postfix from the `GenericRemoteStorage` using
APIs, as `remote_storage` is the "default" now, and add a `_s3` postfix
to the remaining APIs using the S3 SDK (only in tenant snapshot). Also,
remove two unused functions: `list_objects_with_retries` and
`stream_tenants functions`.

Part of https://github.com/neondatabase/neon/issues/7547
2024-08-19 23:58:47 +02:00
Arpad Müller
6949b45e17 Update aws -> infra for repo rename (#8755)
See slack thread:
https://neondb.slack.com/archives/C039YKBRZB4/p1722501766006179
2024-08-19 17:44:10 +02:00
Arpad Müller
3b8ca477ab Migrate physical GC and scan_metadata to remote_storage (#8673)
Migrates most of the remaining parts of the scrubber to remote_storage:

* `pageserver_physical_gc`
* `scan_metadata` for pageservers (safekeepers were done in #8595)
* `download()` in `tenant_snapshot`. The main `tenant_snapshot` is not
migrated as it uses version history to be able to work in the face of
ongoing changes.
 
Part of #7547
2024-08-19 16:39:44 +02:00
Christian Schwarz
eb7241c798 l0_flush: remove support for mode page-cached (#8739)
It's been rolled out everywhere, no configs are referencing it.

All code that's made dead by the removal of the config option is removed
as part of this PR.

The `page_caching::PreWarmingWriter` in `::No` mode is equivalent to a
`size_tracking_writer`, so, use that.

part of https://github.com/neondatabase/neon/issues/7418
2024-08-19 16:35:34 +02:00
Folke Behrens
f246aa3ca7 proxy: Fix some warnings by extended clippy checks (#8748)
* Missing blank lifetimes which is now deprecated.
* Matching off unqualified enum variants that could act like variable.
* Missing semicolons.
2024-08-19 10:33:46 +02:00
108 changed files with 1642 additions and 1448 deletions

View File

@@ -43,7 +43,7 @@ inputs:
pg_version:
description: 'Postgres version to use for tests'
required: false
default: 'v14'
default: 'v16'
benchmark_durations:
description: 'benchmark durations JSON'
required: false

View File

@@ -48,6 +48,8 @@ jobs:
echo "connstr=${CONNSTR}" >> $GITHUB_OUTPUT
- uses: actions/checkout@v4
- name: Download Neon artifact
uses: ./.github/actions/download
with:

View File

@@ -280,6 +280,7 @@ jobs:
save_perf_report: ${{ github.ref_name == 'main' }}
extra_params: --splits 5 --group ${{ matrix.pytest_split_group }}
benchmark_durations: ${{ needs.get-benchmarks-durations.outputs.json }}
pg_version: v16
env:
VIP_VAP_ACCESS_TOKEN: "${{ secrets.VIP_VAP_ACCESS_TOKEN }}"
PERF_TEST_RESULT_CONNSTR: "${{ secrets.PERF_TEST_RESULT_CONNSTR }}"
@@ -985,10 +986,10 @@ jobs:
GH_TOKEN: ${{ secrets.CI_ACCESS_TOKEN }}
run: |
if [[ "$GITHUB_REF_NAME" == "main" ]]; then
gh workflow --repo neondatabase/aws run deploy-dev.yml --ref main -f branch=main -f dockerTag=${{needs.tag.outputs.build-tag}} -f deployPreprodRegion=false
gh workflow --repo neondatabase/infra run deploy-dev.yml --ref main -f branch=main -f dockerTag=${{needs.tag.outputs.build-tag}} -f deployPreprodRegion=false
gh workflow --repo neondatabase/azure run deploy.yml -f dockerTag=${{needs.tag.outputs.build-tag}}
elif [[ "$GITHUB_REF_NAME" == "release" ]]; then
gh workflow --repo neondatabase/aws run deploy-dev.yml --ref main \
gh workflow --repo neondatabase/infra run deploy-dev.yml --ref main \
-f deployPgSniRouter=false \
-f deployProxy=false \
-f deployStorage=true \
@@ -998,14 +999,14 @@ jobs:
-f dockerTag=${{needs.tag.outputs.build-tag}} \
-f deployPreprodRegion=true
gh workflow --repo neondatabase/aws run deploy-prod.yml --ref main \
gh workflow --repo neondatabase/infra run deploy-prod.yml --ref main \
-f deployStorage=true \
-f deployStorageBroker=true \
-f deployStorageController=true \
-f branch=main \
-f dockerTag=${{needs.tag.outputs.build-tag}}
elif [[ "$GITHUB_REF_NAME" == "release-proxy" ]]; then
gh workflow --repo neondatabase/aws run deploy-dev.yml --ref main \
gh workflow --repo neondatabase/infra run deploy-dev.yml --ref main \
-f deployPgSniRouter=true \
-f deployProxy=true \
-f deployStorage=false \
@@ -1015,7 +1016,7 @@ jobs:
-f dockerTag=${{needs.tag.outputs.build-tag}} \
-f deployPreprodRegion=true
gh workflow --repo neondatabase/aws run deploy-proxy-prod.yml --ref main \
gh workflow --repo neondatabase/infra run deploy-proxy-prod.yml --ref main \
-f deployPgSniRouter=true \
-f deployProxy=true \
-f branch=main \

327
Cargo.lock generated
View File

@@ -316,6 +316,33 @@ dependencies = [
"zeroize",
]
[[package]]
name = "aws-lc-rs"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ae74d9bd0a7530e8afd1770739ad34b36838829d6ad61818f9230f683f5ad77"
dependencies = [
"aws-lc-sys",
"mirai-annotations",
"paste",
"zeroize",
]
[[package]]
name = "aws-lc-sys"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f0e249228c6ad2d240c2dc94b714d711629d52bad946075d8e9b2f5391f0703"
dependencies = [
"bindgen 0.69.4",
"cc",
"cmake",
"dunce",
"fs_extra",
"libc",
"paste",
]
[[package]]
name = "aws-runtime"
version = "1.2.1"
@@ -926,7 +953,30 @@ dependencies = [
"lazycell",
"log",
"peeking_take_while",
"prettyplease 0.2.6",
"prettyplease 0.2.17",
"proc-macro2",
"quote",
"regex",
"rustc-hash",
"shlex",
"syn 2.0.52",
"which",
]
[[package]]
name = "bindgen"
version = "0.69.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a00dc851838a2120612785d195287475a3ac45514741da670b735818822129a0"
dependencies = [
"bitflags 2.4.1",
"cexpr",
"clang-sys",
"itertools 0.12.1",
"lazy_static",
"lazycell",
"log",
"prettyplease 0.2.17",
"proc-macro2",
"quote",
"regex",
@@ -1056,6 +1106,12 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "cfg_aliases"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]]
name = "cgroups-rs"
version = "0.3.3"
@@ -1164,6 +1220,15 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b"
[[package]]
name = "cmake"
version = "0.1.51"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fb1e43aa7fd152b1f968787f7dbcdeb306d1867ff373c69955211876c053f91a"
dependencies = [
"cc",
]
[[package]]
name = "colorchoice"
version = "1.0.0"
@@ -1492,7 +1557,7 @@ dependencies = [
"bitflags 1.3.2",
"crossterm_winapi",
"libc",
"mio",
"mio 0.8.11",
"parking_lot 0.12.1",
"signal-hook",
"signal-hook-mio",
@@ -1768,6 +1833,12 @@ dependencies = [
"syn 2.0.52",
]
[[package]]
name = "dunce"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813"
[[package]]
name = "dyn-clone"
version = "1.0.14"
@@ -2069,6 +2140,12 @@ dependencies = [
"tokio-util",
]
[[package]]
name = "fs_extra"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "fsevent-sys"
version = "4.1.0"
@@ -2402,9 +2479,9 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hermit-abi"
version = "0.3.3"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7"
checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024"
[[package]]
name = "hex"
@@ -2922,6 +2999,33 @@ dependencies = [
"libc",
]
[[package]]
name = "ktls"
version = "6.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebe51e4a53d53b396707537bc8a5277798b720fb71f0d1b9c63eb53199a00fde"
dependencies = [
"futures-util",
"ktls-sys",
"libc",
"memoffset 0.9.1",
"nix 0.29.0",
"num_enum",
"pin-project-lite",
"rustls 0.23.12",
"smallvec",
"thiserror",
"tokio",
"tokio-rustls 0.26.0",
"tracing",
]
[[package]]
name = "ktls-sys"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "095b1fc8d841c3df8c3f2db78b7425cb2ec424568a282cb589a880b99d256e84"
[[package]]
name = "lasso"
version = "0.7.2"
@@ -2960,9 +3064,9 @@ dependencies = [
[[package]]
name = "libc"
version = "0.2.150"
version = "0.2.158"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c"
checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439"
[[package]]
name = "libloading"
@@ -3126,9 +3230,9 @@ dependencies = [
[[package]]
name = "memoffset"
version = "0.9.0"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c"
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
dependencies = [
"autocfg",
]
@@ -3204,6 +3308,24 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "mio"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec"
dependencies = [
"hermit-abi",
"libc",
"wasi 0.11.0+wasi-snapshot-preview1",
"windows-sys 0.52.0",
]
[[package]]
name = "mirai-annotations"
version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c9be0862c1b3f26a88803c4a49de6889c10e608b3ee9344e6ef5b45fb37ad3d1"
[[package]]
name = "multimap"
version = "0.8.3"
@@ -3244,7 +3366,20 @@ dependencies = [
"bitflags 2.4.1",
"cfg-if",
"libc",
"memoffset 0.9.0",
"memoffset 0.9.1",
]
[[package]]
name = "nix"
version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46"
dependencies = [
"bitflags 2.4.1",
"cfg-if",
"cfg_aliases",
"libc",
"memoffset 0.9.1",
]
[[package]]
@@ -3271,7 +3406,7 @@ dependencies = [
"kqueue",
"libc",
"log",
"mio",
"mio 0.8.11",
"walkdir",
"windows-sys 0.48.0",
]
@@ -3393,6 +3528,27 @@ dependencies = [
"libc",
]
[[package]]
name = "num_enum"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e613fc340b2220f734a8595782c551f1250e969d87d3be1ae0579e8d4065179"
dependencies = [
"num_enum_derive",
]
[[package]]
name = "num_enum_derive"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af1844ef2428cc3e1cb900be36181049ef3d3193c63e43026cfe202983b27a56"
dependencies = [
"proc-macro-crate",
"proc-macro2",
"quote",
"syn 2.0.52",
]
[[package]]
name = "oauth2"
version = "4.4.2"
@@ -4056,9 +4212,9 @@ dependencies = [
[[package]]
name = "pin-project-lite"
version = "0.2.13"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58"
checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02"
[[package]]
name = "pin-utils"
@@ -4183,14 +4339,14 @@ dependencies = [
"futures",
"once_cell",
"pq_proto",
"rustls 0.22.4",
"rustls 0.23.12",
"rustls-pemfile 2.1.1",
"serde",
"thiserror",
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",
"tokio-rustls 0.25.0",
"tokio-rustls 0.26.0",
"tokio-util",
"tracing",
"workspace_hack",
@@ -4214,7 +4370,7 @@ name = "postgres_ffi"
version = "0.1.0"
dependencies = [
"anyhow",
"bindgen",
"bindgen 0.65.1",
"byteorder",
"bytes",
"crc32c",
@@ -4282,9 +4438,9 @@ dependencies = [
[[package]]
name = "prettyplease"
version = "0.2.6"
version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b69d39aab54d069e7f2fe8cb970493e7834601ca2d8c65fd7bbd183578080d1"
checksum = "8d3928fb5db768cb86f891ff014f0144589297e3c6a1aba6ed7cecfdace270c7"
dependencies = [
"proc-macro2",
"syn 2.0.52",
@@ -4299,6 +4455,15 @@ dependencies = [
"elliptic-curve 0.13.8",
]
[[package]]
name = "proc-macro-crate"
version = "3.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d37c51ca738a55da99dc0c4a34860fd675453b8b36209178c2249bb13651284"
dependencies = [
"toml_edit 0.21.1",
]
[[package]]
name = "proc-macro-hack"
version = "0.5.20+deprecated"
@@ -4457,6 +4622,7 @@ dependencies = [
"itertools 0.10.5",
"jose-jwa",
"jose-jwk",
"ktls",
"lasso",
"md5",
"measured",
@@ -4487,7 +4653,7 @@ dependencies = [
"rsa",
"rstest",
"rustc-hash",
"rustls 0.22.4",
"rustls 0.23.12",
"rustls-native-certs 0.7.0",
"rustls-pemfile 2.1.1",
"scopeguard",
@@ -4506,7 +4672,7 @@ dependencies = [
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",
"tokio-rustls 0.25.0",
"tokio-rustls 0.26.0",
"tokio-tungstenite",
"tokio-util",
"tower-service",
@@ -4672,12 +4838,13 @@ dependencies = [
[[package]]
name = "rcgen"
version = "0.12.1"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48406db8ac1f3cbc7dcdb56ec355343817958a356ff430259bb07baf7607e1e1"
checksum = "54077e1872c46788540de1ea3d7f4ccb1983d12f9aa909b234468676c1a36779"
dependencies = [
"pem",
"ring 0.17.6",
"rustls-pki-types",
"time",
"yasna",
]
@@ -5190,7 +5357,22 @@ dependencies = [
"log",
"ring 0.17.6",
"rustls-pki-types",
"rustls-webpki 0.102.2",
"rustls-webpki 0.102.6",
"subtle",
"zeroize",
]
[[package]]
name = "rustls"
version = "0.23.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c58f8c84392efc0a126acce10fa59ff7b3d2ac06ab451a33f2741989b806b044"
dependencies = [
"aws-lc-rs",
"log",
"once_cell",
"rustls-pki-types",
"rustls-webpki 0.102.6",
"subtle",
"zeroize",
]
@@ -5241,9 +5423,9 @@ dependencies = [
[[package]]
name = "rustls-pki-types"
version = "1.3.1"
version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ede67b28608b4c60685c7d54122d4400d90f62b40caee7700e700380a390fa8"
checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0"
[[package]]
name = "rustls-webpki"
@@ -5267,10 +5449,11 @@ dependencies = [
[[package]]
name = "rustls-webpki"
version = "0.102.2"
version = "0.102.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "faaa0a62740bedb9b2ef5afa303da42764c012f743917351dc9a237ea1663610"
checksum = "8e6b52d4fda176fd835fdc55a835d4a89b8499cad995885a21149d5ad62f852e"
dependencies = [
"aws-lc-rs",
"ring 0.17.6",
"rustls-pki-types",
"untrusted 0.9.0",
@@ -5711,9 +5894,9 @@ dependencies = [
[[package]]
name = "sha2-asm"
version = "0.6.3"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f27ba7066011e3fb30d808b51affff34f0a66d3a03a58edd787c6e420e40e44e"
checksum = "b845214d6175804686b2bd482bcffe96651bb2d1200742b712003504a2dac1ab"
dependencies = [
"cc",
]
@@ -5750,7 +5933,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29ad2e15f37ec9a6cc544097b78a1ec90001e9f71b81338ca39f430adaca99af"
dependencies = [
"libc",
"mio",
"mio 0.8.11",
"signal-hook",
]
@@ -5812,9 +5995,9 @@ dependencies = [
[[package]]
name = "smallvec"
version = "1.13.1"
version = "1.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7"
checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
[[package]]
name = "smol_str"
@@ -6006,7 +6189,7 @@ dependencies = [
"rand 0.8.5",
"remote_storage",
"reqwest 0.12.4",
"rustls 0.22.4",
"rustls 0.23.12",
"rustls-native-certs 0.7.0",
"serde",
"serde_json",
@@ -6016,7 +6199,7 @@ dependencies = [
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",
"tokio-rustls 0.25.0",
"tokio-rustls 0.26.0",
"tokio-stream",
"tokio-util",
"tracing",
@@ -6228,18 +6411,18 @@ dependencies = [
[[package]]
name = "thiserror"
version = "1.0.57"
version = "1.0.63"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e45bcbe8ed29775f228095caf2cd67af7a4ccf756ebff23a306bf3e8b47b24b"
checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.57"
version = "1.0.63"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81"
checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
dependencies = [
"proc-macro2",
"quote",
@@ -6366,20 +6549,19 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokio"
version = "1.37.0"
version = "1.39.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787"
checksum = "9babc99b9923bfa4804bd74722ff02c0381021eafa4db9949217e3be8e84fff5"
dependencies = [
"backtrace",
"bytes",
"libc",
"mio",
"num_cpus",
"mio 1.0.2",
"pin-project-lite",
"signal-hook-registry",
"socket2 0.5.5",
"tokio-macros",
"windows-sys 0.48.0",
"windows-sys 0.52.0",
]
[[package]]
@@ -6410,9 +6592,9 @@ dependencies = [
[[package]]
name = "tokio-macros"
version = "2.2.0"
version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b"
checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752"
dependencies = [
"proc-macro2",
"quote",
@@ -6444,16 +6626,15 @@ dependencies = [
[[package]]
name = "tokio-postgres-rustls"
version = "0.11.1"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ea13f22eda7127c827983bdaf0d7fff9df21c8817bab02815ac277a21143677"
checksum = "04fb792ccd6bbcd4bba408eb8a292f70fc4a3589e5d793626f45190e6454b6ab"
dependencies = [
"futures",
"ring 0.17.6",
"rustls 0.22.4",
"rustls 0.23.12",
"tokio",
"tokio-postgres",
"tokio-rustls 0.25.0",
"tokio-rustls 0.26.0",
"x509-certificate",
]
@@ -6478,6 +6659,17 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-rustls"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4"
dependencies = [
"rustls 0.23.12",
"rustls-pki-types",
"tokio",
]
[[package]]
name = "tokio-stream"
version = "0.1.14"
@@ -6579,6 +6771,17 @@ dependencies = [
"winnow 0.4.6",
]
[[package]]
name = "toml_edit"
version = "0.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8534fd7f78b5405e860340ad6575217ce99f38d4d5c8f2442cb5ecb50090e1"
dependencies = [
"indexmap 2.0.1",
"toml_datetime",
"winnow 0.5.40",
]
[[package]]
name = "toml_edit"
version = "0.22.14"
@@ -6671,11 +6874,10 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52"
[[package]]
name = "tracing"
version = "0.1.37"
version = "0.1.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8"
checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
dependencies = [
"cfg-if",
"log",
"pin-project-lite",
"tracing-attributes",
@@ -6695,9 +6897,9 @@ dependencies = [
[[package]]
name = "tracing-attributes"
version = "0.1.24"
version = "0.1.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f57e3ca2a01450b1a921183a9c9cbfda207fd822cef4ccb00a65402cbba7a74"
checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
dependencies = [
"proc-macro2",
"quote",
@@ -6706,9 +6908,9 @@ dependencies = [
[[package]]
name = "tracing-core"
version = "0.1.31"
version = "0.1.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a"
checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54"
dependencies = [
"once_cell",
"valuable",
@@ -7109,7 +7311,7 @@ name = "walproposer"
version = "0.1.0"
dependencies = [
"anyhow",
"bindgen",
"bindgen 0.65.1",
"postgres_ffi",
"utils",
"workspace_hack",
@@ -7563,6 +7765,15 @@ dependencies = [
"memchr",
]
[[package]]
name = "winnow"
version = "0.5.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f593a95398737aeed53e489c785df13f3618e41dbcd6718c6addbf1395aa6876"
dependencies = [
"memchr",
]
[[package]]
name = "winnow"
version = "0.6.13"
@@ -7652,6 +7863,8 @@ dependencies = [
"reqwest 0.11.19",
"reqwest 0.12.4",
"rustls 0.21.11",
"rustls-pki-types",
"rustls-webpki 0.102.6",
"scopeguard",
"serde",
"serde_json",

View File

@@ -139,7 +139,7 @@ reqwest-retry = "0.5"
routerify = "3"
rpds = "0.13"
rustc-hash = "1.1.0"
rustls = "0.22"
rustls = "0.23"
rustls-pemfile = "2"
rustls-split = "0.3"
scopeguard = "1.1"
@@ -171,8 +171,8 @@ tikv-jemalloc-ctl = "0.5"
tokio = { version = "1.17", features = ["macros"] }
tokio-epoll-uring = { git = "https://github.com/neondatabase/tokio-epoll-uring.git" , branch = "main" }
tokio-io-timeout = "1.2.0"
tokio-postgres-rustls = "0.11.0"
tokio-rustls = "0.25"
tokio-postgres-rustls = "0.12.0"
tokio-rustls = "0.26"
tokio-stream = "0.1"
tokio-tar = "0.3"
tokio-util = { version = "0.7.10", features = ["io", "rt"] }
@@ -232,7 +232,7 @@ workspace_hack = { version = "0.1", path = "./workspace_hack/" }
## Build dependencies
criterion = "0.5.1"
rcgen = "0.12"
rcgen = "0.13"
rstest = "0.18"
camino-tempfile = "1.0.2"
tonic-build = "0.9"

View File

@@ -262,7 +262,7 @@ By default, this runs both debug and release modes, and all supported postgres v
testing locally, it is convenient to run just one set of permutations, like this:
```sh
DEFAULT_PG_VERSION=15 BUILD_TYPE=release ./scripts/pytest
DEFAULT_PG_VERSION=16 BUILD_TYPE=release ./scripts/pytest
```
## Flamegraphs

View File

@@ -54,7 +54,7 @@ const DEFAULT_PAGESERVER_ID: NodeId = NodeId(1);
const DEFAULT_BRANCH_NAME: &str = "main";
project_git_version!(GIT_VERSION);
const DEFAULT_PG_VERSION: &str = "15";
const DEFAULT_PG_VERSION: &str = "16";
const DEFAULT_PAGESERVER_CONTROL_PLANE_API: &str = "http://127.0.0.1:1234/upcall/v1/";

View File

@@ -27,7 +27,7 @@ use crate::pageserver::PageServerNode;
use crate::pageserver::PAGESERVER_REMOTE_STORAGE_DIR;
use crate::safekeeper::SafekeeperNode;
pub const DEFAULT_PG_VERSION: u32 = 15;
pub const DEFAULT_PG_VERSION: u32 = 16;
//
// This data structures represents neon_local CLI config

View File

@@ -217,7 +217,7 @@ impl StorageController {
Ok(exitcode.success())
}
/// Create our database if it doesn't exist, and run migrations.
/// Create our database if it doesn't exist
///
/// This function is equivalent to the `diesel setup` command in the diesel CLI. We implement
/// the same steps by hand to avoid imposing a dependency on installing diesel-cli for developers
@@ -382,7 +382,6 @@ impl StorageController {
)
.await?;
// Run migrations on every startup, in case something changed.
self.setup_database(postgres_port).await?;
}
@@ -454,6 +453,11 @@ impl StorageController {
let jwt_token =
encode_from_key_file(&claims, private_key).expect("failed to generate jwt token");
args.push(format!("--jwt-token={jwt_token}"));
let peer_claims = Claims::new(None, Scope::Admin);
let peer_jwt_token = encode_from_key_file(&peer_claims, private_key)
.expect("failed to generate jwt token");
args.push(format!("--peer-jwt-token={peer_jwt_token}"));
}
if let Some(public_key) = &self.public_key {

View File

@@ -14,7 +14,7 @@ picked tenant (which requested on-demand activation) for around 30 seconds
during the restart at 2024-04-03 16:37 UTC.
Note that lots of shutdowns on loaded pageservers do not finish within the
[10 second systemd enforced timeout](https://github.com/neondatabase/aws/blob/0a5280b383e43c063d43cbf87fa026543f6d6ad4/.github/ansible/systemd/pageserver.service#L16). This means we are shutting down without flushing ephemeral layers
[10 second systemd enforced timeout](https://github.com/neondatabase/infra/blob/0a5280b383e43c063d43cbf87fa026543f6d6ad4/.github/ansible/systemd/pageserver.service#L16). This means we are shutting down without flushing ephemeral layers
and have to reingest data in order to serve requests after restarting, potentially making first request latencies worse.
This problem is not yet very acutely felt in storage controller managed pageservers since

View File

@@ -383,6 +383,48 @@ impl RemoteStorage for AzureBlobStorage {
}
}
async fn head_object(
&self,
key: &RemotePath,
cancel: &CancellationToken,
) -> Result<ListingObject, DownloadError> {
let kind = RequestKind::Head;
let _permit = self.permit(kind, cancel).await?;
let started_at = start_measuring_requests(kind);
let blob_client = self.client.blob_client(self.relative_path_to_name(key));
let properties_future = blob_client.get_properties().into_future();
let properties_future = tokio::time::timeout(self.timeout, properties_future);
let res = tokio::select! {
res = properties_future => res,
_ = cancel.cancelled() => return Err(TimeoutOrCancel::Cancel.into()),
};
if let Ok(inner) = &res {
// do not incl. timeouts as errors in metrics but cancellations
let started_at = ScopeGuard::into_inner(started_at);
crate::metrics::BUCKET_METRICS
.req_seconds
.observe_elapsed(kind, inner, started_at);
}
let data = match res {
Ok(Ok(data)) => Ok(data),
Ok(Err(sdk)) => Err(to_download_error(sdk)),
Err(_timeout) => Err(DownloadError::Timeout),
}?;
let properties = data.blob.properties;
Ok(ListingObject {
key: key.to_owned(),
last_modified: SystemTime::from(properties.last_modified),
size: properties.content_length,
})
}
async fn upload(
&self,
from: impl Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,

View File

@@ -150,7 +150,7 @@ pub enum ListingMode {
NoDelimiter,
}
#[derive(PartialEq, Eq, Debug)]
#[derive(PartialEq, Eq, Debug, Clone)]
pub struct ListingObject {
pub key: RemotePath,
pub last_modified: SystemTime,
@@ -215,6 +215,13 @@ pub trait RemoteStorage: Send + Sync + 'static {
Ok(combined)
}
/// Obtain metadata information about an object.
async fn head_object(
&self,
key: &RemotePath,
cancel: &CancellationToken,
) -> Result<ListingObject, DownloadError>;
/// Streams the local file contents into remote into the remote storage entry.
///
/// If the operation fails because of timeout or cancellation, the root cause of the error will be
@@ -363,6 +370,20 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
}
}
// See [`RemoteStorage::head_object`].
pub async fn head_object(
&self,
key: &RemotePath,
cancel: &CancellationToken,
) -> Result<ListingObject, DownloadError> {
match self {
Self::LocalFs(s) => s.head_object(key, cancel).await,
Self::AwsS3(s) => s.head_object(key, cancel).await,
Self::AzureBlob(s) => s.head_object(key, cancel).await,
Self::Unreliable(s) => s.head_object(key, cancel).await,
}
}
/// See [`RemoteStorage::upload`]
pub async fn upload(
&self,
@@ -598,6 +619,7 @@ impl ConcurrencyLimiter {
RequestKind::Delete => &self.write,
RequestKind::Copy => &self.write,
RequestKind::TimeTravel => &self.write,
RequestKind::Head => &self.read,
}
}

View File

@@ -445,6 +445,20 @@ impl RemoteStorage for LocalFs {
}
}
async fn head_object(
&self,
key: &RemotePath,
_cancel: &CancellationToken,
) -> Result<ListingObject, DownloadError> {
let target_file_path = key.with_base(&self.storage_root);
let metadata = file_metadata(&target_file_path).await?;
Ok(ListingObject {
key: key.clone(),
last_modified: metadata.modified()?,
size: metadata.len(),
})
}
async fn upload(
&self,
data: impl Stream<Item = std::io::Result<Bytes>> + Send + Sync,

View File

@@ -13,6 +13,7 @@ pub(crate) enum RequestKind {
List = 3,
Copy = 4,
TimeTravel = 5,
Head = 6,
}
use scopeguard::ScopeGuard;
@@ -27,6 +28,7 @@ impl RequestKind {
List => "list_objects",
Copy => "copy_object",
TimeTravel => "time_travel_recover",
Head => "head_object",
}
}
const fn as_index(&self) -> usize {
@@ -34,7 +36,8 @@ impl RequestKind {
}
}
pub(crate) struct RequestTyped<C>([C; 6]);
const REQUEST_KIND_COUNT: usize = 7;
pub(crate) struct RequestTyped<C>([C; REQUEST_KIND_COUNT]);
impl<C> RequestTyped<C> {
pub(crate) fn get(&self, kind: RequestKind) -> &C {
@@ -43,8 +46,8 @@ impl<C> RequestTyped<C> {
fn build_with(mut f: impl FnMut(RequestKind) -> C) -> Self {
use RequestKind::*;
let mut it = [Get, Put, Delete, List, Copy, TimeTravel].into_iter();
let arr = std::array::from_fn::<C, 6, _>(|index| {
let mut it = [Get, Put, Delete, List, Copy, TimeTravel, Head].into_iter();
let arr = std::array::from_fn::<C, REQUEST_KIND_COUNT, _>(|index| {
let next = it.next().unwrap();
assert_eq!(index, next.as_index());
f(next)

View File

@@ -23,7 +23,7 @@ use aws_config::{
use aws_sdk_s3::{
config::{AsyncSleep, IdentityCache, Region, SharedAsyncSleep},
error::SdkError,
operation::get_object::GetObjectError,
operation::{get_object::GetObjectError, head_object::HeadObjectError},
types::{Delete, DeleteMarkerEntry, ObjectIdentifier, ObjectVersion, StorageClass},
Client,
};
@@ -604,6 +604,78 @@ impl RemoteStorage for S3Bucket {
}
}
async fn head_object(
&self,
key: &RemotePath,
cancel: &CancellationToken,
) -> Result<ListingObject, DownloadError> {
let kind = RequestKind::Head;
let _permit = self.permit(kind, cancel).await?;
let started_at = start_measuring_requests(kind);
let head_future = self
.client
.head_object()
.bucket(self.bucket_name())
.key(self.relative_path_to_s3_object(key))
.send();
let head_future = tokio::time::timeout(self.timeout, head_future);
let res = tokio::select! {
res = head_future => res,
_ = cancel.cancelled() => return Err(TimeoutOrCancel::Cancel.into()),
};
let res = res.map_err(|_e| DownloadError::Timeout)?;
// do not incl. timeouts as errors in metrics but cancellations
let started_at = ScopeGuard::into_inner(started_at);
crate::metrics::BUCKET_METRICS
.req_seconds
.observe_elapsed(kind, &res, started_at);
let data = match res {
Ok(object_output) => object_output,
Err(SdkError::ServiceError(e)) if matches!(e.err(), HeadObjectError::NotFound(_)) => {
// Count this in the AttemptOutcome::Ok bucket, because 404 is not
// an error: we expect to sometimes fetch an object and find it missing,
// e.g. when probing for timeline indices.
crate::metrics::BUCKET_METRICS.req_seconds.observe_elapsed(
kind,
AttemptOutcome::Ok,
started_at,
);
return Err(DownloadError::NotFound);
}
Err(e) => {
crate::metrics::BUCKET_METRICS.req_seconds.observe_elapsed(
kind,
AttemptOutcome::Err,
started_at,
);
return Err(DownloadError::Other(
anyhow::Error::new(e).context("s3 head object"),
));
}
};
let (Some(last_modified), Some(size)) = (data.last_modified, data.content_length) else {
return Err(DownloadError::Other(anyhow!(
"head_object doesn't contain last_modified or content_length"
)))?;
};
Ok(ListingObject {
key: key.to_owned(),
last_modified: SystemTime::try_from(last_modified).map_err(|e| {
DownloadError::Other(anyhow!("can't convert time '{last_modified}': {e}"))
})?,
size: size as u64,
})
}
async fn upload(
&self,
from: impl Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,

View File

@@ -30,6 +30,7 @@ pub struct UnreliableWrapper {
#[derive(Debug, Hash, Eq, PartialEq)]
enum RemoteOp {
ListPrefixes(Option<RemotePath>),
HeadObject(RemotePath),
Upload(RemotePath),
Download(RemotePath),
Delete(RemotePath),
@@ -137,6 +138,16 @@ impl RemoteStorage for UnreliableWrapper {
self.inner.list(prefix, mode, max_keys, cancel).await
}
async fn head_object(
&self,
key: &RemotePath,
cancel: &CancellationToken,
) -> Result<crate::ListingObject, DownloadError> {
self.attempt(RemoteOp::HeadObject(key.clone()))
.map_err(DownloadError::Other)?;
self.inner.head_object(key, cancel).await
}
async fn upload(
&self,
data: impl Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,

View File

@@ -1,15 +1,10 @@
use std::{num::NonZeroUsize, sync::Arc};
use crate::tenant::ephemeral_file;
#[derive(Debug, PartialEq, Eq, Clone, serde::Deserialize)]
#[serde(tag = "mode", rename_all = "kebab-case", deny_unknown_fields)]
pub enum L0FlushConfig {
PageCached,
#[serde(rename_all = "snake_case")]
Direct {
max_concurrency: NonZeroUsize,
},
Direct { max_concurrency: NonZeroUsize },
}
impl Default for L0FlushConfig {
@@ -25,14 +20,12 @@ impl Default for L0FlushConfig {
pub struct L0FlushGlobalState(Arc<Inner>);
pub enum Inner {
PageCached,
Direct { semaphore: tokio::sync::Semaphore },
}
impl L0FlushGlobalState {
pub fn new(config: L0FlushConfig) -> Self {
match config {
L0FlushConfig::PageCached => Self(Arc::new(Inner::PageCached)),
L0FlushConfig::Direct { max_concurrency } => {
let semaphore = tokio::sync::Semaphore::new(max_concurrency.get());
Self(Arc::new(Inner::Direct { semaphore }))
@@ -44,13 +37,3 @@ impl L0FlushGlobalState {
&self.0
}
}
impl L0FlushConfig {
pub(crate) fn prewarm_on_write(&self) -> ephemeral_file::PrewarmPageCacheOnWrite {
use L0FlushConfig::*;
match self {
PageCached => ephemeral_file::PrewarmPageCacheOnWrite::Yes,
Direct { .. } => ephemeral_file::PrewarmPageCacheOnWrite::No,
}
}
}

View File

@@ -49,7 +49,7 @@ use tracing::{info, info_span};
/// backwards-compatible changes to the metadata format.
pub const STORAGE_FORMAT_VERSION: u16 = 3;
pub const DEFAULT_PG_VERSION: u32 = 15;
pub const DEFAULT_PG_VERSION: u32 = 16;
// Magic constants used to identify different kinds of files
pub const IMAGE_FILE_MAGIC: u16 = 0x5A60;

View File

@@ -1803,6 +1803,15 @@ pub(crate) static SECONDARY_RESIDENT_PHYSICAL_SIZE: Lazy<UIntGaugeVec> = Lazy::n
.expect("failed to define a metric")
});
pub(crate) static SECONDARY_HEATMAP_TOTAL_SIZE: Lazy<UIntGaugeVec> = Lazy::new(|| {
register_uint_gauge_vec!(
"pageserver_secondary_heatmap_total_size",
"The total size in bytes of all layers in the most recently downloaded heatmap.",
&["tenant_id", "shard_id"]
)
.expect("failed to define a metric")
});
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum RemoteOpKind {
Upload,
@@ -1853,16 +1862,64 @@ pub(crate) static TENANT_TASK_EVENTS: Lazy<IntCounterVec> = Lazy::new(|| {
.expect("Failed to register tenant_task_events metric")
});
pub(crate) static BACKGROUND_LOOP_SEMAPHORE_WAIT_GAUGE: Lazy<IntCounterPairVec> = Lazy::new(|| {
register_int_counter_pair_vec!(
"pageserver_background_loop_semaphore_wait_start_count",
"Counter for background loop concurrency-limiting semaphore acquire calls started",
"pageserver_background_loop_semaphore_wait_finish_count",
"Counter for background loop concurrency-limiting semaphore acquire calls finished",
&["task"],
)
.unwrap()
});
pub struct BackgroundLoopSemaphoreMetrics {
counters: EnumMap<BackgroundLoopKind, IntCounterPair>,
durations: EnumMap<BackgroundLoopKind, Counter>,
}
pub(crate) static BACKGROUND_LOOP_SEMAPHORE: Lazy<BackgroundLoopSemaphoreMetrics> = Lazy::new(
|| {
let counters = register_int_counter_pair_vec!(
"pageserver_background_loop_semaphore_wait_start_count",
"Counter for background loop concurrency-limiting semaphore acquire calls started",
"pageserver_background_loop_semaphore_wait_finish_count",
"Counter for background loop concurrency-limiting semaphore acquire calls finished",
&["task"],
)
.unwrap();
let durations = register_counter_vec!(
"pageserver_background_loop_semaphore_wait_duration_seconds",
"Sum of wall clock time spent waiting on the background loop concurrency-limiting semaphore acquire calls",
&["task"],
)
.unwrap();
BackgroundLoopSemaphoreMetrics {
counters: enum_map::EnumMap::from_array(std::array::from_fn(|i| {
let kind = <BackgroundLoopKind as enum_map::Enum>::from_usize(i);
counters.with_label_values(&[kind.into()])
})),
durations: enum_map::EnumMap::from_array(std::array::from_fn(|i| {
let kind = <BackgroundLoopKind as enum_map::Enum>::from_usize(i);
durations.with_label_values(&[kind.into()])
})),
}
},
);
impl BackgroundLoopSemaphoreMetrics {
pub(crate) fn measure_acquisition(&self, task: BackgroundLoopKind) -> impl Drop + '_ {
struct Record<'a> {
metrics: &'a BackgroundLoopSemaphoreMetrics,
task: BackgroundLoopKind,
_counter_guard: metrics::IntCounterPairGuard,
start: Instant,
}
impl Drop for Record<'_> {
fn drop(&mut self) {
let elapsed = self.start.elapsed().as_secs_f64();
self.metrics.durations[self.task].inc_by(elapsed);
}
}
Record {
metrics: self,
task,
_counter_guard: self.counters[task].guard(),
start: Instant::now(),
}
}
}
pub(crate) static BACKGROUND_LOOP_PERIOD_OVERRUN_COUNT: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
@@ -2544,6 +2601,7 @@ use std::time::{Duration, Instant};
use crate::context::{PageContentKind, RequestContext};
use crate::task_mgr::TaskKind;
use crate::tenant::mgr::TenantSlot;
use crate::tenant::tasks::BackgroundLoopKind;
/// Maintain a per timeline gauge in addition to the global gauge.
pub(crate) struct PerTimelineRemotePhysicalSizeGauge {

View File

@@ -393,7 +393,7 @@ struct PageServerTask {
/// Tasks may optionally be launched for a particular tenant/timeline, enabling
/// later cancelling tasks for that tenant/timeline in [`shutdown_tasks`]
tenant_shard_id: Option<TenantShardId>,
tenant_shard_id: TenantShardId,
timeline_id: Option<TimelineId>,
mutable: Mutex<MutableTaskState>,
@@ -405,7 +405,7 @@ struct PageServerTask {
pub fn spawn<F>(
runtime: &tokio::runtime::Handle,
kind: TaskKind,
tenant_shard_id: Option<TenantShardId>,
tenant_shard_id: TenantShardId,
timeline_id: Option<TimelineId>,
name: &str,
future: F,
@@ -550,7 +550,7 @@ pub async fn shutdown_tasks(
let tasks = TASKS.lock().unwrap();
for task in tasks.values() {
if (kind.is_none() || Some(task.kind) == kind)
&& (tenant_shard_id.is_none() || task.tenant_shard_id == tenant_shard_id)
&& (tenant_shard_id.is_none() || Some(task.tenant_shard_id) == tenant_shard_id)
&& (timeline_id.is_none() || task.timeline_id == timeline_id)
{
task.cancel.cancel();
@@ -573,13 +573,8 @@ pub async fn shutdown_tasks(
};
if let Some(mut join_handle) = join_handle {
if log_all {
if tenant_shard_id.is_none() {
// there are quite few of these
info!(name = task.name, kind = ?task_kind, "stopping global task");
} else {
// warn to catch these in tests; there shouldn't be any
warn!(name = task.name, tenant_shard_id = ?tenant_shard_id, timeline_id = ?timeline_id, kind = ?task_kind, "stopping left-over");
}
// warn to catch these in tests; there shouldn't be any
warn!(name = task.name, tenant_shard_id = ?tenant_shard_id, timeline_id = ?timeline_id, kind = ?task_kind, "stopping left-over");
}
if tokio::time::timeout(std::time::Duration::from_secs(1), &mut join_handle)
.await

View File

@@ -798,7 +798,7 @@ impl Tenant {
task_mgr::spawn(
&tokio::runtime::Handle::current(),
TaskKind::Attach,
Some(tenant_shard_id),
tenant_shard_id,
None,
"attach tenant",
async move {

View File

@@ -21,7 +21,6 @@ pub struct EphemeralFile {
}
mod page_caching;
pub(crate) use page_caching::PrewarmOnWrite as PrewarmPageCacheOnWrite;
mod zero_padded_read_write;
impl EphemeralFile {
@@ -52,12 +51,10 @@ impl EphemeralFile {
)
.await?;
let prewarm = conf.l0_flush.prewarm_on_write();
Ok(EphemeralFile {
_tenant_shard_id: tenant_shard_id,
_timeline_id: timeline_id,
rw: page_caching::RW::new(file, prewarm, gate_guard),
rw: page_caching::RW::new(file, gate_guard),
})
}

View File

@@ -1,15 +1,15 @@
//! Wrapper around [`super::zero_padded_read_write::RW`] that uses the
//! [`crate::page_cache`] to serve reads that need to go to the underlying [`VirtualFile`].
//!
//! Subject to removal in <https://github.com/neondatabase/neon/pull/8537>
use crate::context::RequestContext;
use crate::page_cache::{self, PAGE_SZ};
use crate::tenant::block_io::BlockLease;
use crate::virtual_file::owned_buffers_io::io_buf_ext::FullSlice;
use crate::virtual_file::owned_buffers_io::util::size_tracking_writer;
use crate::virtual_file::VirtualFile;
use once_cell::sync::Lazy;
use std::io::{self, ErrorKind};
use std::ops::{Deref, Range};
use std::io::{self};
use tokio_epoll_uring::BoundedBuf;
use tracing::*;
@@ -18,33 +18,17 @@ use super::zero_padded_read_write;
/// See module-level comment.
pub struct RW {
page_cache_file_id: page_cache::FileId,
rw: super::zero_padded_read_write::RW<PreWarmingWriter>,
rw: super::zero_padded_read_write::RW<size_tracking_writer::Writer<VirtualFile>>,
/// Gate guard is held on as long as we need to do operations in the path (delete on drop).
_gate_guard: utils::sync::gate::GateGuard,
}
/// When we flush a block to the underlying [`crate::virtual_file::VirtualFile`],
/// should we pre-warm the [`crate::page_cache`] with the contents?
#[derive(Clone, Copy)]
pub enum PrewarmOnWrite {
Yes,
No,
}
impl RW {
pub fn new(
file: VirtualFile,
prewarm_on_write: PrewarmOnWrite,
_gate_guard: utils::sync::gate::GateGuard,
) -> Self {
pub fn new(file: VirtualFile, _gate_guard: utils::sync::gate::GateGuard) -> Self {
let page_cache_file_id = page_cache::next_file_id();
Self {
page_cache_file_id,
rw: super::zero_padded_read_write::RW::new(PreWarmingWriter::new(
page_cache_file_id,
file,
prewarm_on_write,
)),
rw: super::zero_padded_read_write::RW::new(size_tracking_writer::Writer::new(file)),
_gate_guard,
}
}
@@ -84,10 +68,10 @@ impl RW {
let vec = Vec::with_capacity(size);
// read from disk what we've already flushed
let writer = self.rw.as_writer();
let flushed_range = writer.written_range();
let mut vec = writer
.file
let file_size_tracking_writer = self.rw.as_writer();
let flushed_range = 0..usize::try_from(file_size_tracking_writer.bytes_written()).unwrap();
let mut vec = file_size_tracking_writer
.as_inner()
.read_exact_at(
vec.slice(0..(flushed_range.end - flushed_range.start)),
u64::try_from(flushed_range.start).unwrap(),
@@ -122,7 +106,7 @@ impl RW {
format!(
"ephemeral file: read immutable page #{}: {}: {:#}",
blknum,
self.rw.as_writer().file.path,
self.rw.as_writer().as_inner().path,
e,
),
)
@@ -132,7 +116,7 @@ impl RW {
}
page_cache::ReadBufResult::NotFound(write_guard) => {
let write_guard = writer
.file
.as_inner()
.read_exact_at_page(write_guard, blknum as u64 * PAGE_SZ as u64, ctx)
.await?;
let read_guard = write_guard.mark_valid();
@@ -154,137 +138,16 @@ impl Drop for RW {
// unlink the file
// we are clear to do this, because we have entered a gate
let res = std::fs::remove_file(&self.rw.as_writer().file.path);
let path = &self.rw.as_writer().as_inner().path;
let res = std::fs::remove_file(path);
if let Err(e) = res {
if e.kind() != std::io::ErrorKind::NotFound {
// just never log the not found errors, we cannot do anything for them; on detach
// the tenant directory is already gone.
//
// not found files might also be related to https://github.com/neondatabase/neon/issues/2442
error!(
"could not remove ephemeral file '{}': {}",
self.rw.as_writer().file.path,
e
);
error!("could not remove ephemeral file '{path}': {e}");
}
}
}
}
struct PreWarmingWriter {
prewarm_on_write: PrewarmOnWrite,
nwritten_blocks: u32,
page_cache_file_id: page_cache::FileId,
file: VirtualFile,
}
impl PreWarmingWriter {
fn new(
page_cache_file_id: page_cache::FileId,
file: VirtualFile,
prewarm_on_write: PrewarmOnWrite,
) -> Self {
Self {
prewarm_on_write,
nwritten_blocks: 0,
page_cache_file_id,
file,
}
}
/// Return the byte range within `file` that has been written though `write_all`.
///
/// The returned range would be invalidated by another `write_all`. To prevent that, we capture `&_`.
fn written_range(&self) -> (impl Deref<Target = Range<usize>> + '_) {
let nwritten_blocks = usize::try_from(self.nwritten_blocks).unwrap();
struct Wrapper(Range<usize>);
impl Deref for Wrapper {
type Target = Range<usize>;
fn deref(&self) -> &Range<usize> {
&self.0
}
}
Wrapper(0..nwritten_blocks * PAGE_SZ)
}
}
impl crate::virtual_file::owned_buffers_io::write::OwnedAsyncWriter for PreWarmingWriter {
async fn write_all<Buf: tokio_epoll_uring::IoBuf + Send>(
&mut self,
buf: FullSlice<Buf>,
ctx: &RequestContext,
) -> std::io::Result<(usize, FullSlice<Buf>)> {
let buflen = buf.len();
assert_eq!(
buflen % PAGE_SZ,
0,
"{buflen} ; we know TAIL_SZ is a PAGE_SZ multiple, and write_buffered_borrowed is used"
);
// Do the IO.
let buf = match self.file.write_all(buf, ctx).await {
(buf, Ok(nwritten)) => {
assert_eq!(nwritten, buflen);
buf
}
(_, Err(e)) => {
return Err(std::io::Error::new(
ErrorKind::Other,
// order error before path because path is long and error is short
format!(
"ephemeral_file: write_blob: write-back tail self.nwritten_blocks={}, buflen={}, {:#}: {}",
self.nwritten_blocks, buflen, e, self.file.path,
),
));
}
};
let nblocks = buflen / PAGE_SZ;
let nblocks32 = u32::try_from(nblocks).unwrap();
if matches!(self.prewarm_on_write, PrewarmOnWrite::Yes) {
// Pre-warm page cache with the contents.
// At least in isolated bulk ingest benchmarks (test_bulk_insert.py), the pre-warming
// benefits the code that writes InMemoryLayer=>L0 layers.
let cache = page_cache::get();
static CTX: Lazy<RequestContext> = Lazy::new(|| {
RequestContext::new(
crate::task_mgr::TaskKind::EphemeralFilePreWarmPageCache,
crate::context::DownloadBehavior::Error,
)
});
for blknum_in_buffer in 0..nblocks {
let blk_in_buffer =
&buf[blknum_in_buffer * PAGE_SZ..(blknum_in_buffer + 1) * PAGE_SZ];
let blknum = self
.nwritten_blocks
.checked_add(blknum_in_buffer as u32)
.unwrap();
match cache
.read_immutable_buf(self.page_cache_file_id, blknum, &CTX)
.await
{
Err(e) => {
error!("ephemeral_file write_blob failed to get immutable buf to pre-warm page cache: {e:?}");
// fail gracefully, it's not the end of the world if we can't pre-warm the cache here
}
Ok(v) => match v {
page_cache::ReadBufResult::Found(_guard) => {
// This function takes &mut self, so, it shouldn't be possible to reach this point.
unreachable!("we just wrote block {blknum} to the VirtualFile, which is owned by Self, \
and this function takes &mut self, so, no concurrent read_blk is possible");
}
page_cache::ReadBufResult::NotFound(mut write_guard) => {
write_guard.copy_from_slice(blk_in_buffer);
let _ = write_guard.mark_valid();
}
},
}
}
}
self.nwritten_blocks = self.nwritten_blocks.checked_add(nblocks32).unwrap();
Ok((buflen, buf))
}
}

View File

@@ -565,7 +565,7 @@ mod tests {
);
let expected_bytes = vec![
/* TimelineMetadataHeader */
4, 37, 101, 34, 0, 70, 0, 4, // checksum, size, format_version (4 + 2 + 2)
74, 104, 158, 105, 0, 70, 0, 4, // checksum, size, format_version (4 + 2 + 2)
/* TimelineMetadataBodyV2 */
0, 0, 0, 0, 0, 0, 2, 0, // disk_consistent_lsn (8 bytes)
1, 0, 0, 0, 0, 0, 0, 1, 0, // prev_record_lsn (9 bytes)
@@ -574,7 +574,7 @@ mod tests {
0, 0, 0, 0, 0, 0, 0, 0, // ancestor_lsn (8 bytes)
0, 0, 0, 0, 0, 0, 0, 0, // latest_gc_cutoff_lsn (8 bytes)
0, 0, 0, 0, 0, 0, 0, 0, // initdb_lsn (8 bytes)
0, 0, 0, 15, // pg_version (4 bytes)
0, 0, 0, 16, // pg_version (4 bytes)
/* padding bytes */
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

View File

@@ -1728,7 +1728,7 @@ impl RemoteTimelineClient {
task_mgr::spawn(
&self.runtime,
TaskKind::RemoteUploadTask,
Some(self.tenant_shard_id),
self.tenant_shard_id,
Some(self.timeline_id),
"remote upload",
async move {

View File

@@ -8,6 +8,7 @@ use std::{sync::Arc, time::SystemTime};
use crate::{
context::RequestContext,
disk_usage_eviction_task::DiskUsageEvictionInfo,
metrics::SECONDARY_HEATMAP_TOTAL_SIZE,
task_mgr::{self, TaskKind, BACKGROUND_RUNTIME},
};
@@ -105,6 +106,9 @@ pub(crate) struct SecondaryTenant {
// Sum of layer sizes on local disk
pub(super) resident_size_metric: UIntGauge,
// Sum of layer sizes in the most recently downloaded heatmap
pub(super) heatmap_total_size_metric: UIntGauge,
}
impl Drop for SecondaryTenant {
@@ -112,6 +116,7 @@ impl Drop for SecondaryTenant {
let tenant_id = self.tenant_shard_id.tenant_id.to_string();
let shard_id = format!("{}", self.tenant_shard_id.shard_slug());
let _ = SECONDARY_RESIDENT_PHYSICAL_SIZE.remove_label_values(&[&tenant_id, &shard_id]);
let _ = SECONDARY_HEATMAP_TOTAL_SIZE.remove_label_values(&[&tenant_id, &shard_id]);
}
}
@@ -128,6 +133,10 @@ impl SecondaryTenant {
.get_metric_with_label_values(&[&tenant_id, &shard_id])
.unwrap();
let heatmap_total_size_metric = SECONDARY_HEATMAP_TOTAL_SIZE
.get_metric_with_label_values(&[&tenant_id, &shard_id])
.unwrap();
Arc::new(Self {
tenant_shard_id,
// todo: shall we make this a descendent of the
@@ -145,6 +154,7 @@ impl SecondaryTenant {
progress: std::sync::Mutex::default(),
resident_size_metric,
heatmap_total_size_metric,
})
}

View File

@@ -829,6 +829,12 @@ impl<'a> TenantDownloader<'a> {
layers_downloaded: 0,
bytes_downloaded: 0,
};
// Also expose heatmap bytes_total as a metric
self.secondary_state
.heatmap_total_size_metric
.set(heatmap_stats.bytes);
// Accumulate list of things to delete while holding the detail lock, for execution after dropping the lock
let mut delete_layers = Vec::new();
let mut delete_timelines = Vec::new();

View File

@@ -13,7 +13,7 @@ use crate::tenant::ephemeral_file::EphemeralFile;
use crate::tenant::timeline::GetVectoredError;
use crate::tenant::PageReconstructError;
use crate::virtual_file::owned_buffers_io::io_buf_ext::IoBufExt;
use crate::{l0_flush, page_cache, walrecord};
use crate::{l0_flush, page_cache};
use anyhow::{anyhow, Result};
use camino::Utf8PathBuf;
use pageserver_api::key::CompactKey;
@@ -249,9 +249,7 @@ impl InMemoryLayer {
/// debugging function to print out the contents of the layer
///
/// this is likely completly unused
pub async fn dump(&self, verbose: bool, ctx: &RequestContext) -> Result<()> {
let inner = self.inner.read().await;
pub async fn dump(&self, _verbose: bool, _ctx: &RequestContext) -> Result<()> {
let end_str = self.end_lsn_or_max();
println!(
@@ -259,39 +257,6 @@ impl InMemoryLayer {
self.timeline_id, self.start_lsn, end_str,
);
if !verbose {
return Ok(());
}
let cursor = inner.file.block_cursor();
let mut buf = Vec::new();
for (key, vec_map) in inner.index.iter() {
for (lsn, pos) in vec_map.as_slice() {
let mut desc = String::new();
cursor.read_blob_into_buf(*pos, &mut buf, ctx).await?;
let val = Value::des(&buf);
match val {
Ok(Value::Image(img)) => {
write!(&mut desc, " img {} bytes", img.len())?;
}
Ok(Value::WalRecord(rec)) => {
let wal_desc = walrecord::describe_wal_record(&rec).unwrap();
write!(
&mut desc,
" rec {} bytes will_init: {} {}",
buf.len(),
rec.will_init(),
wal_desc
)?;
}
Err(err) => {
write!(&mut desc, " DESERIALIZATION ERROR: {}", err)?;
}
}
println!(" key {} at {}: {}", key, lsn, desc);
}
}
Ok(())
}
@@ -536,7 +501,6 @@ impl InMemoryLayer {
use l0_flush::Inner;
let _concurrency_permit = match l0_flush_global_state {
Inner::PageCached => None,
Inner::Direct { semaphore, .. } => Some(semaphore.acquire().await),
};
@@ -568,34 +532,6 @@ impl InMemoryLayer {
.await?;
match l0_flush_global_state {
l0_flush::Inner::PageCached => {
let ctx = RequestContextBuilder::extend(ctx)
.page_content_kind(PageContentKind::InMemoryLayer)
.build();
let mut buf = Vec::new();
let cursor = inner.file.block_cursor();
for (key, vec_map) in inner.index.iter() {
// Write all page versions
for (lsn, pos) in vec_map.as_slice() {
cursor.read_blob_into_buf(*pos, &mut buf, &ctx).await?;
let will_init = Value::des(&buf)?.will_init();
let (tmp, res) = delta_layer_writer
.put_value_bytes(
Key::from_compact(*key),
*lsn,
buf.slice_len(),
will_init,
&ctx,
)
.await;
res?;
buf = tmp.into_raw_slice().into_inner();
}
}
}
l0_flush::Inner::Direct { .. } => {
let file_contents: Vec<u8> = inner.file.load_to_vec(ctx).await?;
assert_eq!(

View File

@@ -61,21 +61,12 @@ impl BackgroundLoopKind {
}
}
static PERMIT_GAUGES: once_cell::sync::Lazy<
enum_map::EnumMap<BackgroundLoopKind, metrics::IntCounterPair>,
> = once_cell::sync::Lazy::new(|| {
enum_map::EnumMap::from_array(std::array::from_fn(|i| {
let kind = <BackgroundLoopKind as enum_map::Enum>::from_usize(i);
crate::metrics::BACKGROUND_LOOP_SEMAPHORE_WAIT_GAUGE.with_label_values(&[kind.into()])
}))
});
/// Cancellation safe.
pub(crate) async fn concurrent_background_tasks_rate_limit_permit(
loop_kind: BackgroundLoopKind,
_ctx: &RequestContext,
) -> tokio::sync::SemaphorePermit<'static> {
let _guard = PERMIT_GAUGES[loop_kind].guard();
let _guard = crate::metrics::BACKGROUND_LOOP_SEMAPHORE.measure_acquisition(loop_kind);
pausable_failpoint!(
"initial-size-calculation-permit-pause",
@@ -98,7 +89,7 @@ pub fn start_background_loops(
task_mgr::spawn(
BACKGROUND_RUNTIME.handle(),
TaskKind::Compaction,
Some(tenant_shard_id),
tenant_shard_id,
None,
&format!("compactor for tenant {tenant_shard_id}"),
{
@@ -121,7 +112,7 @@ pub fn start_background_loops(
task_mgr::spawn(
BACKGROUND_RUNTIME.handle(),
TaskKind::GarbageCollector,
Some(tenant_shard_id),
tenant_shard_id,
None,
&format!("garbage collector for tenant {tenant_shard_id}"),
{
@@ -144,7 +135,7 @@ pub fn start_background_loops(
task_mgr::spawn(
BACKGROUND_RUNTIME.handle(),
TaskKind::IngestHousekeeping,
Some(tenant_shard_id),
tenant_shard_id,
None,
&format!("ingest housekeeping for tenant {tenant_shard_id}"),
{

View File

@@ -2281,7 +2281,7 @@ impl Timeline {
task_mgr::spawn(
task_mgr::BACKGROUND_RUNTIME.handle(),
task_mgr::TaskKind::LayerFlushTask,
Some(self.tenant_shard_id),
self.tenant_shard_id,
Some(self.timeline_id),
"layer flush task",
async move {
@@ -2635,7 +2635,7 @@ impl Timeline {
task_mgr::spawn(
task_mgr::BACKGROUND_RUNTIME.handle(),
task_mgr::TaskKind::InitialLogicalSizeCalculation,
Some(self.tenant_shard_id),
self.tenant_shard_id,
Some(self.timeline_id),
"initial size calculation",
// NB: don't log errors here, task_mgr will do that.
@@ -2803,7 +2803,7 @@ impl Timeline {
task_mgr::spawn(
task_mgr::BACKGROUND_RUNTIME.handle(),
task_mgr::TaskKind::OndemandLogicalSizeCalculation,
Some(self.tenant_shard_id),
self.tenant_shard_id,
Some(self.timeline_id),
"ondemand logical size calculation",
async move {
@@ -5162,7 +5162,7 @@ impl Timeline {
let task_id = task_mgr::spawn(
task_mgr::BACKGROUND_RUNTIME.handle(),
task_mgr::TaskKind::DownloadAllRemoteLayers,
Some(self.tenant_shard_id),
self.tenant_shard_id,
Some(self.timeline_id),
"download all remote layers task",
async move {

View File

@@ -395,7 +395,7 @@ impl DeleteTimelineFlow {
task_mgr::spawn(
task_mgr::BACKGROUND_RUNTIME.handle(),
TaskKind::TimelineDeletionWorker,
Some(tenant_shard_id),
tenant_shard_id,
Some(timeline_id),
"timeline_delete",
async move {

View File

@@ -60,7 +60,7 @@ impl Timeline {
task_mgr::spawn(
BACKGROUND_RUNTIME.handle(),
TaskKind::Eviction,
Some(self.tenant_shard_id),
self.tenant_shard_id,
Some(self.timeline_id),
&format!(
"layer eviction for {}/{}",

View File

@@ -756,11 +756,23 @@ impl VirtualFile {
})
}
/// The function aborts the process if the error is fatal.
async fn write_at<B: IoBuf + Send>(
&self,
buf: FullSlice<B>,
offset: u64,
_ctx: &RequestContext, /* TODO: use for metrics: https://github.com/neondatabase/neon/issues/6107 */
) -> (FullSlice<B>, Result<usize, Error>) {
let (slice, result) = self.write_at_inner(buf, offset, _ctx).await;
let result = result.maybe_fatal_err("write_at");
(slice, result)
}
async fn write_at_inner<B: IoBuf + Send>(
&self,
buf: FullSlice<B>,
offset: u64,
_ctx: &RequestContext, /* TODO: use for metrics: https://github.com/neondatabase/neon/issues/6107 */
) -> (FullSlice<B>, Result<usize, Error>) {
let file_guard = match self.lock_file().await {
Ok(file_guard) => file_guard,

View File

@@ -114,6 +114,9 @@ rsa = "0.9"
workspace_hack.workspace = true
[target.'cfg(target_os = "linux")'.dependencies]
ktls = "6"
[dev-dependencies]
camino-tempfile.workspace = true
fallible-iterator.workspace = true

View File

@@ -113,38 +113,36 @@ impl<E: Into<AuthErrorImpl>> From<E> for AuthError {
impl UserFacingError for AuthError {
fn to_string_client(&self) -> String {
use AuthErrorImpl::*;
match self.0.as_ref() {
Link(e) => e.to_string_client(),
GetAuthInfo(e) => e.to_string_client(),
Sasl(e) => e.to_string_client(),
AuthFailed(_) => self.to_string(),
BadAuthMethod(_) => self.to_string(),
MalformedPassword(_) => self.to_string(),
MissingEndpointName => self.to_string(),
Io(_) => "Internal error".to_string(),
IpAddressNotAllowed(_) => self.to_string(),
TooManyConnections => self.to_string(),
UserTimeout(_) => self.to_string(),
AuthErrorImpl::Link(e) => e.to_string_client(),
AuthErrorImpl::GetAuthInfo(e) => e.to_string_client(),
AuthErrorImpl::Sasl(e) => e.to_string_client(),
AuthErrorImpl::AuthFailed(_) => self.to_string(),
AuthErrorImpl::BadAuthMethod(_) => self.to_string(),
AuthErrorImpl::MalformedPassword(_) => self.to_string(),
AuthErrorImpl::MissingEndpointName => self.to_string(),
AuthErrorImpl::Io(_) => "Internal error".to_string(),
AuthErrorImpl::IpAddressNotAllowed(_) => self.to_string(),
AuthErrorImpl::TooManyConnections => self.to_string(),
AuthErrorImpl::UserTimeout(_) => self.to_string(),
}
}
}
impl ReportableError for AuthError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
use AuthErrorImpl::*;
match self.0.as_ref() {
Link(e) => e.get_error_kind(),
GetAuthInfo(e) => e.get_error_kind(),
Sasl(e) => e.get_error_kind(),
AuthFailed(_) => crate::error::ErrorKind::User,
BadAuthMethod(_) => crate::error::ErrorKind::User,
MalformedPassword(_) => crate::error::ErrorKind::User,
MissingEndpointName => crate::error::ErrorKind::User,
Io(_) => crate::error::ErrorKind::ClientDisconnect,
IpAddressNotAllowed(_) => crate::error::ErrorKind::User,
TooManyConnections => crate::error::ErrorKind::RateLimit,
UserTimeout(_) => crate::error::ErrorKind::User,
AuthErrorImpl::Link(e) => e.get_error_kind(),
AuthErrorImpl::GetAuthInfo(e) => e.get_error_kind(),
AuthErrorImpl::Sasl(e) => e.get_error_kind(),
AuthErrorImpl::AuthFailed(_) => crate::error::ErrorKind::User,
AuthErrorImpl::BadAuthMethod(_) => crate::error::ErrorKind::User,
AuthErrorImpl::MalformedPassword(_) => crate::error::ErrorKind::User,
AuthErrorImpl::MissingEndpointName => crate::error::ErrorKind::User,
AuthErrorImpl::Io(_) => crate::error::ErrorKind::ClientDisconnect,
AuthErrorImpl::IpAddressNotAllowed(_) => crate::error::ErrorKind::User,
AuthErrorImpl::TooManyConnections => crate::error::ErrorKind::RateLimit,
AuthErrorImpl::UserTimeout(_) => crate::error::ErrorKind::User,
}
}
}

View File

@@ -4,6 +4,7 @@ pub mod jwt;
mod link;
use std::net::IpAddr;
use std::os::fd::AsRawFd;
use std::sync::Arc;
use std::time::Duration;
@@ -23,6 +24,7 @@ use crate::context::RequestMonitoring;
use crate::intern::EndpointIdInt;
use crate::metrics::Metrics;
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::proxy::handshake::KtlsAsyncReadReady;
use crate::proxy::NeonOptions;
use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter, RateBucketInfo};
use crate::stream::Stream;
@@ -80,9 +82,8 @@ pub trait TestBackend: Send + Sync + 'static {
impl std::fmt::Display for BackendType<'_, (), ()> {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use BackendType::*;
match self {
Console(api, _) => match &**api {
Self::Console(api, _) => match &**api {
ConsoleBackend::Console(endpoint) => {
fmt.debug_tuple("Console").field(&endpoint.url()).finish()
}
@@ -93,7 +94,7 @@ impl std::fmt::Display for BackendType<'_, (), ()> {
#[cfg(test)]
ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(),
},
Link(url, _) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
Self::Link(url, _) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
}
}
}
@@ -102,10 +103,9 @@ impl<T, D> BackendType<'_, T, D> {
/// Very similar to [`std::option::Option::as_ref`].
/// This helps us pass structured config to async tasks.
pub fn as_ref(&self) -> BackendType<'_, &T, &D> {
use BackendType::*;
match self {
Console(c, x) => Console(MaybeOwned::Borrowed(c), x),
Link(c, x) => Link(MaybeOwned::Borrowed(c), x),
Self::Console(c, x) => BackendType::Console(MaybeOwned::Borrowed(c), x),
Self::Link(c, x) => BackendType::Link(MaybeOwned::Borrowed(c), x),
}
}
}
@@ -115,10 +115,9 @@ impl<'a, T, D> BackendType<'a, T, D> {
/// Maps [`BackendType<T>`] to [`BackendType<R>`] by applying
/// a function to a contained value.
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R, D> {
use BackendType::*;
match self {
Console(c, x) => Console(c, f(x)),
Link(c, x) => Link(c, x),
Self::Console(c, x) => BackendType::Console(c, f(x)),
Self::Link(c, x) => BackendType::Link(c, x),
}
}
}
@@ -126,10 +125,9 @@ impl<'a, T, D, E> BackendType<'a, Result<T, E>, D> {
/// Very similar to [`std::option::Option::transpose`].
/// This is most useful for error handling.
pub fn transpose(self) -> Result<BackendType<'a, T, D>, E> {
use BackendType::*;
match self {
Console(c, x) => x.map(|x| Console(c, x)),
Link(c, x) => Ok(Link(c, x)),
Self::Console(c, x) => x.map(|x| BackendType::Console(c, x)),
Self::Link(c, x) => Ok(BackendType::Link(c, x)),
}
}
}
@@ -278,7 +276,9 @@ async fn auth_quirks(
ctx: &RequestMonitoring,
api: &impl console::Api,
user_info: ComputeUserInfoMaybeEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut stream::PqStream<
Stream<impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>,
>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -293,7 +293,9 @@ async fn auth_quirks(
ctx.set_endpoint_id(res.info.endpoint.clone());
let password = match res.keys {
ComputeCredentialKeys::Password(p) => p,
_ => unreachable!("password hack should return a password"),
ComputeCredentialKeys::AuthKeys(_) => {
unreachable!("password hack should return a password")
}
};
(res.info, Some(password))
}
@@ -360,7 +362,9 @@ async fn authenticate_with_secret(
ctx: &RequestMonitoring,
secret: AuthSecret,
info: ComputeUserInfo,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut stream::PqStream<
Stream<impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>,
>,
unauthenticated_password: Option<Vec<u8>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
@@ -400,21 +404,17 @@ async fn authenticate_with_secret(
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
/// Get compute endpoint name from the credentials.
pub fn get_endpoint(&self) -> Option<EndpointId> {
use BackendType::*;
match self {
Console(_, user_info) => user_info.endpoint_id.clone(),
Link(_, _) => Some("link".into()),
Self::Console(_, user_info) => user_info.endpoint_id.clone(),
Self::Link(_, _) => Some("link".into()),
}
}
/// Get username from the credentials.
pub fn get_user(&self) -> &str {
use BackendType::*;
match self {
Console(_, user_info) => &user_info.user,
Link(_, _) => "link",
Self::Console(_, user_info) => &user_info.user,
Self::Link(_, _) => "link",
}
}
@@ -423,15 +423,15 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
pub async fn authenticate(
self,
ctx: &RequestMonitoring,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut stream::PqStream<
Stream<impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>,
>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> auth::Result<BackendType<'a, ComputeCredentials, NodeInfo>> {
use BackendType::*;
let res = match self {
Console(api, user_info) => {
Self::Console(api, user_info) => {
info!(
user = &*user_info.user,
project = user_info.endpoint(),
@@ -451,7 +451,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
BackendType::Console(api, credentials)
}
// NOTE: this auth backend doesn't use client credentials.
Link(url, _) => {
Self::Link(url, _) => {
info!("performing link authentication");
let info = link::authenticate(ctx, &url, client).await?;
@@ -470,10 +470,9 @@ impl BackendType<'_, ComputeUserInfo, &()> {
&self,
ctx: &RequestMonitoring,
) -> Result<CachedRoleSecret, GetAuthInfoError> {
use BackendType::*;
match self {
Console(api, user_info) => api.get_role_secret(ctx, user_info).await,
Link(_, _) => Ok(Cached::new_uncached(None)),
Self::Console(api, user_info) => api.get_role_secret(ctx, user_info).await,
Self::Link(_, _) => Ok(Cached::new_uncached(None)),
}
}
@@ -481,10 +480,9 @@ impl BackendType<'_, ComputeUserInfo, &()> {
&self,
ctx: &RequestMonitoring,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
use BackendType::*;
match self {
Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await,
Link(_, _) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
Self::Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await,
Self::Link(_, _) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
}
}
}
@@ -495,18 +493,16 @@ impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> {
&self,
ctx: &RequestMonitoring,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
use BackendType::*;
match self {
Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
Link(_, info) => Ok(Cached::new_uncached(info.clone())),
Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
Self::Link(_, info) => Ok(Cached::new_uncached(info.clone())),
}
}
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
match self {
BackendType::Console(_, creds) => Some(&creds.keys),
BackendType::Link(_, _) => None,
Self::Console(_, creds) => Some(&creds.keys),
Self::Link(_, _) => None,
}
}
}
@@ -517,18 +513,16 @@ impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> {
&self,
ctx: &RequestMonitoring,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
use BackendType::*;
match self {
Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
Link(_, _) => unreachable!("link auth flow doesn't support waking the compute"),
Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
Self::Link(_, _) => unreachable!("link auth flow doesn't support waking the compute"),
}
}
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
match self {
BackendType::Console(_, creds) => Some(&creds.keys),
BackendType::Link(_, _) => None,
Self::Console(_, creds) => Some(&creds.keys),
Self::Link(_, _) => None,
}
}
}
@@ -556,7 +550,7 @@ mod tests {
CachedNodeInfo,
},
context::RequestMonitoring,
proxy::NeonOptions,
proxy::{tests::DummyClient, NeonOptions},
rate_limiter::{EndpointRateLimiter, RateBucketInfo},
scram::{threadpool::ThreadPool, ServerSecret},
stream::{PqStream, Stream},
@@ -664,7 +658,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_scram() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut stream = PqStream::new(Stream::from_raw(DummyClient(server)));
let ctx = RequestMonitoring::test();
let api = Auth {
@@ -741,7 +735,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_cleartext() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut stream = PqStream::new(Stream::from_raw(DummyClient(server)));
let ctx = RequestMonitoring::test();
let api = Auth {
@@ -793,7 +787,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_password_hack() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut stream = PqStream::new(Stream::from_raw(DummyClient(server)));
let ctx = RequestMonitoring::test();
let api = Auth {

View File

@@ -1,3 +1,5 @@
use std::os::fd::AsRawFd;
use super::{ComputeCredentials, ComputeUserInfo};
use crate::{
auth::{self, backend::ComputeCredentialKeys, AuthFlow},
@@ -5,6 +7,7 @@ use crate::{
config::AuthenticationConfig,
console::AuthSecret,
context::RequestMonitoring,
proxy::handshake::KtlsAsyncReadReady,
sasl,
stream::{PqStream, Stream},
};
@@ -14,7 +17,9 @@ use tracing::{info, warn};
pub(super) async fn authenticate(
ctx: &RequestMonitoring,
creds: ComputeUserInfo,
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut PqStream<
Stream<impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>,
>,
config: &'static AuthenticationConfig,
secret: AuthSecret,
) -> auth::Result<ComputeCredentials> {

View File

@@ -1,3 +1,5 @@
use std::os::fd::AsRawFd;
use super::{
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint,
};
@@ -7,6 +9,7 @@ use crate::{
console::AuthSecret,
context::RequestMonitoring,
intern::EndpointIdInt,
proxy::handshake::KtlsAsyncReadReady,
sasl,
stream::{self, Stream},
};
@@ -20,7 +23,9 @@ use tracing::{info, warn};
pub async fn authenticate_cleartext(
ctx: &RequestMonitoring,
info: ComputeUserInfo,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut stream::PqStream<
Stream<impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>,
>,
secret: AuthSecret,
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials> {
@@ -62,7 +67,9 @@ pub async fn authenticate_cleartext(
pub async fn password_hack_no_authentication(
ctx: &RequestMonitoring,
info: ComputeUserInfoNoEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut stream::PqStream<
Stream<impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>,
>,
) -> auth::Result<ComputeCredentials> {
warn!("project not specified, resorting to the password hack auth flow");
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);

View File

@@ -195,7 +195,7 @@ impl JwkCacheEntryLock {
let header = base64::decode_config(header, base64::URL_SAFE_NO_PAD)
.context("Provided authentication token is not a valid JWT encoding")?;
let header = serde_json::from_slice::<JWTHeader>(&header)
let header = serde_json::from_slice::<JWTHeader<'_>>(&header)
.context("Provided authentication token is not a valid JWT encoding")?;
let sig = base64::decode_config(signature, base64::URL_SAFE_NO_PAD)
@@ -340,7 +340,7 @@ impl JwkRenewalPermit<'_> {
}
}
async fn acquire_permit(from: &Arc<JwkCacheEntryLock>) -> JwkRenewalPermit {
async fn acquire_permit(from: &Arc<JwkCacheEntryLock>) -> JwkRenewalPermit<'_> {
match from.lookup.acquire().await {
Ok(permit) => {
permit.forget();
@@ -352,7 +352,7 @@ impl JwkRenewalPermit<'_> {
}
}
fn try_acquire_permit(from: &Arc<JwkCacheEntryLock>) -> Option<JwkRenewalPermit> {
fn try_acquire_permit(from: &Arc<JwkCacheEntryLock>) -> Option<JwkRenewalPermit<'_>> {
match from.lookup.try_acquire() {
Ok(permit) => {
permit.forget();

View File

@@ -86,13 +86,14 @@ impl ComputeUserInfoMaybeEndpoint {
pub fn parse(
ctx: &RequestMonitoring,
params: &StartupMessageParams,
sni: Option<&str>,
common_names: Option<&HashSet<String>>,
endpoint_from_domain: Option<EndpointId>,
) -> Result<Self, ComputeUserInfoParseError> {
use ComputeUserInfoParseError::*;
// Some parameters are stored in the startup message.
let get_param = |key| params.get(key).ok_or(MissingKey(key));
let get_param = |key| {
params
.get(key)
.ok_or(ComputeUserInfoParseError::MissingKey(key))
};
let user: RoleName = get_param("user")?.into();
// Project name might be passed via PG's command-line options.
@@ -109,24 +110,18 @@ impl ComputeUserInfoMaybeEndpoint {
})
.map(|name| name.into());
let endpoint_from_domain = if let Some(sni_str) = sni {
if let Some(cn) = common_names {
endpoint_sni(sni_str, cn)?
} else {
None
}
} else {
None
};
let is_sni = endpoint_from_domain.is_some();
let endpoint = match (endpoint_option, endpoint_from_domain) {
// Invariant: if we have both project name variants, they should match.
(Some(option), Some(domain)) if option != domain => {
Some(Err(InconsistentProjectNames { domain, option }))
Some(Err(ComputeUserInfoParseError::InconsistentProjectNames {
domain,
option,
}))
}
// Invariant: project name may not contain certain characters.
(a, b) => a.or(b).map(|name| match project_name_valid(name.as_ref()) {
false => Err(MalformedProjectName(name)),
false => Err(ComputeUserInfoParseError::MalformedProjectName(name)),
true => Ok(name),
}),
}
@@ -138,7 +133,7 @@ impl ComputeUserInfoMaybeEndpoint {
let metrics = Metrics::get();
info!(%user, "credentials");
if sni.is_some() {
if is_sni {
info!("Connection with sni");
metrics.proxy.accepted_connections_by_sni.inc(SniKind::Sni);
} else if endpoint.is_some() {
@@ -186,7 +181,7 @@ impl<'de> serde::de::Deserialize<'de> for IpPattern {
impl<'de> serde::de::Visitor<'de> for StrVisitor {
type Value = IpPattern;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(formatter, "comma separated list with ip address, ip address range, or ip address subnet mask")
}
@@ -250,7 +245,7 @@ mod tests {
// According to postgresql, only `user` should be required.
let options = StartupMessageParams::new([("user", "john_doe")]);
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id, None);
@@ -265,7 +260,7 @@ mod tests {
("foo", "bar"), // should be ignored
]);
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id, None);
@@ -276,12 +271,8 @@ mod tests {
fn parse_project_from_sni() -> anyhow::Result<()> {
let options = StartupMessageParams::new([("user", "john_doe")]);
let sni = Some("foo.localhost");
let common_names = Some(["localhost".into()].into());
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, Some("foo".into()))?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id.as_deref(), Some("foo"));
assert_eq!(user_info.options.get_cache_key("foo"), "foo");
@@ -297,7 +288,7 @@ mod tests {
]);
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
@@ -312,7 +303,7 @@ mod tests {
]);
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
@@ -330,7 +321,7 @@ mod tests {
]);
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?;
assert_eq!(user_info.user, "john_doe");
assert!(user_info.endpoint_id.is_none());
@@ -345,7 +336,7 @@ mod tests {
]);
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?;
assert_eq!(user_info.user, "john_doe");
assert!(user_info.endpoint_id.is_none());
@@ -356,49 +347,21 @@ mod tests {
fn parse_projects_identical() -> anyhow::Result<()> {
let options = StartupMessageParams::new([("user", "john_doe"), ("options", "project=baz")]);
let sni = Some("baz.localhost");
let common_names = Some(["localhost".into()].into());
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, Some("baz".into()))?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id.as_deref(), Some("baz"));
Ok(())
}
#[test]
fn parse_multi_common_names() -> anyhow::Result<()> {
let options = StartupMessageParams::new([("user", "john_doe")]);
let common_names = Some(["a.com".into(), "b.com".into()].into());
let sni = Some("p1.a.com");
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
let common_names = Some(["a.com".into(), "b.com".into()].into());
let sni = Some("p1.b.com");
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
Ok(())
}
#[test]
fn parse_projects_different() {
let options =
StartupMessageParams::new([("user", "john_doe"), ("options", "project=first")]);
let sni = Some("second.localhost");
let common_names = Some(["localhost".into()].into());
let ctx = RequestMonitoring::test();
let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, Some("second".into()))
.expect_err("should fail");
match err {
InconsistentProjectNames { domain, option } => {
@@ -409,24 +372,6 @@ mod tests {
}
}
#[test]
fn parse_inconsistent_sni() {
let options = StartupMessageParams::new([("user", "john_doe")]);
let sni = Some("project.localhost");
let common_names = Some(["example.com".into()].into());
let ctx = RequestMonitoring::test();
let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
.expect_err("should fail");
match err {
UnknownCommonName { cn } => {
assert_eq!(cn, "localhost");
}
_ => panic!("bad error: {err:?}"),
}
}
#[test]
fn parse_neon_options() -> anyhow::Result<()> {
let options = StartupMessageParams::new([
@@ -434,11 +379,9 @@ mod tests {
("options", "neon_lsn:0/2 neon_endpoint_type:read_write"),
]);
let sni = Some("project.localhost");
let common_names = Some(["localhost".into()].into());
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, Some("project".into()))?;
assert_eq!(user_info.endpoint_id.as_deref(), Some("project"));
assert_eq!(
user_info.options.get_cache_key("project"),

View File

@@ -6,13 +6,14 @@ use crate::{
console::AuthSecret,
context::RequestMonitoring,
intern::EndpointIdInt,
proxy::handshake::KtlsAsyncReadReady,
sasl,
scram::{self, threadpool::ThreadPool},
stream::{PqStream, Stream},
};
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
use std::{io, sync::Arc};
use std::{io, os::fd::AsRawFd, sync::Arc};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
@@ -70,7 +71,7 @@ impl AuthMethod for CleartextPassword {
/// This wrapper for [`PqStream`] performs client authentication.
#[must_use]
pub struct AuthFlow<'a, S, State> {
pub struct AuthFlow<'a, S: AsRawFd, State> {
/// The underlying stream which implements libpq's protocol.
stream: &'a mut PqStream<Stream<S>>,
/// State might contain ancillary data (see [`Self::begin`]).
@@ -79,7 +80,7 @@ pub struct AuthFlow<'a, S, State> {
}
/// Initial state of the stream wrapper.
impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
impl<'a, S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady> AuthFlow<'a, S, Begin> {
/// Create a new wrapper for client authentication.
pub fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
let tls_server_end_point = stream.get_ref().tls_server_end_point();
@@ -105,7 +106,9 @@ impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
impl<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>
AuthFlow<'_, S, PasswordHack>
{
/// Perform user authentication. Raise an error in case authentication failed.
pub async fn get_password(self) -> super::Result<PasswordHackPayload> {
let msg = self.stream.read_password_message().await?;
@@ -124,7 +127,9 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
impl<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>
AuthFlow<'_, S, CleartextPassword>
{
/// Perform user authentication. Raise an error in case authentication failed.
pub async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
let msg = self.stream.read_password_message().await?;
@@ -149,7 +154,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
}
/// Stream wrapper for handling [SCRAM](crate::scram) auth.
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
impl<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady> AuthFlow<'_, S, Scram<'_>> {
/// Perform user authentication. Raise an error in case authentication failed.
pub async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
let Scram(secret, ctx) = self.state;

View File

@@ -1,3 +1,4 @@
use std::os::fd::AsRawFd;
/// A stand-alone program that routes connections, e.g. from
/// `aaa--bbb--1234.external.domain` to `aaa.bbb.internal.domain:1234`.
///
@@ -7,9 +8,9 @@ use std::{net::SocketAddr, sync::Arc};
use futures::future::Either;
use itertools::Itertools;
use proxy::config::TlsServerEndPoint;
use proxy::context::RequestMonitoring;
use proxy::metrics::{Metrics, ThreadPoolMetrics};
use proxy::proxy::handshake::KtlsAsyncReadReady;
use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource};
use rustls::pki_types::PrivateKeyDer;
use tokio::net::TcpListener;
@@ -20,6 +21,7 @@ use futures::TryFutureExt;
use proxy::stream::{PqStream, Stream};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::server::TlsStream;
use tokio_util::sync::CancellationToken;
use utils::{project_git_version, sentry_init::init_sentry};
@@ -72,7 +74,7 @@ async fn main() -> anyhow::Result<()> {
let destination: String = args.get_one::<String>("dest").unwrap().parse()?;
// Configure TLS
let (tls_config, tls_server_end_point): (Arc<rustls::ServerConfig>, TlsServerEndPoint) = match (
let tls_config = match (
args.get_one::<String>("tls-key"),
args.get_one::<String>("tls-cert"),
) {
@@ -102,19 +104,14 @@ async fn main() -> anyhow::Result<()> {
})?
};
// needed for channel bindings
let first_cert = cert_chain.first().context("missing certificate")?;
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let tls_config = rustls::ServerConfig::builder_with_protocol_versions(&[
&rustls::version::TLS13,
&rustls::version::TLS12,
])
.with_no_client_auth()
.with_single_cert(cert_chain, key)?
.into();
(tls_config, tls_server_end_point)
Arc::new(
rustls::ServerConfig::builder_with_protocol_versions(&[
&rustls::version::TLS13,
&rustls::version::TLS12,
])
.with_no_client_auth()
.with_single_cert(cert_chain, key)?,
)
}
_ => bail!("tls-key and tls-cert must be specified"),
};
@@ -129,7 +126,6 @@ async fn main() -> anyhow::Result<()> {
let main = tokio::spawn(task_main(
Arc::new(destination),
tls_config,
tls_server_end_point,
proxy_listener,
cancellation_token.clone(),
));
@@ -151,7 +147,6 @@ async fn main() -> anyhow::Result<()> {
async fn task_main(
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
) -> anyhow::Result<()> {
@@ -183,7 +178,7 @@ async fn task_main(
proxy::metrics::Protocol::SniRouter,
"sni",
);
handle_client(ctx, dest_suffix, tls_config, tls_server_end_point, socket).await
handle_client(ctx, dest_suffix, tls_config, socket).await
}
.unwrap_or_else(|e| {
// Acknowledge that the task has finished with an error.
@@ -204,12 +199,11 @@ async fn task_main(
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>(
ctx: &RequestMonitoring,
raw_stream: S,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
) -> anyhow::Result<Stream<S>> {
) -> anyhow::Result<Box<TlsStream<S>>> {
let mut stream = PqStream::new(Stream::from_raw(raw_stream));
let msg = stream.read_startup_packet().await?;
@@ -235,13 +229,10 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
bail!("data is sent before server replied with EncryptionResponse");
}
Ok(Stream::Tls {
tls: Box::new(
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
.await?,
),
tls_server_end_point,
})
Ok(Box::new(
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
.await?,
))
}
unexpected => {
info!(
@@ -259,15 +250,18 @@ async fn handle_client(
ctx: RequestMonitoring,
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
stream: impl AsyncRead + AsyncWrite + Unpin,
stream: impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady,
) -> anyhow::Result<()> {
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?;
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config).await?;
// Cut off first part of the SNI domain
// We receive required destination details in the format of
// `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain`
let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?;
let sni = tls_stream
.get_ref()
.1
.server_name()
.ok_or(anyhow!("SNI missing"))?;
let dest: Vec<&str> = sni
.split_once('.')
.context("invalid SNI")?

View File

@@ -173,9 +173,6 @@ struct ProxyCliArgs {
/// cache for `role_secret` (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
role_secret_cache: String,
/// disable ip check for http requests. If it is too time consuming, it could be turned off.
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
disable_ip_check_for_http: bool,
/// redis url for notifications (if empty, redis_host:port will be used for both notifications and streaming connections)
#[clap(long)]
redis_notifications: Option<String>,
@@ -288,7 +285,7 @@ async fn main() -> anyhow::Result<()> {
};
let args = ProxyCliArgs::parse();
let config = build_config(&args)?;
let config = build_config(&args).await?;
info!("Authentication backend: {}", config.auth_backend);
info!("Using region: {}", args.aws_region);
@@ -532,16 +529,14 @@ async fn main() -> anyhow::Result<()> {
}
/// ProxyConfig is created at proxy startup, and lives forever.
fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
async fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
let thread_pool = ThreadPool::new(args.scram_thread_pool_size);
Metrics::install(thread_pool.metrics.clone());
let tls_config = match (&args.tls_key, &args.tls_cert) {
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(
key_path,
cert_path,
args.certs_dir.as_ref(),
)?),
(Some(key_path), Some(cert_path)) => {
Some(config::configure_tls(key_path, cert_path, args.certs_dir.as_ref()).await?)
}
(None, None) => None,
_ => bail!("either both or neither tls-key and tls-cert must be specified"),
};
@@ -661,6 +656,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
)?;
let http_config = HttpConfig {
accept_websockets: true,
pool_options: GlobalConnPoolOptions {
max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint,
gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch,

View File

@@ -24,7 +24,7 @@ impl<C: Cache> Cache for &C {
type LookupInfo<Key> = C::LookupInfo<Key>;
fn invalidate(&self, info: &Self::LookupInfo<Self::Key>) {
C::invalidate(self, info)
C::invalidate(self, info);
}
}

View File

@@ -58,7 +58,7 @@ impl<K: Hash + Eq, V> Cache for TimedLru<K, V> {
type LookupInfo<Key> = LookupInfo<Key>;
fn invalidate(&self, info: &Self::LookupInfo<K>) {
self.invalidate_raw(info)
self.invalidate_raw(info);
}
}

View File

@@ -44,11 +44,10 @@ pub enum ConnectionError {
impl UserFacingError for ConnectionError {
fn to_string_client(&self) -> String {
use ConnectionError::*;
match self {
// This helps us drop irrelevant library-specific prefixes.
// TODO: propagate severity level and other parameters.
Postgres(err) => match err.as_db_error() {
ConnectionError::Postgres(err) => match err.as_db_error() {
Some(err) => {
let msg = err.message();
@@ -62,8 +61,8 @@ impl UserFacingError for ConnectionError {
}
None => err.to_string(),
},
WakeComputeError(err) => err.to_string_client(),
TooManyConnectionAttempts(_) => {
ConnectionError::WakeComputeError(err) => err.to_string_client(),
ConnectionError::TooManyConnectionAttempts(_) => {
"Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
}
_ => COULD_NOT_CONNECT.to_owned(),
@@ -366,16 +365,16 @@ static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();
struct AcceptEverythingVerifier;
impl ServerCertVerifier for AcceptEverythingVerifier {
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
use rustls::SignatureScheme::*;
use rustls::SignatureScheme;
// The schemes for which `SignatureScheme::supported_in_tls13` returns true.
vec![
ECDSA_NISTP521_SHA512,
ECDSA_NISTP384_SHA384,
ECDSA_NISTP256_SHA256,
RSA_PSS_SHA512,
RSA_PSS_SHA384,
RSA_PSS_SHA256,
ED25519,
SignatureScheme::ECDSA_NISTP521_SHA512,
SignatureScheme::ECDSA_NISTP384_SHA384,
SignatureScheme::ECDSA_NISTP256_SHA256,
SignatureScheme::RSA_PSS_SHA512,
SignatureScheme::RSA_PSS_SHA384,
SignatureScheme::RSA_PSS_SHA256,
SignatureScheme::ED25519,
]
}
fn verify_server_cert(

View File

@@ -10,7 +10,7 @@ use anyhow::{bail, ensure, Context, Ok};
use itertools::Itertools;
use remote_storage::RemoteStorageConfig;
use rustls::{
crypto::ring::sign,
crypto::aws_lc_rs::sign,
pki_types::{CertificateDer, PrivateKeyDer},
};
use sha2::{Digest, Sha256};
@@ -52,6 +52,7 @@ pub struct TlsConfig {
}
pub struct HttpConfig {
pub accept_websockets: bool,
pub pool_options: GlobalConnPoolOptions,
pub cancel_set: CancelSet,
pub client_conn_threshold: u64,
@@ -75,7 +76,7 @@ impl TlsConfig {
pub const PG_ALPN_PROTOCOL: &[u8] = b"postgresql";
/// Configure TLS for the main endpoint.
pub fn configure_tls(
pub async fn configure_tls(
key_path: &str,
cert_path: &str,
certs_dir: Option<&String>,
@@ -109,13 +110,20 @@ pub fn configure_tls(
let cert_resolver = Arc::new(cert_resolver);
let provider = rustls::crypto::aws_lc_rs::default_provider();
#[cfg(target_os = "linux")]
let provider = {
let mut provider = provider;
let compat = ktls::CompatibleCiphers::new().await?;
provider.cipher_suites.retain(|s| compat.is_compatible(*s));
provider
};
// allow TLS 1.2 to be compatible with older client libraries
let mut config = rustls::ServerConfig::builder_with_protocol_versions(&[
&rustls::version::TLS13,
&rustls::version::TLS12,
])
.with_no_client_auth()
.with_cert_resolver(cert_resolver.clone());
let mut config = rustls::ServerConfig::builder_with_provider(Arc::new(provider))
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
.with_no_client_auth()
.with_cert_resolver(cert_resolver.clone());
config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
@@ -155,7 +163,7 @@ pub enum TlsServerEndPoint {
}
impl TlsServerEndPoint {
pub fn new(cert: &CertificateDer) -> anyhow::Result<Self> {
pub fn new(cert: &CertificateDer<'_>) -> anyhow::Result<Self> {
let sha256_oids = [
// I'm explicitly not adding MD5 or SHA1 here... They're bad.
oid_registry::OID_SIG_ECDSA_WITH_SHA256,
@@ -278,7 +286,7 @@ impl CertResolver {
impl rustls::server::ResolvesServerCert for CertResolver {
fn resolve(
&self,
client_hello: rustls::server::ClientHello,
client_hello: rustls::server::ClientHello<'_>,
) -> Option<Arc<rustls::sign::CertifiedKey>> {
self.resolve(client_hello.server_name()).map(|x| x.0)
}
@@ -559,7 +567,7 @@ impl RetryConfig {
match key {
"num_retries" => num_retries = Some(value.parse()?),
"base_retry_wait_duration" => {
base_retry_wait_duration = Some(humantime::parse_duration(value)?)
base_retry_wait_duration = Some(humantime::parse_duration(value)?);
}
"retry_wait_exponent_base" => retry_wait_exponent_base = Some(value.parse()?),
unknown => bail!("unknown key: {unknown}"),

View File

@@ -22,16 +22,15 @@ impl ConsoleError {
self.status
.as_ref()
.and_then(|s| s.details.error_info.as_ref())
.map(|e| e.reason)
.unwrap_or(Reason::Unknown)
.map_or(Reason::Unknown, |e| e.reason)
}
pub fn get_user_facing_message(&self) -> String {
use super::provider::errors::REQUEST_FAILED;
self.status
.as_ref()
.and_then(|s| s.details.user_facing_message.as_ref())
.map(|m| m.message.clone().into())
.unwrap_or_else(|| {
.map_or_else(|| {
// Ask @neondatabase/control-plane for review before adding more.
match self.http_status_code {
http::StatusCode::NOT_FOUND => {
@@ -48,19 +47,18 @@ impl ConsoleError {
}
_ => REQUEST_FAILED.to_owned(),
}
})
}, |m| m.message.clone().into())
}
}
impl Display for ConsoleError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let msg = self
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let msg: &str = self
.status
.as_ref()
.and_then(|s| s.details.user_facing_message.as_ref())
.map(|m| m.message.as_ref())
.unwrap_or_else(|| &self.error);
write!(f, "{}", msg)
.map_or_else(|| self.error.as_ref(), |m| m.message.as_ref());
write!(f, "{msg}")
}
}
@@ -286,7 +284,7 @@ pub struct DatabaseInfo {
// Manually implement debug to omit sensitive info.
impl fmt::Debug for DatabaseInfo {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DatabaseInfo")
.field("host", &self.host)
.field("port", &self.port)
@@ -373,7 +371,7 @@ mod tests {
}
}
});
let _: KickSession = serde_json::from_str(&json.to_string())?;
let _: KickSession<'_> = serde_json::from_str(&json.to_string())?;
Ok(())
}

View File

@@ -93,7 +93,8 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for MgmtHandler {
}
fn try_process_query(pgb: &mut PostgresBackendTCP, query: &str) -> Result<(), QueryError> {
let resp: KickSession = serde_json::from_str(query).context("Failed to parse query as json")?;
let resp: KickSession<'_> =
serde_json::from_str(query).context("Failed to parse query as json")?;
let span = info_span!("event", session_id = resp.session_id);
let _enter = span.enter();

View File

@@ -26,7 +26,7 @@ use tracing::info;
pub mod errors {
use crate::{
console::messages::{self, ConsoleError, Reason},
error::{io_error, ReportableError, UserFacingError},
error::{io_error, ErrorKind, ReportableError, UserFacingError},
proxy::retry::CouldRetry,
};
use thiserror::Error;
@@ -51,21 +51,19 @@ pub mod errors {
impl ApiError {
/// Returns HTTP status code if it's the reason for failure.
pub fn get_reason(&self) -> messages::Reason {
use ApiError::*;
match self {
Console(e) => e.get_reason(),
_ => messages::Reason::Unknown,
ApiError::Console(e) => e.get_reason(),
ApiError::Transport(_) => messages::Reason::Unknown,
}
}
}
impl UserFacingError for ApiError {
fn to_string_client(&self) -> String {
use ApiError::*;
match self {
// To minimize risks, only select errors are forwarded to users.
Console(c) => c.get_user_facing_message(),
_ => REQUEST_FAILED.to_owned(),
ApiError::Console(c) => c.get_user_facing_message(),
ApiError::Transport(_) => REQUEST_FAILED.to_owned(),
}
}
}
@@ -73,57 +71,53 @@ pub mod errors {
impl ReportableError for ApiError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
ApiError::Console(e) => {
use crate::error::ErrorKind::*;
match e.get_reason() {
Reason::RoleProtected => User,
Reason::ResourceNotFound => User,
Reason::ProjectNotFound => User,
Reason::EndpointNotFound => User,
Reason::BranchNotFound => User,
Reason::RateLimitExceeded => ServiceRateLimit,
Reason::NonDefaultBranchComputeTimeExceeded => User,
Reason::ActiveTimeQuotaExceeded => User,
Reason::ComputeTimeQuotaExceeded => User,
Reason::WrittenDataQuotaExceeded => User,
Reason::DataTransferQuotaExceeded => User,
Reason::LogicalSizeQuotaExceeded => User,
Reason::ConcurrencyLimitReached => ControlPlane,
Reason::LockAlreadyTaken => ControlPlane,
Reason::RunningOperations => ControlPlane,
Reason::Unknown => match &e {
ConsoleError {
http_status_code:
http::StatusCode::NOT_FOUND | http::StatusCode::NOT_ACCEPTABLE,
..
} => crate::error::ErrorKind::User,
ConsoleError {
http_status_code: http::StatusCode::UNPROCESSABLE_ENTITY,
error,
..
} if error.contains(
"compute time quota of non-primary branches is exceeded",
) =>
{
crate::error::ErrorKind::User
}
ConsoleError {
http_status_code: http::StatusCode::LOCKED,
error,
..
} if error.contains("quota exceeded")
|| error.contains("the limit for current plan reached") =>
{
crate::error::ErrorKind::User
}
ConsoleError {
http_status_code: http::StatusCode::TOO_MANY_REQUESTS,
..
} => crate::error::ErrorKind::ServiceRateLimit,
ConsoleError { .. } => crate::error::ErrorKind::ControlPlane,
},
}
}
ApiError::Console(e) => match e.get_reason() {
Reason::RoleProtected => ErrorKind::User,
Reason::ResourceNotFound => ErrorKind::User,
Reason::ProjectNotFound => ErrorKind::User,
Reason::EndpointNotFound => ErrorKind::User,
Reason::BranchNotFound => ErrorKind::User,
Reason::RateLimitExceeded => ErrorKind::ServiceRateLimit,
Reason::NonDefaultBranchComputeTimeExceeded => ErrorKind::User,
Reason::ActiveTimeQuotaExceeded => ErrorKind::User,
Reason::ComputeTimeQuotaExceeded => ErrorKind::User,
Reason::WrittenDataQuotaExceeded => ErrorKind::User,
Reason::DataTransferQuotaExceeded => ErrorKind::User,
Reason::LogicalSizeQuotaExceeded => ErrorKind::User,
Reason::ConcurrencyLimitReached => ErrorKind::ControlPlane,
Reason::LockAlreadyTaken => ErrorKind::ControlPlane,
Reason::RunningOperations => ErrorKind::ControlPlane,
Reason::Unknown => match &e {
ConsoleError {
http_status_code:
http::StatusCode::NOT_FOUND | http::StatusCode::NOT_ACCEPTABLE,
..
} => crate::error::ErrorKind::User,
ConsoleError {
http_status_code: http::StatusCode::UNPROCESSABLE_ENTITY,
error,
..
} if error
.contains("compute time quota of non-primary branches is exceeded") =>
{
crate::error::ErrorKind::User
}
ConsoleError {
http_status_code: http::StatusCode::LOCKED,
error,
..
} if error.contains("quota exceeded")
|| error.contains("the limit for current plan reached") =>
{
crate::error::ErrorKind::User
}
ConsoleError {
http_status_code: http::StatusCode::TOO_MANY_REQUESTS,
..
} => crate::error::ErrorKind::ServiceRateLimit,
ConsoleError { .. } => crate::error::ErrorKind::ControlPlane,
},
},
ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane,
}
}
@@ -170,12 +164,11 @@ pub mod errors {
impl UserFacingError for GetAuthInfoError {
fn to_string_client(&self) -> String {
use GetAuthInfoError::*;
match self {
// We absolutely should not leak any secrets!
BadSecret => REQUEST_FAILED.to_owned(),
Self::BadSecret => REQUEST_FAILED.to_owned(),
// However, API might return a meaningful error.
ApiError(e) => e.to_string_client(),
Self::ApiError(e) => e.to_string_client(),
}
}
}
@@ -183,8 +176,8 @@ pub mod errors {
impl ReportableError for GetAuthInfoError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
GetAuthInfoError::BadSecret => crate::error::ErrorKind::ControlPlane,
GetAuthInfoError::ApiError(_) => crate::error::ErrorKind::ControlPlane,
Self::BadSecret => crate::error::ErrorKind::ControlPlane,
Self::ApiError(_) => crate::error::ErrorKind::ControlPlane,
}
}
}
@@ -213,17 +206,16 @@ pub mod errors {
impl UserFacingError for WakeComputeError {
fn to_string_client(&self) -> String {
use WakeComputeError::*;
match self {
// We shouldn't show user the address even if it's broken.
// Besides, user is unlikely to care about this detail.
BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
Self::BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
// However, API might return a meaningful error.
ApiError(e) => e.to_string_client(),
Self::ApiError(e) => e.to_string_client(),
TooManyConnections => self.to_string(),
Self::TooManyConnections => self.to_string(),
TooManyConnectionAttempts(_) => {
Self::TooManyConnectionAttempts(_) => {
"Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
}
}
@@ -233,10 +225,10 @@ pub mod errors {
impl ReportableError for WakeComputeError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
WakeComputeError::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
WakeComputeError::ApiError(e) => e.get_error_kind(),
WakeComputeError::TooManyConnections => crate::error::ErrorKind::RateLimit,
WakeComputeError::TooManyConnectionAttempts(e) => e.get_error_kind(),
Self::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
Self::ApiError(e) => e.get_error_kind(),
Self::TooManyConnections => crate::error::ErrorKind::RateLimit,
Self::TooManyConnectionAttempts(e) => e.get_error_kind(),
}
}
}
@@ -244,10 +236,10 @@ pub mod errors {
impl CouldRetry for WakeComputeError {
fn could_retry(&self) -> bool {
match self {
WakeComputeError::BadComputeAddress(_) => false,
WakeComputeError::ApiError(e) => e.could_retry(),
WakeComputeError::TooManyConnections => false,
WakeComputeError::TooManyConnectionAttempts(_) => false,
Self::BadComputeAddress(_) => false,
Self::ApiError(e) => e.could_retry(),
Self::TooManyConnections => false,
Self::TooManyConnectionAttempts(_) => false,
}
}
}
@@ -366,13 +358,14 @@ impl Api for ConsoleBackend {
ctx: &RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
use ConsoleBackend::*;
match self {
Console(api) => api.get_role_secret(ctx, user_info).await,
Self::Console(api) => api.get_role_secret(ctx, user_info).await,
#[cfg(any(test, feature = "testing"))]
Postgres(api) => api.get_role_secret(ctx, user_info).await,
Self::Postgres(api) => api.get_role_secret(ctx, user_info).await,
#[cfg(test)]
Test(_) => unreachable!("this function should never be called in the test backend"),
Self::Test(_) => {
unreachable!("this function should never be called in the test backend")
}
}
}
@@ -381,13 +374,12 @@ impl Api for ConsoleBackend {
ctx: &RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError> {
use ConsoleBackend::*;
match self {
Console(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
Self::Console(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
#[cfg(any(test, feature = "testing"))]
Postgres(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
Self::Postgres(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
#[cfg(test)]
Test(api) => api.get_allowed_ips_and_secret(),
Self::Test(api) => api.get_allowed_ips_and_secret(),
}
}
@@ -396,14 +388,12 @@ impl Api for ConsoleBackend {
ctx: &RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<CachedNodeInfo, errors::WakeComputeError> {
use ConsoleBackend::*;
match self {
Console(api) => api.wake_compute(ctx, user_info).await,
Self::Console(api) => api.wake_compute(ctx, user_info).await,
#[cfg(any(test, feature = "testing"))]
Postgres(api) => api.wake_compute(ctx, user_info).await,
Self::Postgres(api) => api.wake_compute(ctx, user_info).await,
#[cfg(test)]
Test(api) => api.wake_compute(),
Self::Test(api) => api.wake_compute(),
}
}
}
@@ -549,7 +539,7 @@ impl WakeComputePermit {
!self.permit.is_disabled()
}
pub fn release(self, outcome: Outcome) {
self.permit.release(outcome)
self.permit.release(outcome);
}
pub fn release_result<T, E>(self, res: Result<T, E>) -> Result<T, E> {
match res {

View File

@@ -166,7 +166,7 @@ impl RequestMonitoring {
pub fn set_project(&self, x: MetricsAuxInfo) {
let mut this = self.0.try_lock().expect("should not deadlock");
if this.endpoint_id.is_none() {
this.set_endpoint_id(x.endpoint_id.as_str().into())
this.set_endpoint_id(x.endpoint_id.as_str().into());
}
this.branch = Some(x.branch_id);
this.project = Some(x.project_id);
@@ -260,7 +260,7 @@ impl RequestMonitoring {
.cold_start_info
}
pub fn latency_timer_pause(&self, waiting_for: Waiting) -> LatencyTimerPause {
pub fn latency_timer_pause(&self, waiting_for: Waiting) -> LatencyTimerPause<'_> {
LatencyTimerPause {
ctx: self,
start: tokio::time::Instant::now(),
@@ -273,7 +273,7 @@ impl RequestMonitoring {
.try_lock()
.expect("should not deadlock")
.latency_timer
.success()
.success();
}
}
@@ -328,7 +328,7 @@ impl RequestMonitoringInner {
fn has_private_peer_addr(&self) -> bool {
match self.peer_addr {
IpAddr::V4(ip) => ip.is_private(),
_ => false,
IpAddr::V6(_) => false,
}
}

View File

@@ -736,7 +736,7 @@ mod tests {
while let Some(r) = s.next().await {
tx.send(r).unwrap();
}
time::sleep(time::Duration::from_secs(70)).await
time::sleep(time::Duration::from_secs(70)).await;
}
});

View File

@@ -56,7 +56,7 @@ impl<'de, Id: InternId> serde::de::Deserialize<'de> for InternedString<Id> {
impl<'de, Id: InternId> serde::de::Visitor<'de> for Visitor<Id> {
type Value = InternedString<Id>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str("a string")
}

View File

@@ -252,7 +252,7 @@ impl Drop for HttpEndpointPoolsGuard<'_> {
}
impl HttpEndpointPools {
pub fn guard(&self) -> HttpEndpointPoolsGuard {
pub fn guard(&self) -> HttpEndpointPoolsGuard<'_> {
self.http_pool_endpoints_registered_total.inc();
HttpEndpointPoolsGuard {
dec: &self.http_pool_endpoints_unregistered_total,

View File

@@ -3,6 +3,7 @@
use std::{
io,
net::SocketAddr,
os::fd::AsRawFd,
pin::Pin,
task::{Context, Poll},
};
@@ -20,6 +21,23 @@ pin_project! {
}
}
impl<S: AsRawFd> AsRawFd for ChainRW<S> {
fn as_raw_fd(&self) -> std::os::unix::prelude::RawFd {
self.inner.as_raw_fd()
}
}
#[cfg(all(target_os = "linux", not(test)))]
impl<S: ktls::AsyncReadReady> ktls::AsyncReadReady for ChainRW<S> {
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
if self.buf.is_empty() {
self.inner.poll_read_ready(cx)
} else {
Poll::Ready(Ok(()))
}
}
}
impl<T: AsyncWrite> AsyncWrite for ChainRW<T> {
#[inline]
fn poll_write(

View File

@@ -1,5 +1,5 @@
#[cfg(test)]
mod tests;
pub mod tests;
pub mod connect_compute;
mod copy_bidirectional;
@@ -9,6 +9,7 @@ pub mod retry;
pub mod wake_compute;
pub use copy_bidirectional::copy_bidirectional_client_compute;
pub use copy_bidirectional::ErrorSource;
use handshake::KtlsAsyncReadReady;
use crate::{
auth,
@@ -21,7 +22,7 @@ use crate::{
protocol2::read_proxy_protocol,
proxy::handshake::{handshake, HandshakeData},
rate_limiter::EndpointRateLimiter,
stream::{PqStream, Stream},
stream::PqStream,
EndpointCacheKey,
};
use futures::TryFutureExt;
@@ -30,6 +31,7 @@ use once_cell::sync::OnceCell;
use pq_proto::{BeMessage as Be, StartupMessageParams};
use regex::Regex;
use smol_str::{format_smolstr, SmolStr};
use std::os::fd::AsRawFd;
use std::sync::Arc;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
@@ -191,13 +193,6 @@ impl ClientMode {
}
}
fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
match self {
ClientMode::Tcp => s.sni_hostname(),
ClientMode::Websockets { hostname } => hostname.as_deref(),
}
}
fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
match self {
ClientMode::Tcp => tls,
@@ -238,7 +233,7 @@ impl ReportableError for ClientRequestError {
}
}
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>(
config: &'static ProxyConfig,
ctx: &RequestMonitoring,
cancellation_handler: Arc<CancellationHandlerMain>,
@@ -261,9 +256,9 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
let (mut stream, params) =
let (mut stream, ep, params) =
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Startup(stream, ep, params) => (stream, ep, params),
HandshakeData::Cancel(cancel_key_data) => {
return Ok(cancellation_handler
.cancel_session(cancel_key_data, ctx.session_id())
@@ -275,15 +270,11 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
ctx.set_db_options(params.clone());
let hostname = mode.hostname(stream.get_ref());
let common_names = tls.map(|tls| &tls.common_names);
// Extract credentials which we're going to use for auth.
let result = config
.auth_backend
.as_ref()
.map(|_| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names))
.map(|_| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, ep))
.transpose();
let user_info = match result {

View File

@@ -184,7 +184,7 @@ impl CopyBuffer {
}
Poll::Pending
}
res => res.map_err(ErrorDirection::Write),
res @ Poll::Ready(_) => res.map_err(ErrorDirection::Write),
}
}

View File

@@ -1,3 +1,5 @@
use std::os::fd::AsRawFd;
use bytes::Buf;
use pq_proto::{
framed::Framed, BeMessage as Be, CancelKeyData, FeStartupPacket, ProtocolVersion,
@@ -15,6 +17,7 @@ use crate::{
metrics::Metrics,
proxy::ERR_INSECURE_CONNECTION,
stream::{PqStream, Stream, StreamUpgradeError},
EndpointId,
};
#[derive(Error, Debug)]
@@ -31,6 +34,10 @@ pub enum HandshakeError {
#[error("{0}")]
StreamUpgradeError(#[from] StreamUpgradeError),
#[cfg(all(target_os = "linux", not(test)))]
#[error("{0}")]
KtlsUpgradeError(#[from] ktls::Error),
#[error("{0}")]
Io(#[from] std::io::Error),
@@ -43,6 +50,8 @@ impl ReportableError for HandshakeError {
match self {
HandshakeError::EarlyData => crate::error::ErrorKind::User,
HandshakeError::ProtocolViolation => crate::error::ErrorKind::User,
#[cfg(all(target_os = "linux", not(test)))]
HandshakeError::KtlsUpgradeError(_) => crate::error::ErrorKind::Service,
// This error should not happen, but will if we have no default certificate and
// the client sends no SNI extension.
// If they provide SNI then we can be sure there is a certificate that matches.
@@ -57,22 +66,39 @@ impl ReportableError for HandshakeError {
}
}
pub enum HandshakeData<S> {
Startup(PqStream<Stream<S>>, StartupMessageParams),
pub enum HandshakeData<S: AsRawFd> {
Startup(
PqStream<Stream<S>>,
Option<EndpointId>,
StartupMessageParams,
),
Cancel(CancelKeyData),
}
#[cfg(any(not(target_os = "linux"), test))]
pub trait KtlsAsyncReadReady {}
#[cfg(all(target_os = "linux", not(test)))]
pub trait KtlsAsyncReadReady: ktls::AsyncReadReady {}
#[cfg(any(not(target_os = "linux"), test))]
impl<K: AsyncRead> KtlsAsyncReadReady for K {}
#[cfg(all(target_os = "linux", not(test)))]
impl<K: ktls::AsyncReadReady> KtlsAsyncReadReady for K {}
/// Establish a (most probably, secure) connection with the client.
/// For better testing experience, `stream` can be any object satisfying the traits.
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
/// we also take an extra care of propagating only the select handshake errors to client.
#[tracing::instrument(skip_all)]
pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
pub async fn handshake<S>(
ctx: &RequestMonitoring,
stream: S,
mut tls: Option<&TlsConfig>,
record_handshake_error: bool,
) -> Result<HandshakeData<S>, HandshakeError> {
) -> Result<HandshakeData<S>, HandshakeError>
where
S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady,
{
// Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false);
@@ -80,11 +106,11 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0);
let mut stream = PqStream::new(Stream::from_raw(stream));
let mut ep = None;
loop {
let msg = stream.read_startup_packet().await?;
use FeStartupPacket::*;
match msg {
SslRequest { direct } => match stream.get_ref() {
FeStartupPacket::SslRequest { direct } => match stream.get_ref() {
Stream::Raw { .. } if !tried_ssl => {
tried_ssl = true;
@@ -114,6 +140,9 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
));
};
#[cfg(all(target_os = "linux", not(test)))]
let raw = ktls::CorkStream::new(raw);
let mut read_buf = read_buf.reader();
let mut res = Ok(());
let accept = tokio_rustls::TlsAcceptor::from(tls.to_server_config())
@@ -139,18 +168,18 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
let tls_stream = accept.await.inspect_err(|_| {
if record_handshake_error {
Metrics::get().proxy.tls_handshake_failures.inc()
Metrics::get().proxy.tls_handshake_failures.inc();
}
})?;
let conn_info = tls_stream.get_ref().1;
// try parse endpoint
let ep = conn_info
ep = conn_info
.server_name()
.and_then(|sni| endpoint_sni(sni, &tls.common_names).ok().flatten());
if let Some(ep) = ep {
ctx.set_endpoint_id(ep);
if let Some(ep) = &ep {
ctx.set_endpoint_id(ep.clone());
}
// check the ALPN, if exists, as required.
@@ -171,7 +200,10 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream = PqStream {
framed: Framed {
stream: Stream::Tls {
tls: Box::new(tls_stream),
#[cfg(any(not(target_os = "linux"), test))]
tls: Box::pin(tls_stream),
#[cfg(all(target_os = "linux", not(test)))]
tls: ktls::config_ktls_server(tls_stream).await?,
tls_server_end_point,
},
read_buf,
@@ -182,7 +214,7 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
}
_ => return Err(HandshakeError::ProtocolViolation),
},
GssEncRequest => match stream.get_ref() {
FeStartupPacket::GssEncRequest => match stream.get_ref() {
Stream::Raw { .. } if !tried_gss => {
tried_gss = true;
@@ -191,7 +223,7 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
}
_ => return Err(HandshakeError::ProtocolViolation),
},
StartupMessage { params, version }
FeStartupPacket::StartupMessage { params, version }
if PG_PROTOCOL_EARLIEST <= version && version <= PG_PROTOCOL_LATEST =>
{
// Check that the config has been consumed during upgrade
@@ -208,10 +240,10 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
session_type = "normal",
"successful handshake"
);
break Ok(HandshakeData::Startup(stream, params));
break Ok(HandshakeData::Startup(stream, ep, params));
}
// downgrade protocol version
StartupMessage { params, version }
FeStartupPacket::StartupMessage { params, version }
if version.major() == 3 && version > PG_PROTOCOL_LATEST =>
{
warn!(?version, "unsupported minor version");
@@ -239,9 +271,9 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
session_type = "normal",
"successful handshake; unsupported minor version requested"
);
break Ok(HandshakeData::Startup(stream, params));
break Ok(HandshakeData::Startup(stream, ep, params));
}
StartupMessage { version, .. } => {
FeStartupPacket::StartupMessage { version, .. } => {
warn!(
?version,
session_type = "normal",
@@ -249,7 +281,7 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
);
return Err(HandshakeError::ProtocolViolation);
}
CancelRequest(cancel_key_data) => {
FeStartupPacket::CancelRequest(cancel_key_data) => {
info!(session_type = "cancellation", "successful handshake");
break Ok(HandshakeData::Cancel(cancel_key_data));
}

View File

@@ -1,3 +1,5 @@
use std::os::fd::AsRawFd;
use crate::{
cancellation,
compute::PostgresConnection,
@@ -10,7 +12,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
use utils::measured_stream::MeasuredStream;
use super::copy_bidirectional::ErrorSource;
use super::{copy_bidirectional::ErrorSource, handshake::KtlsAsyncReadReady};
/// Forward bytes in both directions (client <-> compute).
#[tracing::instrument(skip_all)]
@@ -57,7 +59,7 @@ pub async fn proxy_pass(
Ok(())
}
pub struct ProxyPassthrough<P, S> {
pub struct ProxyPassthrough<P, S: AsRawFd> {
pub client: Stream<S>,
pub compute: PostgresConnection,
pub aux: MetricsAuxInfo,
@@ -67,7 +69,7 @@ pub struct ProxyPassthrough<P, S> {
pub cancel: cancellation::Session<P>,
}
impl<P, S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<P, S> {
impl<P, S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady> ProxyPassthrough<P, S> {
pub async fn proxy_pass(self) -> Result<(), ErrorSource> {
let res = proxy_pass(self.client, self.compute.stream, self.aux).await;
if let Err(err) = self.compute.cancel_closure.try_cancel_query().await {

View File

@@ -2,6 +2,8 @@
mod mitm;
use std::pin::Pin;
use std::task::Poll;
use std::time::Duration;
use super::connect_compute::ConnectMechanism;
@@ -16,12 +18,14 @@ use crate::console::messages::{ConsoleError, Details, MetricsAuxInfo, Status};
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend};
use crate::console::{self, CachedNodeInfo, NodeInfo};
use crate::error::ErrorKind;
use crate::stream::Stream;
use crate::{http, sasl, scram, BranchId, EndpointId, ProjectId};
use anyhow::{bail, Context};
use async_trait::async_trait;
use retry::{retry_after, ShouldRetryWakeCompute};
use rstest::rstest;
use rustls::pki_types;
use tokio::io::DuplexStream;
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream};
@@ -35,28 +39,73 @@ fn generate_certs(
pki_types::CertificateDer<'static>,
pki_types::PrivateKeyDer<'static>,
)> {
let ca = rcgen::Certificate::from_params({
let ca_key = rcgen::KeyPair::generate()?;
let cert_key = rcgen::KeyPair::generate()?;
let ca = {
let mut params = rcgen::CertificateParams::default();
params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
params
})?;
params.self_signed(&ca_key)?
};
let cert = rcgen::Certificate::from_params({
let mut params = rcgen::CertificateParams::new(vec![hostname.into()]);
let cert = {
let mut params = rcgen::CertificateParams::new(vec![hostname.into()])?;
params.distinguished_name = rcgen::DistinguishedName::new();
params
.distinguished_name
.push(rcgen::DnType::CommonName, common_name);
params
})?;
params.signed_by(&cert_key, &ca, &ca_key)?
};
Ok((
pki_types::CertificateDer::from(ca.serialize_der()?),
pki_types::CertificateDer::from(cert.serialize_der_with_signer(&ca)?),
pki_types::PrivateKeyDer::Pkcs8(cert.serialize_private_key_der().into()),
ca.into(),
cert.into(),
pki_types::PrivateKeyDer::Pkcs8(cert_key.serialize_der().into()),
))
}
pub struct DummyClient(pub DuplexStream);
impl AsRawFd for DummyClient {
fn as_raw_fd(&self) -> std::os::unix::prelude::RawFd {
unreachable!()
}
}
impl AsyncWrite for DummyClient {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.0).poll_flush(cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}
}
impl AsyncRead for DummyClient {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}
struct ClientConfig<'a> {
config: rustls::ClientConfig,
hostname: &'a str,
@@ -121,7 +170,9 @@ fn generate_tls_config<'a>(
#[async_trait]
trait TestAuth: Sized {
async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
async fn authenticate<
S: AsyncRead + AsyncWrite + Unpin + Send + AsRawFd + KtlsAsyncReadReady,
>(
self,
stream: &mut PqStream<Stream<S>>,
) -> anyhow::Result<()> {
@@ -150,7 +201,9 @@ impl Scram {
#[async_trait]
impl TestAuth for Scram {
async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
async fn authenticate<
S: AsyncRead + AsyncWrite + Unpin + Send + AsRawFd + KtlsAsyncReadReady,
>(
self,
stream: &mut PqStream<Stream<S>>,
) -> anyhow::Result<()> {
@@ -170,14 +223,14 @@ impl TestAuth for Scram {
/// A dummy proxy impl which performs a handshake and reports auth success.
async fn dummy_proxy(
client: impl AsyncRead + AsyncWrite + Unpin + Send,
client: impl AsyncRead + AsyncWrite + Unpin + Send + AsRawFd,
tls: Option<TlsConfig>,
auth: impl TestAuth + Send,
) -> anyhow::Result<()> {
let (client, _) = read_proxy_protocol(client).await?;
let mut stream =
match handshake(&RequestMonitoring::test(), client, tls.as_ref(), false).await? {
HandshakeData::Startup(stream, _) => stream,
HandshakeData::Startup(stream, ..) => stream,
HandshakeData::Cancel(_) => bail!("cancellation not supported"),
};
@@ -196,7 +249,11 @@ async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let (_, server_config) = generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
let proxy = tokio::spawn(dummy_proxy(
DummyClient(client),
Some(server_config),
NoAuth,
));
let client_err = tokio_postgres::Config::new()
.user("john_doe")
@@ -225,7 +282,11 @@ async fn handshake_tls() -> anyhow::Result<()> {
let (client_config, server_config) =
generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
let proxy = tokio::spawn(dummy_proxy(
DummyClient(client),
Some(server_config),
NoAuth,
));
let (_client, _conn) = tokio_postgres::Config::new()
.user("john_doe")
@@ -241,7 +302,7 @@ async fn handshake_tls() -> anyhow::Result<()> {
async fn handshake_raw() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth));
let proxy = tokio::spawn(dummy_proxy(DummyClient(client), None, NoAuth));
let (_client, _conn) = tokio_postgres::Config::new()
.user("john_doe")
@@ -285,7 +346,7 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
let (client_config, server_config) =
generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(
client,
DummyClient(client),
Some(server_config),
Scram::new(password).await?,
));
@@ -309,7 +370,7 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
let (client_config, server_config) =
generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(
client,
DummyClient(client),
Some(server_config),
Scram::new("password").await?,
));
@@ -332,7 +393,11 @@ async fn scram_auth_mock() -> anyhow::Result<()> {
let (client_config, server_config) =
generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), Scram::mock()));
let proxy = tokio::spawn(dummy_proxy(
DummyClient(client),
Some(server_config),
Scram::mock(),
));
use rand::{distributions::Alphanumeric, Rng};
let password: String = rand::thread_rng()

View File

@@ -36,14 +36,14 @@ async fn proxy_mitm(
let end_server = connect_tls(server2, client_config2.make_tls_connect().unwrap()).await;
let (end_client, startup) = match handshake(
&RequestMonitoring::test(),
client1,
DummyClient(client1),
Some(&server_config1),
false,
)
.await
.unwrap()
{
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Startup(stream, _ep, params) => (stream, params),
HandshakeData::Cancel(_) => panic!("cancellation not supported"),
};
@@ -68,7 +68,7 @@ async fn proxy_mitm(
end_client.send(Bytes::from_static(b"R\0\0\0\x17\0\0\0\x0aSCRAM-SHA-256\0\0")).await.unwrap();
continue;
}
end_client.send(message).await.unwrap()
end_client.send(message).await.unwrap();
}
_ => break,
}
@@ -88,7 +88,7 @@ async fn proxy_mitm(
end_server.send(buf.freeze()).await.unwrap();
continue;
}
end_server.send(message).await.unwrap()
end_server.send(message).await.unwrap();
}
_ => break,
}
@@ -154,7 +154,7 @@ impl Encoder<Bytes> for PgFrame {
async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
let (server, client, client_config, server_config) = proxy_mitm(Intercept::None).await;
let proxy = tokio::spawn(dummy_proxy(
client,
DummyClient(client),
Some(server_config),
Scram::new("password").await?,
));
@@ -237,7 +237,7 @@ async fn connect_failure(
) -> anyhow::Result<()> {
let (server, client, client_config, server_config) = proxy_mitm(intercept).await;
let proxy = tokio::spawn(dummy_proxy(
client,
DummyClient(client),
Some(server_config),
Scram::new("password").await?,
));

View File

@@ -237,7 +237,7 @@ impl Token {
}
pub fn release(mut self, outcome: Outcome) {
self.release_mut(Some(outcome))
self.release_mut(Some(outcome));
}
pub fn release_mut(&mut self, outcome: Option<Outcome>) {
@@ -249,7 +249,7 @@ impl Token {
impl Drop for Token {
fn drop(&mut self) {
self.release_mut(None)
self.release_mut(None);
}
}

View File

@@ -25,9 +25,8 @@ pub struct Aimd {
impl LimitAlgorithm for Aimd {
fn update(&self, old_limit: usize, sample: Sample) -> usize {
use Outcome::*;
match sample.outcome {
Success => {
Outcome::Success => {
let utilisation = sample.in_flight as f32 / old_limit as f32;
if utilisation > self.utilisation {
@@ -42,7 +41,7 @@ impl LimitAlgorithm for Aimd {
old_limit
}
}
Overload => {
Outcome::Overload => {
let limit = old_limit as f32 * self.dec;
// Floor instead of round, so the limit reduces even with small numbers.

View File

@@ -98,7 +98,7 @@ impl ConnectionWithCredentialsProvider {
info!("Establishing a new connection...");
self.con = None;
if let Some(f) = self.refresh_token_task.take() {
f.abort()
f.abort();
}
let mut con = self
.get_client()

View File

@@ -108,7 +108,6 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
}
#[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
use Notification::*;
let payload: String = msg.get_payload()?;
tracing::debug!(?payload, "received a message payload");
@@ -124,7 +123,7 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
};
tracing::debug!(?msg, "received a message");
match msg {
Cancel(cancel_session) => {
Notification::Cancel(cancel_session) => {
tracing::Span::current().record(
"session_id",
tracing::field::display(cancel_session.session_id),
@@ -153,12 +152,12 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
}
_ => {
invalidate_cache(self.cache.clone(), msg.clone());
if matches!(msg, AllowedIpsUpdate { .. }) {
if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
Metrics::get()
.proxy
.redis_events_count
.inc(RedisEventsCount::AllowedIpsUpdate);
} else if matches!(msg, PasswordUpdate { .. }) {
} else if matches!(msg, Notification::PasswordUpdate { .. }) {
Metrics::get()
.proxy
.redis_events_count
@@ -180,16 +179,16 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
}
fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
use Notification::*;
match msg {
AllowedIpsUpdate { allowed_ips_update } => {
cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id)
Notification::AllowedIpsUpdate { allowed_ips_update } => {
cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id);
}
PasswordUpdate { password_update } => cache.invalidate_role_secret_for_project(
password_update.project_id,
password_update.role_name,
),
Cancel(_) => unreachable!("cancel message should be handled separately"),
Notification::PasswordUpdate { password_update } => cache
.invalidate_role_secret_for_project(
password_update.project_id,
password_update.role_name,
),
Notification::Cancel(_) => unreachable!("cancel message should be handled separately"),
}
}

View File

@@ -42,10 +42,9 @@ pub enum Error {
impl UserFacingError for Error {
fn to_string_client(&self) -> String {
use Error::*;
match self {
ChannelBindingFailed(m) => m.to_string(),
ChannelBindingBadMethod(m) => format!("unsupported channel binding method {m}"),
Self::ChannelBindingFailed(m) => (*m).to_string(),
Self::ChannelBindingBadMethod(m) => format!("unsupported channel binding method {m}"),
_ => "authentication protocol violation".to_string(),
}
}

View File

@@ -13,11 +13,10 @@ pub enum ChannelBinding<T> {
impl<T> ChannelBinding<T> {
pub fn and_then<R, E>(self, f: impl FnOnce(T) -> Result<R, E>) -> Result<ChannelBinding<R>, E> {
use ChannelBinding::*;
Ok(match self {
NotSupportedClient => NotSupportedClient,
NotSupportedServer => NotSupportedServer,
Required(x) => Required(f(x)?),
Self::NotSupportedClient => ChannelBinding::NotSupportedClient,
Self::NotSupportedServer => ChannelBinding::NotSupportedServer,
Self::Required(x) => ChannelBinding::Required(f(x)?),
})
}
}
@@ -25,11 +24,10 @@ impl<T> ChannelBinding<T> {
impl<'a> ChannelBinding<&'a str> {
// NB: FromStr doesn't work with lifetimes
pub fn parse(input: &'a str) -> Option<Self> {
use ChannelBinding::*;
Some(match input {
"n" => NotSupportedClient,
"y" => NotSupportedServer,
other => Required(other.strip_prefix("p=")?),
"n" => Self::NotSupportedClient,
"y" => Self::NotSupportedServer,
other => Self::Required(other.strip_prefix("p=")?),
})
}
}
@@ -40,17 +38,16 @@ impl<T: std::fmt::Display> ChannelBinding<T> {
&self,
get_cbind_data: impl FnOnce(&T) -> Result<&'a [u8], E>,
) -> Result<std::borrow::Cow<'static, str>, E> {
use ChannelBinding::*;
Ok(match self {
NotSupportedClient => {
Self::NotSupportedClient => {
// base64::encode("n,,")
"biws".into()
}
NotSupportedServer => {
Self::NotSupportedServer => {
// base64::encode("y,,")
"eSws".into()
}
Required(mode) => {
Self::Required(mode) => {
use std::io::Write;
let mut cbind_input = vec![];
write!(&mut cbind_input, "p={mode},,",).unwrap();

View File

@@ -42,10 +42,9 @@ pub(super) enum ServerMessage<T> {
impl<'a> ServerMessage<&'a str> {
pub(super) fn to_reply(&self) -> BeMessage<'a> {
use BeAuthenticationSaslMessage::*;
BeMessage::AuthenticationSasl(match self {
ServerMessage::Continue(s) => Continue(s.as_bytes()),
ServerMessage::Final(s) => Final(s.as_bytes()),
ServerMessage::Continue(s) => BeAuthenticationSaslMessage::Continue(s.as_bytes()),
ServerMessage::Final(s) => BeAuthenticationSaslMessage::Final(s.as_bytes()),
})
}
}

View File

@@ -137,12 +137,12 @@ mod tests {
#[tokio::test]
async fn round_trip() {
run_round_trip_test("pencil", "pencil").await
run_round_trip_test("pencil", "pencil").await;
}
#[tokio::test]
#[should_panic(expected = "password doesn't match")]
async fn failure() {
run_round_trip_test("pencil", "eraser").await
run_round_trip_test("pencil", "eraser").await;
}
}

View File

@@ -98,8 +98,6 @@ mod tests {
// q% of counts will be within p of the actual value
let mut sketch = CountMinSketch::with_params(p / N as f64, 1.0 - q);
dbg!(sketch.buckets.len());
// insert a bunch of entries in a random order
let mut ids2 = ids.clone();
while !ids2.is_empty() {

View File

@@ -210,23 +210,23 @@ impl sasl::Mechanism for Exchange<'_> {
type Output = super::ScramKey;
fn exchange(mut self, input: &str) -> sasl::Result<sasl::Step<Self, Self::Output>> {
use {sasl::Step::*, ExchangeState::*};
use {sasl::Step, ExchangeState};
match &self.state {
Initial(init) => {
ExchangeState::Initial(init) => {
match init.transition(self.secret, &self.tls_server_end_point, input)? {
Continue(sent, msg) => {
self.state = SaltSent(sent);
Ok(Continue(self, msg))
Step::Continue(sent, msg) => {
self.state = ExchangeState::SaltSent(sent);
Ok(Step::Continue(self, msg))
}
Success(x, _) => match x {},
Failure(msg) => Ok(Failure(msg)),
Step::Success(x, _) => match x {},
Step::Failure(msg) => Ok(Step::Failure(msg)),
}
}
SaltSent(sent) => {
ExchangeState::SaltSent(sent) => {
match sent.transition(self.secret, &self.tls_server_end_point, input)? {
Success(keys, msg) => Ok(Success(keys, msg)),
Continue(x, _) => match x {},
Failure(msg) => Ok(Failure(msg)),
Step::Success(keys, msg) => Ok(Step::Success(keys, msg)),
Step::Continue(x, _) => match x {},
Step::Failure(msg) => Ok(Step::Failure(msg)),
}
}
}

View File

@@ -59,7 +59,7 @@ impl<'a> ClientFirstMessage<'a> {
// https://github.com/postgres/postgres/blob/f83908798f78c4cafda217ca875602c88ea2ae28/src/backend/libpq/auth-scram.c#L13-L14
if !username.is_empty() {
tracing::warn!(username, "scram username provided, but is not expected")
tracing::warn!(username, "scram username provided, but is not expected");
// TODO(conrad):
// return None;
}
@@ -137,7 +137,7 @@ impl<'a> ClientFinalMessage<'a> {
/// Build a response to [`ClientFinalMessage`].
pub fn build_server_final_message(
&self,
signature_builder: SignatureBuilder,
signature_builder: SignatureBuilder<'_>,
server_key: &ScramKey,
) -> String {
let mut buf = String::from("v=");
@@ -212,7 +212,7 @@ mod tests {
#[test]
fn parse_client_first_message_with_invalid_gs2_authz() {
assert!(ClientFirstMessage::parse("n,authzid,n=,r=nonce").is_none())
assert!(ClientFirstMessage::parse("n,authzid,n=,r=nonce").is_none());
}
#[test]

View File

@@ -84,6 +84,6 @@ mod tests {
};
let expected = pbkdf2_hmac_array::<Sha256, 32>(pass, salt, 600000);
assert_eq!(hash, expected)
assert_eq!(hash, expected);
}
}

View File

@@ -270,7 +270,7 @@ fn thread_rt(pool: Arc<ThreadPool>, worker: Worker<JobSpec>, index: usize) {
.inc(ThreadPoolWorkerId(index));
// skip for now
worker.push(job)
worker.push(job);
}
}
@@ -316,6 +316,6 @@ mod tests {
10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242,
178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140,
];
assert_eq!(actual, expected)
assert_eq!(actual, expected);
}
}

View File

@@ -10,6 +10,7 @@ mod json;
mod sql_over_http;
mod websocket;
use async_trait::async_trait;
use atomic_take::AtomicTake;
use bytes::Bytes;
pub use conn_pool::GlobalConnPoolOptions;
@@ -26,8 +27,9 @@ use rand::rngs::StdRng;
use rand::SeedableRng;
pub use reqwest_middleware::{ClientWithMiddleware, Error};
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::time::timeout;
use tokio_rustls::{server::TlsStream, TlsAcceptor};
use tokio_rustls::TlsAcceptor;
use tokio_util::task::TaskTracker;
use crate::cancellation::CancellationHandlerMain;
@@ -41,7 +43,7 @@ use crate::serverless::backend::PoolingBackend;
use crate::serverless::http_util::{api_error_into_response, json_response};
use std::net::{IpAddr, SocketAddr};
use std::pin::pin;
use std::pin::{pin, Pin};
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use tokio_util::sync::CancellationToken;
@@ -86,18 +88,18 @@ pub async fn task_main(
config,
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
});
let tls_config = match config.tls_config.as_ref() {
Some(config) => config,
let tls_acceptor: Arc<dyn MaybeTlsAcceptor> = match config.tls_config.as_ref() {
Some(config) => {
let mut tls_server_config = rustls::ServerConfig::clone(&config.to_server_config());
// prefer http2, but support http/1.1
tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
Arc::new(tls_server_config) as Arc<_>
}
None => {
warn!("TLS config is missing, WebSocket Secure server will not be started");
return Ok(());
warn!("TLS config is missing");
Arc::new(NoTls) as Arc<_>
}
};
let mut tls_server_config = rustls::ServerConfig::clone(&tls_config.to_server_config());
// prefer http2, but support http/1.1
tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
let tls_acceptor: tokio_rustls::TlsAcceptor = Arc::new(tls_server_config).into();
let connections = tokio_util::task::task_tracker::TaskTracker::new();
connections.close(); // allows `connections.wait to complete`
@@ -120,7 +122,7 @@ pub async fn task_main(
tracing::trace!("attempting to cancel a random connection");
if let Some(token) = config.http_config.cancel_set.take() {
tracing::debug!("cancelling a random connection");
token.cancel()
token.cancel();
}
}
@@ -176,16 +178,53 @@ pub async fn task_main(
Ok(())
}
pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + 'static {}
impl<T: AsyncRead + AsyncWrite + Send + 'static> AsyncReadWrite for T {}
pub type AsyncRW = Pin<Box<dyn AsyncReadWrite>>;
#[async_trait]
trait MaybeTlsAcceptor: Send + Sync + 'static {
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW>;
}
#[async_trait]
impl MaybeTlsAcceptor for rustls::ServerConfig {
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW> {
#[cfg(all(target_os = "linux", not(test)))]
let conn = ktls::CorkStream::new(conn);
let tls = TlsAcceptor::from(self).accept(conn).await?;
#[cfg(all(target_os = "linux", not(test)))]
return ktls::config_ktls_server(tls)
.await
.map(|s| Box::pin(s) as _)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));
#[cfg(any(not(target_os = "linux"), test))]
Ok(Box::pin(tls))
}
}
struct NoTls;
#[async_trait]
impl MaybeTlsAcceptor for NoTls {
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW> {
Ok(Box::pin(conn))
}
}
/// Handles the TCP startup lifecycle.
/// 1. Parses PROXY protocol V2
/// 2. Handles TLS handshake
async fn connection_startup(
config: &ProxyConfig,
tls_acceptor: TlsAcceptor,
tls_acceptor: Arc<dyn MaybeTlsAcceptor>,
session_id: uuid::Uuid,
conn: TcpStream,
peer_addr: SocketAddr,
) -> Option<(TlsStream<ChainRW<TcpStream>>, IpAddr)> {
) -> Option<(AsyncRW, IpAddr)> {
// handle PROXY protocol
let (conn, peer) = match read_proxy_protocol(conn).await {
Ok(c) => c,
@@ -198,7 +237,7 @@ async fn connection_startup(
let peer_addr = peer.unwrap_or(peer_addr).ip();
let has_private_peer_addr = match peer_addr {
IpAddr::V4(ip) => ip.is_private(),
_ => false,
IpAddr::V6(_) => false,
};
info!(?session_id, %peer_addr, "accepted new TCP connection");
@@ -241,7 +280,7 @@ async fn connection_handler(
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_token: CancellationToken,
conn: TlsStream<ChainRW<TcpStream>>,
conn: AsyncRW,
peer_addr: IpAddr,
session_id: uuid::Uuid,
) {
@@ -326,7 +365,9 @@ async fn request_handler(
.map(|s| s.to_string());
// Check if the request is a websocket upgrade request.
if framed_websockets::upgrade::is_upgrade_request(&request) {
if config.http_config.accept_websockets
&& framed_websockets::upgrade::is_upgrade_request(&request)
{
let ctx = RequestMonitoring::new(
session_id,
peer_addr,

View File

@@ -390,7 +390,7 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
.write()
.get_conn_entry(conn_info.db_and_user())
{
client = Some(entry.conn)
client = Some(entry.conn);
}
let endpoint_pool = Arc::downgrade(&endpoint_pool);
@@ -662,13 +662,13 @@ impl<C: ClientInnerExt> Discard<'_, C> {
pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
let conn_info = &self.conn_info;
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
info!("pool: throwing away connection '{conn_info}' because connection is not idle")
info!("pool: throwing away connection '{conn_info}' because connection is not idle");
}
}
pub fn discard(&mut self) {
let conn_info = &self.conn_info;
if std::mem::take(self.pool).strong_count() > 0 {
info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state")
info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state");
}
}
}
@@ -758,6 +758,7 @@ mod tests {
async fn test_pool() {
let _ = env_logger::try_init();
let config = Box::leak(Box::new(crate::config::HttpConfig {
accept_websockets: false,
pool_options: GlobalConnPoolOptions {
max_conns_per_endpoint: 2,
gc_epoch: Duration::from_secs(1),

View File

@@ -147,7 +147,7 @@ impl UserFacingError for ConnInfoError {
fn get_conn_info(
ctx: &RequestMonitoring,
headers: &HeaderMap,
tls: &TlsConfig,
tls: Option<&TlsConfig>,
) -> Result<ConnInfo, ConnInfoError> {
// HTTP only uses cleartext (for now and likely always)
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
@@ -184,12 +184,22 @@ fn get_conn_info(
.ok_or(ConnInfoError::MissingPassword)?;
let password = urlencoding::decode_binary(password.as_bytes());
let hostname = connection_url
.host_str()
.ok_or(ConnInfoError::MissingHostname)?;
let endpoint =
endpoint_sni(hostname, &tls.common_names)?.ok_or(ConnInfoError::MalformedEndpoint)?;
let endpoint = match connection_url.host() {
Some(url::Host::Domain(hostname)) => {
if let Some(tls) = tls {
endpoint_sni(hostname, &tls.common_names)?
.ok_or(ConnInfoError::MalformedEndpoint)?
} else {
hostname
.split_once(".")
.map_or(hostname, |(prefix, _)| prefix)
.into()
}
}
Some(url::Host::Ipv4(_)) | Some(url::Host::Ipv6(_)) | None => {
return Err(ConnInfoError::MissingHostname)
}
};
ctx.set_endpoint_id(endpoint.clone());
let pairs = connection_url.query_pairs();
@@ -502,7 +512,7 @@ async fn handle_inner(
let headers = request.headers();
// TLS config should be there.
let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref().unwrap())?;
let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref())?;
info!(user = conn_info.user_info.user.as_str(), "credentials");
// Allow connection pooling only if explicitly requested

View File

@@ -16,6 +16,7 @@ use hyper1::upgrade::OnUpgrade;
use hyper_util::rt::TokioIo;
use pin_project_lite::pin_project;
use std::os::fd::AsRawFd;
use std::{
pin::Pin,
sync::Arc,
@@ -45,6 +46,18 @@ impl<S> WebSocketRw<S> {
}
}
impl<S> AsRawFd for WebSocketRw<S> {
fn as_raw_fd(&self) -> std::os::unix::prelude::RawFd {
unreachable!("ktls should not need to be used for websocket rw")
}
}
#[cfg(all(target_os = "linux", not(test)))]
impl<S> ktls::AsyncReadReady for WebSocketRw<S> {
fn poll_read_ready(&self, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
unreachable!("ktls should not need to be used for websocket rw")
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for WebSocketRw<S> {
fn poll_write(
self: Pin<&mut Self>,

View File

@@ -1,11 +1,13 @@
use crate::config::TlsServerEndPoint;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::Metrics;
use crate::proxy::handshake::KtlsAsyncReadReady;
use bytes::BytesMut;
use pq_proto::framed::{ConnectionError, Framed};
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
use rustls::ServerConfig;
use std::os::fd::AsRawFd;
use std::pin::Pin;
use std::sync::Arc;
use std::{io, task};
@@ -172,34 +174,31 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
}
/// Wrapper for upgrading raw streams into secure streams.
pub enum Stream<S> {
pub enum Stream<S: AsRawFd> {
/// We always begin with a raw stream,
/// which may then be upgraded into a secure stream.
Raw { raw: S },
Tls {
/// We box [`TlsStream`] since it can be quite large.
tls: Box<TlsStream<S>>,
#[cfg(any(not(target_os = "linux"), test))]
tls: Pin<Box<TlsStream<S>>>,
#[cfg(all(target_os = "linux", not(test)))]
tls: ktls::KtlsStream<S>,
/// Channel binding parameter
tls_server_end_point: TlsServerEndPoint,
},
}
impl<S: Unpin> Unpin for Stream<S> {}
impl<S: Unpin + AsRawFd> Unpin for Stream<S> {}
impl<S> Stream<S> {
impl<S: AsRawFd> Stream<S> {
/// Construct a new instance from a raw stream.
pub fn from_raw(raw: S) -> Self {
Self::Raw { raw }
}
/// Return SNI hostname when it's available.
pub fn sni_hostname(&self) -> Option<&str> {
match self {
Stream::Raw { .. } => None,
Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
}
}
pub fn tls_server_end_point(&self) -> TlsServerEndPoint {
match self {
Stream::Raw { .. } => TlsServerEndPoint::Undefined,
@@ -221,7 +220,7 @@ pub enum StreamUpgradeError {
Io(#[from] io::Error),
}
impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
impl<S: AsyncRead + AsyncWrite + Unpin + AsRawFd> Stream<S> {
/// If possible, upgrade raw stream into a secure TLS-based stream.
pub async fn upgrade(
self,
@@ -234,7 +233,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
.await
.inspect_err(|_| {
if record_handshake_error {
Metrics::get().proxy.tls_handshake_failures.inc()
Metrics::get().proxy.tls_handshake_failures.inc();
}
})?),
Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
@@ -242,7 +241,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
impl<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady> AsyncRead for Stream<S> {
fn poll_read(
mut self: Pin<&mut Self>,
context: &mut task::Context<'_>,
@@ -255,7 +254,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
impl<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady> AsyncWrite for Stream<S> {
fn poll_write(
mut self: Pin<&mut Self>,
context: &mut task::Context<'_>,

View File

@@ -12,7 +12,7 @@ impl ApiUrl {
}
/// See [`url::Url::path_segments_mut`].
pub fn path_segments_mut(&mut self) -> url::PathSegmentsMut {
pub fn path_segments_mut(&mut self) -> url::PathSegmentsMut<'_> {
// We've already verified that it works during construction.
self.0.path_segments_mut().expect("bad API url")
}

View File

@@ -36,7 +36,7 @@ impl<T> Default for Waiters<T> {
}
impl<T> Waiters<T> {
pub fn register(&self, key: String) -> Result<Waiter<T>, RegisterError> {
pub fn register(&self, key: String) -> Result<Waiter<'_, T>, RegisterError> {
let (tx, rx) = oneshot::channel();
self.0

View File

@@ -44,7 +44,7 @@ run the following commands from the top of the neon.git checkout
# test suite run
export TEST_OUTPUT="$TEST_OUTPUT"
DEFAULT_PG_VERSION=15 BUILD_TYPE=release ./scripts/pytest test_runner/performance/test_latency.py
DEFAULT_PG_VERSION=16 BUILD_TYPE=release ./scripts/pytest test_runner/performance/test_latency.py
# for interactive use
export NEON_REPO_DIR="$NEON_REPO_DIR"

View File

@@ -87,9 +87,12 @@ impl Heartbeater {
pageservers,
reply: sender,
})
.unwrap();
.map_err(|_| HeartbeaterError::Cancel)?;
receiver.await.unwrap()
receiver
.await
.map_err(|_| HeartbeaterError::Cancel)
.and_then(|x| x)
}
}

View File

@@ -0,0 +1,135 @@
use std::sync::Arc;
use hyper::Uri;
use tokio_util::sync::CancellationToken;
use crate::{
peer_client::{GlobalObservedState, PeerClient},
persistence::{ControllerPersistence, DatabaseError, DatabaseResult, Persistence},
service::Config,
};
/// Helper for storage controller leadership acquisition
pub(crate) struct Leadership {
persistence: Arc<Persistence>,
config: Config,
cancel: CancellationToken,
}
#[derive(thiserror::Error, Debug)]
pub(crate) enum Error {
#[error(transparent)]
Database(#[from] DatabaseError),
}
pub(crate) type Result<T> = std::result::Result<T, Error>;
impl Leadership {
pub(crate) fn new(
persistence: Arc<Persistence>,
config: Config,
cancel: CancellationToken,
) -> Self {
Self {
persistence,
config,
cancel,
}
}
/// Find the current leader in the database and request it to step down if required.
/// Should be called early on in within the start-up sequence.
///
/// Returns a tuple of two optionals: the current leader and its observed state
pub(crate) async fn step_down_current_leader(
&self,
) -> Result<(Option<ControllerPersistence>, Option<GlobalObservedState>)> {
let leader = self.current_leader().await?;
let leader_step_down_state = if let Some(ref leader) = leader {
if self.config.start_as_candidate {
self.request_step_down(leader).await
} else {
None
}
} else {
tracing::info!("No leader found to request step down from. Will build observed state.");
None
};
Ok((leader, leader_step_down_state))
}
/// Mark the current storage controller instance as the leader in the database
pub(crate) async fn become_leader(
&self,
current_leader: Option<ControllerPersistence>,
) -> Result<()> {
if let Some(address_for_peers) = &self.config.address_for_peers {
// TODO: `address-for-peers` can become a mandatory cli arg
// after we update the k8s setup
let proposed_leader = ControllerPersistence {
address: address_for_peers.to_string(),
started_at: chrono::Utc::now(),
};
self.persistence
.update_leader(current_leader, proposed_leader)
.await
.map_err(Error::Database)
} else {
tracing::info!("No address-for-peers provided. Skipping leader persistence.");
Ok(())
}
}
async fn current_leader(&self) -> DatabaseResult<Option<ControllerPersistence>> {
let res = self.persistence.get_leader().await;
if let Err(DatabaseError::Query(diesel::result::Error::DatabaseError(_kind, ref err))) = res
{
const REL_NOT_FOUND_MSG: &str = "relation \"controllers\" does not exist";
if err.message().trim() == REL_NOT_FOUND_MSG {
// Special case: if this is a brand new storage controller, migrations will not
// have run at this point yet, and, hence, the controllers table does not exist.
// Detect this case via the error string (diesel doesn't type it) and allow it.
tracing::info!("Detected first storage controller start-up. Allowing missing controllers table ...");
return Ok(None);
}
}
res
}
/// Request step down from the currently registered leader in the database
///
/// If such an entry is persisted, the success path returns the observed
/// state and details of the leader. Otherwise, None is returned indicating
/// there is no leader currently.
async fn request_step_down(
&self,
leader: &ControllerPersistence,
) -> Option<GlobalObservedState> {
tracing::info!("Sending step down request to {leader:?}");
let client = PeerClient::new(
Uri::try_from(leader.address.as_str()).expect("Failed to build leader URI"),
self.config.peer_jwt_token.clone(),
);
let state = client.step_down(&self.cancel).await;
match state {
Ok(state) => Some(state),
Err(err) => {
// TODO: Make leaders periodically update a timestamp field in the
// database and, if the leader is not reachable from the current instance,
// but inferred as alive from the timestamp, abort start-up. This avoids
// a potential scenario in which we have two controllers acting as leaders.
tracing::error!(
"Leader ({}) did not respond to step-down request: {}",
leader.address,
err
);
None
}
}
}
}

View File

@@ -8,6 +8,7 @@ mod drain_utils;
mod heartbeater;
pub mod http;
mod id_lock_map;
mod leadership;
pub mod metrics;
mod node;
mod pageserver_client;

View File

@@ -1,6 +1,5 @@
use anyhow::{anyhow, Context};
use clap::Parser;
use diesel::Connection;
use hyper::Uri;
use metrics::launch_timestamp::LaunchTimestamp;
use metrics::BuildInfo;
@@ -27,9 +26,6 @@ use utils::{project_build_tag, project_git_version, tcp_listener};
project_git_version!(GIT_VERSION);
project_build_tag!(BUILD_TAG);
use diesel_migrations::{embed_migrations, EmbeddedMigrations};
pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations");
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
#[command(arg_required_else_help(true))]
@@ -51,6 +47,9 @@ struct Cli {
#[arg(long)]
control_plane_jwt_token: Option<String>,
#[arg(long)]
peer_jwt_token: Option<String>,
/// URL to control plane compute notification endpoint
#[arg(long)]
compute_hook_url: Option<String>,
@@ -130,28 +129,28 @@ struct Secrets {
public_key: Option<JwtAuth>,
jwt_token: Option<String>,
control_plane_jwt_token: Option<String>,
peer_jwt_token: Option<String>,
}
impl Secrets {
const DATABASE_URL_ENV: &'static str = "DATABASE_URL";
const PAGESERVER_JWT_TOKEN_ENV: &'static str = "PAGESERVER_JWT_TOKEN";
const CONTROL_PLANE_JWT_TOKEN_ENV: &'static str = "CONTROL_PLANE_JWT_TOKEN";
const PEER_JWT_TOKEN_ENV: &'static str = "PEER_JWT_TOKEN";
const PUBLIC_KEY_ENV: &'static str = "PUBLIC_KEY";
/// Load secrets from, in order of preference:
/// - CLI args if database URL is provided on the CLI
/// - Environment variables if DATABASE_URL is set.
/// - AWS Secrets Manager secrets
async fn load(args: &Cli) -> anyhow::Result<Self> {
let Some(database_url) =
Self::load_secret(&args.database_url, Self::DATABASE_URL_ENV).await
let Some(database_url) = Self::load_secret(&args.database_url, Self::DATABASE_URL_ENV)
else {
anyhow::bail!(
"Database URL is not set (set `--database-url`, or `DATABASE_URL` environment)"
)
};
let public_key = match Self::load_secret(&args.public_key, Self::PUBLIC_KEY_ENV).await {
let public_key = match Self::load_secret(&args.public_key, Self::PUBLIC_KEY_ENV) {
Some(v) => Some(JwtAuth::from_key(v).context("Loading public key")?),
None => None,
};
@@ -159,18 +158,18 @@ impl Secrets {
let this = Self {
database_url,
public_key,
jwt_token: Self::load_secret(&args.jwt_token, Self::PAGESERVER_JWT_TOKEN_ENV).await,
jwt_token: Self::load_secret(&args.jwt_token, Self::PAGESERVER_JWT_TOKEN_ENV),
control_plane_jwt_token: Self::load_secret(
&args.control_plane_jwt_token,
Self::CONTROL_PLANE_JWT_TOKEN_ENV,
)
.await,
),
peer_jwt_token: Self::load_secret(&args.peer_jwt_token, Self::PEER_JWT_TOKEN_ENV),
};
Ok(this)
}
async fn load_secret(cli: &Option<String>, env_name: &str) -> Option<String> {
fn load_secret(cli: &Option<String>, env_name: &str) -> Option<String> {
if let Some(v) = cli {
Some(v.clone())
} else if let Ok(v) = std::env::var(env_name) {
@@ -181,20 +180,6 @@ impl Secrets {
}
}
/// Execute the diesel migrations that are built into this binary
async fn migration_run(database_url: &str) -> anyhow::Result<()> {
use diesel::PgConnection;
use diesel_migrations::{HarnessWithOutput, MigrationHarness};
let mut conn = PgConnection::establish(database_url)?;
HarnessWithOutput::write_to_stdout(&mut conn)
.run_pending_migrations(MIGRATIONS)
.map(|_| ())
.map_err(|e| anyhow::anyhow!(e))?;
Ok(())
}
fn main() -> anyhow::Result<()> {
logging::init(
LogFormat::Plain,
@@ -284,6 +269,7 @@ async fn async_main() -> anyhow::Result<()> {
let config = Config {
jwt_token: secrets.jwt_token,
control_plane_jwt_token: secrets.control_plane_jwt_token,
peer_jwt_token: secrets.peer_jwt_token,
compute_hook_url: args.compute_hook_url,
max_offline_interval: args
.max_offline_interval
@@ -304,13 +290,9 @@ async fn async_main() -> anyhow::Result<()> {
http_service_port: args.listen.port() as i32,
};
// After loading secrets & config, but before starting anything else, apply database migrations
// Validate that we can connect to the database
Persistence::await_connection(&secrets.database_url, args.db_connect_timeout.into()).await?;
migration_run(&secrets.database_url)
.await
.context("Running database migrations")?;
let persistence = Arc::new(Persistence::new(secrets.database_url));
let service = Service::spawn(config, persistence.clone()).await?;

View File

@@ -230,6 +230,7 @@ pub(crate) enum DatabaseErrorLabel {
Connection,
ConnectionPool,
Logical,
Migration,
}
impl DatabaseError {
@@ -239,6 +240,7 @@ impl DatabaseError {
Self::Connection(_) => DatabaseErrorLabel::Connection,
Self::ConnectionPool(_) => DatabaseErrorLabel::ConnectionPool,
Self::Logical(_) => DatabaseErrorLabel::Logical,
Self::Migration(_) => DatabaseErrorLabel::Migration,
}
}
}

View File

@@ -25,6 +25,9 @@ use crate::metrics::{
};
use crate::node::Node;
use diesel_migrations::{embed_migrations, EmbeddedMigrations};
const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations");
/// ## What do we store?
///
/// The storage controller service does not store most of its state durably.
@@ -72,6 +75,8 @@ pub(crate) enum DatabaseError {
ConnectionPool(#[from] r2d2::Error),
#[error("Logical error: {0}")]
Logical(String),
#[error("Migration error: {0}")]
Migration(String),
}
#[derive(measured::FixedCardinalityLabel, Copy, Clone)]
@@ -167,6 +172,19 @@ impl Persistence {
}
}
/// Execute the diesel migrations that are built into this binary
pub(crate) async fn migration_run(&self) -> DatabaseResult<()> {
use diesel_migrations::{HarnessWithOutput, MigrationHarness};
self.with_conn(move |conn| -> DatabaseResult<()> {
HarnessWithOutput::write_to_stdout(conn)
.run_pending_migrations(MIGRATIONS)
.map(|_| ())
.map_err(|e| DatabaseError::Migration(e.to_string()))
})
.await
}
/// Wraps `with_conn` in order to collect latency and error metrics
async fn with_measured_conn<F, R>(&self, op: DatabaseOperation, func: F) -> DatabaseResult<R>
where

View File

@@ -17,8 +17,9 @@ use crate::{
compute_hook::NotifyError,
drain_utils::{self, TenantShardDrain, TenantShardIterator},
id_lock_map::{trace_exclusive_lock, trace_shared_lock, IdLockMap, TracingExclusiveGuard},
leadership::Leadership,
metrics,
peer_client::{GlobalObservedState, PeerClient},
peer_client::GlobalObservedState,
persistence::{
AbortShardSplitStatus, ControllerPersistence, DatabaseResult, MetadataHealthPersistence,
TenantFilter,
@@ -287,6 +288,9 @@ pub struct Config {
// This JWT token will be used to authenticate this service to the control plane.
pub control_plane_jwt_token: Option<String>,
// This JWT token will be used to authenticate with other storage controller instances
pub peer_jwt_token: Option<String>,
/// Where the compute hook should send notifications of pageserver attachment locations
/// (this URL points to the control plane in prod). If this is None, the compute hook will
/// assume it is running in a test environment and try to update neon_local.
@@ -333,7 +337,7 @@ impl From<DatabaseError> for ApiError {
DatabaseError::Connection(_) | DatabaseError::ConnectionPool(_) => {
ApiError::ShuttingDown
}
DatabaseError::Logical(reason) => {
DatabaseError::Logical(reason) | DatabaseError::Migration(reason) => {
ApiError::InternalServerError(anyhow::anyhow!(reason))
}
}
@@ -606,22 +610,15 @@ impl Service {
// Before making any obeservable changes to the cluster, persist self
// as leader in database and memory.
if let Some(address_for_peers) = &self.config.address_for_peers {
// TODO: `address-for-peers` can become a mandatory cli arg
// after we update the k8s setup
let proposed_leader = ControllerPersistence {
address: address_for_peers.to_string(),
started_at: chrono::Utc::now(),
};
let leadership = Leadership::new(
self.persistence.clone(),
self.config.clone(),
self.cancel.child_token(),
);
if let Err(err) = self
.persistence
.update_leader(current_leader, proposed_leader)
.await
{
tracing::error!("Failed to persist self as leader: {err}. Aborting start-up ...");
std::process::exit(1);
}
if let Err(e) = leadership.become_leader(current_leader).await {
tracing::error!("Failed to persist self as leader: {e}. Aborting start-up ...");
std::process::exit(1);
}
self.inner.write().unwrap().become_leader();
@@ -1159,6 +1156,16 @@ impl Service {
let (result_tx, result_rx) = tokio::sync::mpsc::unbounded_channel();
let (abort_tx, abort_rx) = tokio::sync::mpsc::unbounded_channel();
let leadership_cancel = CancellationToken::new();
let leadership = Leadership::new(persistence.clone(), config.clone(), leadership_cancel);
let (leader, leader_step_down_state) = leadership.step_down_current_leader().await?;
// Apply the migrations **after** the current leader has stepped down
// (or we've given up waiting for it), but **before** reading from the
// database. The only exception is reading the current leader before
// migrating.
persistence.migration_run().await?;
tracing::info!("Loading nodes from database...");
let nodes = persistence
.list_nodes()
@@ -1376,32 +1383,6 @@ impl Service {
return;
};
let leadership_status = this.inner.read().unwrap().get_leadership_status();
let leader = match this.get_leader().await {
Ok(ok) => ok,
Err(err) => {
tracing::error!(
"Failed to query database for current leader: {err}. Aborting start-up ..."
);
std::process::exit(1);
}
};
let leader_step_down_state = match leadership_status {
LeadershipStatus::Candidate => {
if let Some(ref leader) = leader {
this.request_step_down(leader).await
} else {
tracing::info!(
"No leader found to request step down from. Will build observed state."
);
None
}
}
LeadershipStatus::Leader => None,
LeadershipStatus::SteppedDown => unreachable!(),
};
this.startup_reconcile(leader, leader_step_down_state, bg_compute_notify_result_tx)
.await;
@@ -6377,42 +6358,4 @@ impl Service {
global_observed
}
/// Request step down from the currently registered leader in the database
///
/// If such an entry is persisted, the success path returns the observed
/// state and details of the leader. Otherwise, None is returned indicating
/// there is no leader currently.
///
/// On failures to query the database or step down error responses the process is killed
/// and we rely on k8s to retry.
async fn request_step_down(
&self,
leader: &ControllerPersistence,
) -> Option<GlobalObservedState> {
tracing::info!("Sending step down request to {leader:?}");
// TODO: jwt token
let client = PeerClient::new(
Uri::try_from(leader.address.as_str()).expect("Failed to build leader URI"),
self.config.jwt_token.clone(),
);
let state = client.step_down(&self.cancel).await;
match state {
Ok(state) => Some(state),
Err(err) => {
// TODO: Make leaders periodically update a timestamp field in the
// database and, if the leader is not reachable from the current instance,
// but inferred as alive from the timestamp, abort start-up. This avoids
// a potential scenario in which we have two controllers acting as leaders.
tracing::error!(
"Leader ({}) did not respond to step-down request: {}",
leader.address,
err
);
None
}
}
}
}

View File

@@ -1,10 +1,10 @@
use std::collections::{HashMap, HashSet};
use anyhow::Context;
use aws_sdk_s3::Client;
use pageserver::tenant::layer_map::LayerMap;
use pageserver::tenant::remote_timeline_client::index::LayerFileMetadata;
use pageserver_api::shard::ShardIndex;
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn};
use utils::generation::Generation;
use utils::id::TimelineId;
@@ -16,7 +16,7 @@ use futures_util::StreamExt;
use pageserver::tenant::remote_timeline_client::{parse_remote_index_path, remote_layer_path};
use pageserver::tenant::storage_layer::LayerName;
use pageserver::tenant::IndexPart;
use remote_storage::RemotePath;
use remote_storage::{GenericRemoteStorage, ListingObject, RemotePath};
pub(crate) struct TimelineAnalysis {
/// Anomalies detected
@@ -48,13 +48,12 @@ impl TimelineAnalysis {
}
pub(crate) async fn branch_cleanup_and_check_errors(
s3_client: &Client,
target: &RootTarget,
remote_client: &GenericRemoteStorage,
id: &TenantShardTimelineId,
tenant_objects: &mut TenantObjectListing,
s3_active_branch: Option<&BranchData>,
console_branch: Option<BranchData>,
s3_data: Option<S3TimelineBlobData>,
s3_data: Option<RemoteTimelineBlobData>,
) -> TimelineAnalysis {
let mut result = TimelineAnalysis::new();
@@ -78,7 +77,9 @@ pub(crate) async fn branch_cleanup_and_check_errors(
match s3_data {
Some(s3_data) => {
result.garbage_keys.extend(s3_data.unknown_keys);
result
.garbage_keys
.extend(s3_data.unknown_keys.into_iter().map(|k| k.key.to_string()));
match s3_data.blob_data {
BlobDataParseResult::Parsed {
@@ -143,11 +144,8 @@ pub(crate) async fn branch_cleanup_and_check_errors(
// HEAD request used here to address a race condition when an index was uploaded concurrently
// with our scan. We check if the object is uploaded to S3 after taking the listing snapshot.
let response = s3_client
.head_object()
.bucket(target.bucket_name())
.key(path.get_path().as_str())
.send()
let response = remote_client
.head_object(&path, &CancellationToken::new())
.await;
if response.is_err() {
@@ -284,14 +282,14 @@ impl TenantObjectListing {
}
#[derive(Debug)]
pub(crate) struct S3TimelineBlobData {
pub(crate) struct RemoteTimelineBlobData {
pub(crate) blob_data: BlobDataParseResult,
// Index objects that were not used when loading `blob_data`, e.g. those from old generations
pub(crate) unused_index_keys: Vec<String>,
pub(crate) unused_index_keys: Vec<ListingObject>,
// Objects whose keys were not recognized at all, i.e. not layer files, not indices
pub(crate) unknown_keys: Vec<String>,
pub(crate) unknown_keys: Vec<ListingObject>,
}
#[derive(Debug)]
@@ -323,31 +321,37 @@ pub(crate) fn parse_layer_object_name(name: &str) -> Result<(LayerName, Generati
}
pub(crate) async fn list_timeline_blobs(
s3_client: &Client,
remote_client: &GenericRemoteStorage,
id: TenantShardTimelineId,
s3_root: &RootTarget,
) -> anyhow::Result<S3TimelineBlobData> {
root_target: &RootTarget,
) -> anyhow::Result<RemoteTimelineBlobData> {
let mut s3_layers = HashSet::new();
let mut errors = Vec::new();
let mut unknown_keys = Vec::new();
let mut timeline_dir_target = s3_root.timeline_root(&id);
let mut timeline_dir_target = root_target.timeline_root(&id);
timeline_dir_target.delimiter = String::new();
let mut index_part_keys: Vec<String> = Vec::new();
let mut index_part_keys: Vec<ListingObject> = Vec::new();
let mut initdb_archive: bool = false;
let mut stream = std::pin::pin!(stream_listing(s3_client, &timeline_dir_target));
while let Some(obj) = stream.next().await {
let obj = obj?;
let key = obj.key();
let prefix_str = &timeline_dir_target
.prefix_in_bucket
.strip_prefix("/")
.unwrap_or(&timeline_dir_target.prefix_in_bucket);
let blob_name = key.strip_prefix(&timeline_dir_target.prefix_in_bucket);
let mut stream = std::pin::pin!(stream_listing(remote_client, &timeline_dir_target));
while let Some(obj) = stream.next().await {
let (key, Some(obj)) = obj? else {
panic!("ListingObject not specified");
};
let blob_name = key.get_path().as_str().strip_prefix(prefix_str);
match blob_name {
Some(name) if name.starts_with("index_part.json") => {
tracing::debug!("Index key {key}");
index_part_keys.push(key.to_owned())
index_part_keys.push(obj)
}
Some("initdb.tar.zst") => {
tracing::debug!("initdb archive {key}");
@@ -358,7 +362,7 @@ pub(crate) async fn list_timeline_blobs(
}
Some(maybe_layer_name) => match parse_layer_object_name(maybe_layer_name) {
Ok((new_layer, gen)) => {
tracing::debug!("Parsed layer key: {} {:?}", new_layer, gen);
tracing::debug!("Parsed layer key: {new_layer} {gen:?}");
s3_layers.insert((new_layer, gen));
}
Err(e) => {
@@ -366,13 +370,13 @@ pub(crate) async fn list_timeline_blobs(
errors.push(
format!("S3 list response got an object with key {key} that is not a layer name: {e}"),
);
unknown_keys.push(key.to_string());
unknown_keys.push(obj);
}
},
None => {
tracing::warn!("Unknown key {}", key);
tracing::warn!("Unknown key {key}");
errors.push(format!("S3 list response got an object with odd key {key}"));
unknown_keys.push(key.to_string());
unknown_keys.push(obj);
}
}
}
@@ -381,7 +385,7 @@ pub(crate) async fn list_timeline_blobs(
tracing::debug!(
"Timeline is empty apart from initdb archive: expected post-deletion state."
);
return Ok(S3TimelineBlobData {
return Ok(RemoteTimelineBlobData {
blob_data: BlobDataParseResult::Relic,
unused_index_keys: index_part_keys,
unknown_keys: Vec::new(),
@@ -395,13 +399,13 @@ pub(crate) async fn list_timeline_blobs(
// Stripping the index key to the last part, because RemotePath doesn't
// like absolute paths, and depending on prefix_in_bucket it's possible
// for the keys we read back to start with a slash.
let basename = key.rsplit_once('/').unwrap().1;
let basename = key.key.get_path().as_str().rsplit_once('/').unwrap().1;
parse_remote_index_path(RemotePath::from_string(basename).unwrap()).map(|g| (key, g))
})
.max_by_key(|i| i.1)
.map(|(k, g)| (k.clone(), g))
{
Some((key, gen)) => (Some(key), gen),
Some((key, gen)) => (Some::<ListingObject>(key.to_owned()), gen),
None => {
// Legacy/missing case: one or zero index parts, which did not have a generation
(index_part_keys.pop(), Generation::none())
@@ -416,17 +420,14 @@ pub(crate) async fn list_timeline_blobs(
}
if let Some(index_part_object_key) = index_part_object.as_ref() {
let index_part_bytes = download_object_with_retries(
s3_client,
&timeline_dir_target.bucket_name,
index_part_object_key,
)
.await
.context("index_part.json download")?;
let index_part_bytes =
download_object_with_retries(remote_client, &index_part_object_key.key)
.await
.context("index_part.json download")?;
match serde_json::from_slice(&index_part_bytes) {
Ok(index_part) => {
return Ok(S3TimelineBlobData {
return Ok(RemoteTimelineBlobData {
blob_data: BlobDataParseResult::Parsed {
index_part: Box::new(index_part),
index_part_generation,
@@ -448,7 +449,7 @@ pub(crate) async fn list_timeline_blobs(
);
}
Ok(S3TimelineBlobData {
Ok(RemoteTimelineBlobData {
blob_data: BlobDataParseResult::Incorrect { errors, s3_layers },
unused_index_keys: index_part_keys,
unknown_keys,

View File

@@ -6,7 +6,7 @@ use remote_storage::ListingMode;
use serde::{Deserialize, Serialize};
use crate::{
checks::parse_layer_object_name, init_remote_generic, metadata_stream::stream_tenants_generic,
checks::parse_layer_object_name, init_remote, metadata_stream::stream_tenants,
stream_objects_with_retries, BucketConfig, NodeKind,
};
@@ -50,9 +50,8 @@ pub async fn find_large_objects(
ignore_deltas: bool,
concurrency: usize,
) -> anyhow::Result<LargeObjectListing> {
let (remote_client, target) =
init_remote_generic(bucket_config.clone(), NodeKind::Pageserver).await?;
let tenants = pin!(stream_tenants_generic(&remote_client, &target));
let (remote_client, target) = init_remote(bucket_config.clone(), NodeKind::Pageserver).await?;
let tenants = pin!(stream_tenants(&remote_client, &target));
let objects_stream = tenants.map_ok(|tenant_shard_id| {
let mut tenant_root = target.tenant_root(&tenant_shard_id);

View File

@@ -19,8 +19,8 @@ use utils::id::TenantId;
use crate::{
cloud_admin_api::{CloudAdminApiClient, MaybeDeleted, ProjectData},
init_remote_generic, list_objects_with_retries_generic,
metadata_stream::{stream_tenant_timelines_generic, stream_tenants_generic},
init_remote, list_objects_with_retries,
metadata_stream::{stream_tenant_timelines, stream_tenants},
BucketConfig, ConsoleConfig, NodeKind, TenantShardTimelineId, TraversingDepth,
};
@@ -153,7 +153,7 @@ async fn find_garbage_inner(
node_kind: NodeKind,
) -> anyhow::Result<GarbageList> {
// Construct clients for S3 and for Console API
let (remote_client, target) = init_remote_generic(bucket_config.clone(), node_kind).await?;
let (remote_client, target) = init_remote(bucket_config.clone(), node_kind).await?;
let cloud_admin_api_client = Arc::new(CloudAdminApiClient::new(console_config));
// Build a set of console-known tenants, for quickly eliminating known-active tenants without having
@@ -179,7 +179,7 @@ async fn find_garbage_inner(
// Enumerate Tenants in S3, and check if each one exists in Console
tracing::info!("Finding all tenants in bucket {}...", bucket_config.bucket);
let tenants = stream_tenants_generic(&remote_client, &target);
let tenants = stream_tenants(&remote_client, &target);
let tenants_checked = tenants.map_ok(|t| {
let api_client = cloud_admin_api_client.clone();
let console_cache = console_cache.clone();
@@ -237,14 +237,13 @@ async fn find_garbage_inner(
// Special case: If it's missing in console, check for known bugs that would enable us to conclusively
// identify it as purge-able anyway
if console_result.is_none() {
let timelines =
stream_tenant_timelines_generic(&remote_client, &target, tenant_shard_id)
.await?
.collect::<Vec<_>>()
.await;
let timelines = stream_tenant_timelines(&remote_client, &target, tenant_shard_id)
.await?
.collect::<Vec<_>>()
.await;
if timelines.is_empty() {
// No timelines, but a heatmap: the deletion bug where we deleted everything but heatmaps
let tenant_objects = list_objects_with_retries_generic(
let tenant_objects = list_objects_with_retries(
&remote_client,
ListingMode::WithDelimiter,
&target.tenant_root(&tenant_shard_id),
@@ -265,7 +264,7 @@ async fn find_garbage_inner(
for timeline_r in timelines {
let timeline = timeline_r?;
let timeline_objects = list_objects_with_retries_generic(
let timeline_objects = list_objects_with_retries(
&remote_client,
ListingMode::WithDelimiter,
&target.timeline_root(&timeline),
@@ -331,8 +330,7 @@ async fn find_garbage_inner(
// Construct a stream of all timelines within active tenants
let active_tenants = tokio_stream::iter(active_tenants.iter().map(Ok));
let timelines =
active_tenants.map_ok(|t| stream_tenant_timelines_generic(&remote_client, &target, *t));
let timelines = active_tenants.map_ok(|t| stream_tenant_timelines(&remote_client, &target, *t));
let timelines = timelines.try_buffer_unordered(S3_CONCURRENCY);
let timelines = timelines.try_flatten();
@@ -507,7 +505,7 @@ pub async fn purge_garbage(
);
let (remote_client, _target) =
init_remote_generic(garbage_list.bucket_config.clone(), garbage_list.node_kind).await?;
init_remote(garbage_list.bucket_config.clone(), garbage_list.node_kind).await?;
assert_eq!(
&garbage_list.bucket_config.bucket,

View File

@@ -15,7 +15,7 @@ use std::fmt::Display;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{anyhow, Context};
use anyhow::Context;
use aws_config::retry::{RetryConfigBuilder, RetryMode};
use aws_sdk_s3::config::Region;
use aws_sdk_s3::error::DisplayErrorContext;
@@ -352,7 +352,7 @@ fn make_root_target(
}
}
async fn init_remote(
async fn init_remote_s3(
bucket_config: BucketConfig,
node_kind: NodeKind,
) -> anyhow::Result<(Arc<Client>, RootTarget)> {
@@ -369,7 +369,7 @@ async fn init_remote(
Ok((s3_client, s3_root))
}
async fn init_remote_generic(
async fn init_remote(
bucket_config: BucketConfig,
node_kind: NodeKind,
) -> anyhow::Result<(GenericRemoteStorage, RootTarget)> {
@@ -394,45 +394,10 @@ async fn init_remote_generic(
// We already pass the prefix to the remote client above
let prefix_in_root_target = String::new();
let s3_root = make_root_target(bucket_config.bucket, prefix_in_root_target, node_kind);
let root_target = make_root_target(bucket_config.bucket, prefix_in_root_target, node_kind);
let client = GenericRemoteStorage::from_config(&storage_config).await?;
Ok((client, s3_root))
}
async fn list_objects_with_retries(
s3_client: &Client,
s3_target: &S3Target,
continuation_token: Option<String>,
) -> anyhow::Result<aws_sdk_s3::operation::list_objects_v2::ListObjectsV2Output> {
for trial in 0..MAX_RETRIES {
match s3_client
.list_objects_v2()
.bucket(&s3_target.bucket_name)
.prefix(&s3_target.prefix_in_bucket)
.delimiter(&s3_target.delimiter)
.set_continuation_token(continuation_token.clone())
.send()
.await
{
Ok(response) => return Ok(response),
Err(e) => {
if trial == MAX_RETRIES - 1 {
return Err(e)
.with_context(|| format!("Failed to list objects {MAX_RETRIES} times"));
}
error!(
"list_objects_v2 query failed: bucket_name={}, prefix={}, delimiter={}, error={}",
s3_target.bucket_name,
s3_target.prefix_in_bucket,
s3_target.delimiter,
DisplayErrorContext(e),
);
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
}
Err(anyhow!("unreachable unless MAX_RETRIES==0"))
Ok((client, root_target))
}
/// Listing possibly large amounts of keys in a streaming fashion.
@@ -452,23 +417,26 @@ fn stream_objects_with_retries<'a>(
let mut list_stream =
storage_client.list_streaming(Some(&prefix), listing_mode, None, &cancel);
while let Some(res) = list_stream.next().await {
if let Err(err) = res {
let yield_err = if err.is_permanent() {
true
} else {
let backoff_time = 1 << trial.max(5);
tokio::time::sleep(Duration::from_secs(backoff_time)).await;
trial += 1;
trial == MAX_RETRIES - 1
};
if yield_err {
yield Err(err)
.with_context(|| format!("Failed to list objects {MAX_RETRIES} times"));
break;
match res {
Err(err) => {
let yield_err = if err.is_permanent() {
true
} else {
let backoff_time = 1 << trial.max(5);
tokio::time::sleep(Duration::from_secs(backoff_time)).await;
trial += 1;
trial == MAX_RETRIES - 1
};
if yield_err {
yield Err(err)
.with_context(|| format!("Failed to list objects {MAX_RETRIES} times"));
break;
}
}
Ok(res) => {
trial = 0;
yield Ok(res);
}
} else {
trial = 0;
yield res.map_err(anyhow::Error::from);
}
}
}
@@ -476,7 +444,7 @@ fn stream_objects_with_retries<'a>(
/// If you want to list a bounded amount of prefixes or keys. For larger numbers of keys/prefixes,
/// use [`stream_objects_with_retries`] instead.
async fn list_objects_with_retries_generic(
async fn list_objects_with_retries(
remote_client: &GenericRemoteStorage,
listing_mode: ListingMode,
s3_target: &S3Target,
@@ -514,40 +482,34 @@ async fn list_objects_with_retries_generic(
}
async fn download_object_with_retries(
s3_client: &Client,
bucket_name: &str,
key: &str,
remote_client: &GenericRemoteStorage,
key: &RemotePath,
) -> anyhow::Result<Vec<u8>> {
for _ in 0..MAX_RETRIES {
let mut body_buf = Vec::new();
let response_stream = match s3_client
.get_object()
.bucket(bucket_name)
.key(key)
.send()
.await
{
let cancel = CancellationToken::new();
for trial in 0..MAX_RETRIES {
let mut buf = Vec::new();
let download = match remote_client.download(key, &cancel).await {
Ok(response) => response,
Err(e) => {
error!("Failed to download object for key {key}: {e}");
tokio::time::sleep(Duration::from_secs(1)).await;
let backoff_time = 1 << trial.max(5);
tokio::time::sleep(Duration::from_secs(backoff_time)).await;
continue;
}
};
match response_stream
.body
.into_async_read()
.read_to_end(&mut body_buf)
match tokio_util::io::StreamReader::new(download.download_stream)
.read_to_end(&mut buf)
.await
{
Ok(bytes_read) => {
tracing::debug!("Downloaded {bytes_read} bytes for object {key}");
return Ok(body_buf);
return Ok(buf);
}
Err(e) => {
error!("Failed to stream object body for key {key}: {e}");
tokio::time::sleep(Duration::from_secs(1)).await;
let backoff_time = 1 << trial.max(5);
tokio::time::sleep(Duration::from_secs(backoff_time)).await;
}
}
}
@@ -555,7 +517,7 @@ async fn download_object_with_retries(
anyhow::bail!("Failed to download objects with key {key} {MAX_RETRIES} times")
}
async fn download_object_to_file(
async fn download_object_to_file_s3(
s3_client: &Client,
bucket_name: &str,
key: &str,

View File

@@ -2,7 +2,6 @@ use std::str::FromStr;
use anyhow::{anyhow, Context};
use async_stream::{stream, try_stream};
use aws_sdk_s3::{types::ObjectIdentifier, Client};
use futures::StreamExt;
use remote_storage::{GenericRemoteStorage, ListingMode, ListingObject, RemotePath};
use tokio_stream::Stream;
@@ -15,7 +14,7 @@ use pageserver_api::shard::TenantShardId;
use utils::id::{TenantId, TimelineId};
/// Given a remote storage and a target, output a stream of TenantIds discovered via listing prefixes
pub fn stream_tenants_generic<'a>(
pub fn stream_tenants<'a>(
remote_client: &'a GenericRemoteStorage,
target: &'a RootTarget,
) -> impl Stream<Item = anyhow::Result<TenantShardId>> + 'a {
@@ -36,92 +35,36 @@ pub fn stream_tenants_generic<'a>(
}
}
/// Given an S3 bucket, output a stream of TenantIds discovered via ListObjectsv2
pub fn stream_tenants<'a>(
s3_client: &'a Client,
target: &'a RootTarget,
) -> impl Stream<Item = anyhow::Result<TenantShardId>> + 'a {
try_stream! {
let mut continuation_token = None;
let tenants_target = target.tenants_root();
loop {
let fetch_response =
list_objects_with_retries(s3_client, &tenants_target, continuation_token.clone()).await?;
let new_entry_ids = fetch_response
.common_prefixes()
.iter()
.filter_map(|prefix| prefix.prefix())
.filter_map(|prefix| -> Option<&str> {
prefix
.strip_prefix(&tenants_target.prefix_in_bucket)?
.strip_suffix('/')
}).map(|entry_id_str| {
entry_id_str
.parse()
.with_context(|| format!("Incorrect entry id str: {entry_id_str}"))
});
for i in new_entry_ids {
yield i?;
}
match fetch_response.next_continuation_token {
Some(new_token) => continuation_token = Some(new_token),
None => break,
}
}
}
}
pub async fn stream_tenant_shards<'a>(
s3_client: &'a Client,
remote_client: &'a GenericRemoteStorage,
target: &'a RootTarget,
tenant_id: TenantId,
) -> anyhow::Result<impl Stream<Item = Result<TenantShardId, anyhow::Error>> + 'a> {
let mut tenant_shard_ids: Vec<Result<TenantShardId, anyhow::Error>> = Vec::new();
let mut continuation_token = None;
let shards_target = target.tenant_shards_prefix(&tenant_id);
loop {
tracing::info!("Listing in {}", shards_target.prefix_in_bucket);
let fetch_response =
list_objects_with_retries(s3_client, &shards_target, continuation_token.clone()).await;
let fetch_response = match fetch_response {
Err(e) => {
tenant_shard_ids.push(Err(e));
break;
}
Ok(r) => r,
};
let strip_prefix = target.tenants_root().prefix_in_bucket;
let prefix_str = &strip_prefix.strip_prefix("/").unwrap_or(&strip_prefix);
let new_entry_ids = fetch_response
.common_prefixes()
.iter()
.filter_map(|prefix| prefix.prefix())
.filter_map(|prefix| -> Option<&str> {
prefix
.strip_prefix(&target.tenants_root().prefix_in_bucket)?
.strip_suffix('/')
})
.map(|entry_id_str| {
let first_part = entry_id_str.split('/').next().unwrap();
tracing::info!("Listing shards in {}", shards_target.prefix_in_bucket);
let listing =
list_objects_with_retries(remote_client, ListingMode::WithDelimiter, &shards_target)
.await?;
first_part
.parse::<TenantShardId>()
.with_context(|| format!("Incorrect entry id str: {first_part}"))
});
let tenant_shard_ids = listing
.prefixes
.iter()
.map(|prefix| prefix.get_path().as_str())
.filter_map(|prefix| -> Option<&str> { prefix.strip_prefix(prefix_str) })
.map(|entry_id_str| {
let first_part = entry_id_str.split('/').next().unwrap();
for i in new_entry_ids {
tenant_shard_ids.push(i);
}
match fetch_response.next_continuation_token {
Some(new_token) => continuation_token = Some(new_token),
None => break,
}
}
first_part
.parse::<TenantShardId>()
.with_context(|| format!("Incorrect entry id str: {first_part}"))
})
.collect::<Vec<_>>();
tracing::debug!("Yielding {} shards for {tenant_id}", tenant_shard_ids.len());
Ok(stream! {
for i in tenant_shard_ids {
let id = i?;
@@ -130,69 +73,10 @@ pub async fn stream_tenant_shards<'a>(
})
}
/// Given a TenantShardId, output a stream of the timelines within that tenant, discovered
/// using ListObjectsv2. The listing is done before the stream is built, so that this
/// function can be used to generate concurrency on a stream using buffer_unordered.
pub async fn stream_tenant_timelines<'a>(
s3_client: &'a Client,
target: &'a RootTarget,
tenant: TenantShardId,
) -> anyhow::Result<impl Stream<Item = Result<TenantShardTimelineId, anyhow::Error>> + 'a> {
let mut timeline_ids: Vec<Result<TimelineId, anyhow::Error>> = Vec::new();
let mut continuation_token = None;
let timelines_target = target.timelines_root(&tenant);
loop {
tracing::debug!("Listing in {}", tenant);
let fetch_response =
list_objects_with_retries(s3_client, &timelines_target, continuation_token.clone())
.await;
let fetch_response = match fetch_response {
Err(e) => {
timeline_ids.push(Err(e));
break;
}
Ok(r) => r,
};
let new_entry_ids = fetch_response
.common_prefixes()
.iter()
.filter_map(|prefix| prefix.prefix())
.filter_map(|prefix| -> Option<&str> {
prefix
.strip_prefix(&timelines_target.prefix_in_bucket)?
.strip_suffix('/')
})
.map(|entry_id_str| {
entry_id_str
.parse::<TimelineId>()
.with_context(|| format!("Incorrect entry id str: {entry_id_str}"))
});
for i in new_entry_ids {
timeline_ids.push(i);
}
match fetch_response.next_continuation_token {
Some(new_token) => continuation_token = Some(new_token),
None => break,
}
}
tracing::debug!("Yielding for {}", tenant);
Ok(stream! {
for i in timeline_ids {
let id = i?;
yield Ok(TenantShardTimelineId::new(tenant, id));
}
})
}
/// Given a `TenantShardId`, output a stream of the timelines within that tenant, discovered
/// using a listing. The listing is done before the stream is built, so that this
/// function can be used to generate concurrency on a stream using buffer_unordered.
pub async fn stream_tenant_timelines_generic<'a>(
pub async fn stream_tenant_timelines<'a>(
remote_client: &'a GenericRemoteStorage,
target: &'a RootTarget,
tenant: TenantShardId,
@@ -200,6 +84,11 @@ pub async fn stream_tenant_timelines_generic<'a>(
let mut timeline_ids: Vec<Result<TimelineId, anyhow::Error>> = Vec::new();
let timelines_target = target.timelines_root(&tenant);
let prefix_str = &timelines_target
.prefix_in_bucket
.strip_prefix("/")
.unwrap_or(&timelines_target.prefix_in_bucket);
let mut objects_stream = std::pin::pin!(stream_objects_with_retries(
remote_client,
ListingMode::WithDelimiter,
@@ -220,11 +109,7 @@ pub async fn stream_tenant_timelines_generic<'a>(
.prefixes
.iter()
.filter_map(|prefix| -> Option<&str> {
prefix
.get_path()
.as_str()
.strip_prefix(&timelines_target.prefix_in_bucket)?
.strip_suffix('/')
prefix.get_path().as_str().strip_prefix(prefix_str)
})
.map(|entry_id_str| {
entry_id_str
@@ -237,7 +122,7 @@ pub async fn stream_tenant_timelines_generic<'a>(
}
}
tracing::debug!("Yielding for {}", tenant);
tracing::debug!("Yielding {} timelines for {}", timeline_ids.len(), tenant);
Ok(stream! {
for i in timeline_ids {
let id = i?;
@@ -247,37 +132,6 @@ pub async fn stream_tenant_timelines_generic<'a>(
}
pub(crate) fn stream_listing<'a>(
s3_client: &'a Client,
target: &'a S3Target,
) -> impl Stream<Item = anyhow::Result<ObjectIdentifier>> + 'a {
try_stream! {
let mut continuation_token = None;
loop {
let fetch_response =
list_objects_with_retries(s3_client, target, continuation_token.clone()).await?;
if target.delimiter.is_empty() {
for object_key in fetch_response.contents().iter().filter_map(|object| object.key())
{
let object_id = ObjectIdentifier::builder().key(object_key).build()?;
yield object_id;
}
} else {
for prefix in fetch_response.common_prefixes().iter().filter_map(|p| p.prefix()) {
let object_id = ObjectIdentifier::builder().key(prefix).build()?;
yield object_id;
}
}
match fetch_response.next_continuation_token {
Some(new_token) => continuation_token = Some(new_token),
None => break,
}
}
}
}
pub(crate) fn stream_listing_generic<'a>(
remote_client: &'a GenericRemoteStorage,
target: &'a S3Target,
) -> impl Stream<Item = anyhow::Result<(RemotePath, Option<ListingObject>)>> + 'a {

View File

@@ -1,11 +1,10 @@
use std::collections::{BTreeMap, BTreeSet, HashMap};
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use std::time::Duration;
use crate::checks::{list_timeline_blobs, BlobDataParseResult};
use crate::metadata_stream::{stream_tenant_timelines, stream_tenants};
use crate::{init_remote, BucketConfig, NodeKind, RootTarget, TenantShardTimelineId};
use aws_sdk_s3::Client;
use futures_util::{StreamExt, TryStreamExt};
use pageserver::tenant::remote_timeline_client::index::LayerFileMetadata;
use pageserver::tenant::remote_timeline_client::{parse_remote_index_path, remote_layer_path};
@@ -13,10 +12,11 @@ use pageserver::tenant::storage_layer::LayerName;
use pageserver::tenant::IndexPart;
use pageserver_api::controller_api::TenantDescribeResponse;
use pageserver_api::shard::{ShardIndex, TenantShardId};
use remote_storage::RemotePath;
use remote_storage::{GenericRemoteStorage, ListingObject, RemotePath};
use reqwest::Method;
use serde::Serialize;
use storage_controller_client::control_api;
use tokio_util::sync::CancellationToken;
use tracing::{info_span, Instrument};
use utils::generation::Generation;
use utils::id::{TenantId, TenantTimelineId};
@@ -240,38 +240,13 @@ impl TenantRefAccumulator {
}
}
async fn is_old_enough(
s3_client: &Client,
bucket_config: &BucketConfig,
min_age: &Duration,
key: &str,
summary: &mut GcSummary,
) -> bool {
fn is_old_enough(min_age: &Duration, key: &ListingObject, summary: &mut GcSummary) -> bool {
// Validation: we will only GC indices & layers after a time threshold (e.g. one week) so that during an incident
// it is easier to read old data for analysis, and easier to roll back shard splits without having to un-delete any objects.
let age: Duration = match s3_client
.head_object()
.bucket(&bucket_config.bucket)
.key(key)
.send()
.await
{
Ok(response) => match response.last_modified {
None => {
tracing::warn!("Missing last_modified");
summary.remote_storage_errors += 1;
return false;
}
Some(last_modified) => match SystemTime::try_from(last_modified).map(|t| t.elapsed()) {
Ok(Ok(e)) => e,
Err(_) | Ok(Err(_)) => {
tracing::warn!("Bad last_modified time: {last_modified:?}");
return false;
}
},
},
Err(e) => {
tracing::warn!("Failed to HEAD {key}: {e}");
let age = match key.last_modified.elapsed() {
Ok(e) => e,
Err(_) => {
tracing::warn!("Bad last_modified time: {:?}", key.last_modified);
summary.remote_storage_errors += 1;
return false;
}
@@ -289,17 +264,30 @@ async fn is_old_enough(
old_enough
}
/// Same as [`is_old_enough`], but doesn't require a [`ListingObject`] passed to it.
async fn check_is_old_enough(
remote_client: &GenericRemoteStorage,
key: &RemotePath,
min_age: &Duration,
summary: &mut GcSummary,
) -> Option<bool> {
let listing_object = remote_client
.head_object(key, &CancellationToken::new())
.await
.ok()?;
Some(is_old_enough(min_age, &listing_object, summary))
}
async fn maybe_delete_index(
s3_client: &Client,
bucket_config: &BucketConfig,
remote_client: &GenericRemoteStorage,
min_age: &Duration,
latest_gen: Generation,
key: &str,
obj: &ListingObject,
mode: GcMode,
summary: &mut GcSummary,
) {
// Validation: we will only delete things that parse cleanly
let basename = key.rsplit_once('/').unwrap().1;
let basename = obj.key.get_path().file_name().unwrap();
let candidate_generation =
match parse_remote_index_path(RemotePath::from_string(basename).unwrap()) {
Some(g) => g,
@@ -328,7 +316,7 @@ async fn maybe_delete_index(
return;
}
if !is_old_enough(s3_client, bucket_config, min_age, key, summary).await {
if !is_old_enough(min_age, obj, summary) {
return;
}
@@ -338,11 +326,8 @@ async fn maybe_delete_index(
}
// All validations passed: erase the object
match s3_client
.delete_object()
.bucket(&bucket_config.bucket)
.key(key)
.send()
match remote_client
.delete(&obj.key, &CancellationToken::new())
.await
{
Ok(_) => {
@@ -358,8 +343,7 @@ async fn maybe_delete_index(
#[allow(clippy::too_many_arguments)]
async fn gc_ancestor(
s3_client: &Client,
bucket_config: &BucketConfig,
remote_client: &GenericRemoteStorage,
root_target: &RootTarget,
min_age: &Duration,
ancestor: TenantShardId,
@@ -368,7 +352,7 @@ async fn gc_ancestor(
summary: &mut GcSummary,
) -> anyhow::Result<()> {
// Scan timelines in the ancestor
let timelines = stream_tenant_timelines(s3_client, root_target, ancestor).await?;
let timelines = stream_tenant_timelines(remote_client, root_target, ancestor).await?;
let mut timelines = std::pin::pin!(timelines);
// Build a list of keys to retain
@@ -376,7 +360,7 @@ async fn gc_ancestor(
while let Some(ttid) = timelines.next().await {
let ttid = ttid?;
let data = list_timeline_blobs(s3_client, ttid, root_target).await?;
let data = list_timeline_blobs(remote_client, ttid, root_target).await?;
let s3_layers = match data.blob_data {
BlobDataParseResult::Parsed {
@@ -427,7 +411,8 @@ async fn gc_ancestor(
// We apply a time threshold to GCing objects that are un-referenced: this preserves our ability
// to roll back a shard split if we have to, by avoiding deleting ancestor layers right away
if !is_old_enough(s3_client, bucket_config, min_age, &key, summary).await {
let path = RemotePath::from_string(key.strip_prefix("/").unwrap_or(&key)).unwrap();
if check_is_old_enough(remote_client, &path, min_age, summary).await != Some(true) {
continue;
}
@@ -437,13 +422,7 @@ async fn gc_ancestor(
}
// All validations passed: erase the object
match s3_client
.delete_object()
.bucket(&bucket_config.bucket)
.key(&key)
.send()
.await
{
match remote_client.delete(&path, &CancellationToken::new()).await {
Ok(_) => {
tracing::info!("Successfully deleted unreferenced ancestor layer {key}");
summary.ancestor_layers_deleted += 1;
@@ -477,10 +456,10 @@ pub async fn pageserver_physical_gc(
min_age: Duration,
mode: GcMode,
) -> anyhow::Result<GcSummary> {
let (s3_client, target) = init_remote(bucket_config.clone(), NodeKind::Pageserver).await?;
let (remote_client, target) = init_remote(bucket_config.clone(), NodeKind::Pageserver).await?;
let tenants = if tenant_shard_ids.is_empty() {
futures::future::Either::Left(stream_tenants(&s3_client, &target))
futures::future::Either::Left(stream_tenants(&remote_client, &target))
} else {
futures::future::Either::Right(futures::stream::iter(tenant_shard_ids.into_iter().map(Ok)))
};
@@ -493,14 +472,13 @@ pub async fn pageserver_physical_gc(
let accumulator = Arc::new(std::sync::Mutex::new(TenantRefAccumulator::default()));
// Generate a stream of TenantTimelineId
let timelines = tenants.map_ok(|t| stream_tenant_timelines(&s3_client, &target, t));
let timelines = tenants.map_ok(|t| stream_tenant_timelines(&remote_client, &target, t));
let timelines = timelines.try_buffered(CONCURRENCY);
let timelines = timelines.try_flatten();
// Generate a stream of S3TimelineBlobData
async fn gc_timeline(
s3_client: &Client,
bucket_config: &BucketConfig,
remote_client: &GenericRemoteStorage,
min_age: &Duration,
target: &RootTarget,
mode: GcMode,
@@ -508,7 +486,7 @@ pub async fn pageserver_physical_gc(
accumulator: &Arc<std::sync::Mutex<TenantRefAccumulator>>,
) -> anyhow::Result<GcSummary> {
let mut summary = GcSummary::default();
let data = list_timeline_blobs(s3_client, ttid, target).await?;
let data = list_timeline_blobs(remote_client, ttid, target).await?;
let (index_part, latest_gen, candidates) = match &data.blob_data {
BlobDataParseResult::Parsed {
@@ -533,17 +511,9 @@ pub async fn pageserver_physical_gc(
accumulator.lock().unwrap().update(ttid, index_part);
for key in candidates {
maybe_delete_index(
s3_client,
bucket_config,
min_age,
latest_gen,
&key,
mode,
&mut summary,
)
.instrument(info_span!("maybe_delete_index", %ttid, ?latest_gen, key))
.await;
maybe_delete_index(remote_client, min_age, latest_gen, &key, mode, &mut summary)
.instrument(info_span!("maybe_delete_index", %ttid, ?latest_gen, %key.key))
.await;
}
Ok(summary)
@@ -554,15 +524,7 @@ pub async fn pageserver_physical_gc(
// Drain futures for per-shard GC, populating accumulator as a side effect
{
let timelines = timelines.map_ok(|ttid| {
gc_timeline(
&s3_client,
bucket_config,
&min_age,
&target,
mode,
ttid,
&accumulator,
)
gc_timeline(&remote_client, &min_age, &target, mode, ttid, &accumulator)
});
let mut timelines = std::pin::pin!(timelines.try_buffered(CONCURRENCY));
@@ -586,8 +548,7 @@ pub async fn pageserver_physical_gc(
for ancestor_shard in ancestor_shards {
gc_ancestor(
&s3_client,
bucket_config,
&remote_client,
&target,
&min_age,
ancestor_shard,

View File

@@ -1,16 +1,16 @@
use std::collections::{HashMap, HashSet};
use crate::checks::{
branch_cleanup_and_check_errors, list_timeline_blobs, BlobDataParseResult, S3TimelineBlobData,
TenantObjectListing, TimelineAnalysis,
branch_cleanup_and_check_errors, list_timeline_blobs, BlobDataParseResult,
RemoteTimelineBlobData, TenantObjectListing, TimelineAnalysis,
};
use crate::metadata_stream::{stream_tenant_timelines, stream_tenants};
use crate::{init_remote, BucketConfig, NodeKind, RootTarget, TenantShardTimelineId};
use aws_sdk_s3::Client;
use futures_util::{StreamExt, TryStreamExt};
use pageserver::tenant::remote_timeline_client::remote_layer_path;
use pageserver_api::controller_api::MetadataHealthUpdateRequest;
use pageserver_api::shard::TenantShardId;
use remote_storage::GenericRemoteStorage;
use serde::Serialize;
use utils::id::TenantId;
use utils::shard::ShardCount;
@@ -36,7 +36,7 @@ impl MetadataSummary {
Self::default()
}
fn update_data(&mut self, data: &S3TimelineBlobData) {
fn update_data(&mut self, data: &RemoteTimelineBlobData) {
self.timeline_shard_count += 1;
if let BlobDataParseResult::Parsed {
index_part,
@@ -120,10 +120,10 @@ pub async fn scan_pageserver_metadata(
bucket_config: BucketConfig,
tenant_ids: Vec<TenantShardId>,
) -> anyhow::Result<MetadataSummary> {
let (s3_client, target) = init_remote(bucket_config, NodeKind::Pageserver).await?;
let (remote_client, target) = init_remote(bucket_config, NodeKind::Pageserver).await?;
let tenants = if tenant_ids.is_empty() {
futures::future::Either::Left(stream_tenants(&s3_client, &target))
futures::future::Either::Left(stream_tenants(&remote_client, &target))
} else {
futures::future::Either::Right(futures::stream::iter(tenant_ids.into_iter().map(Ok)))
};
@@ -133,20 +133,20 @@ pub async fn scan_pageserver_metadata(
const CONCURRENCY: usize = 32;
// Generate a stream of TenantTimelineId
let timelines = tenants.map_ok(|t| stream_tenant_timelines(&s3_client, &target, t));
let timelines = tenants.map_ok(|t| stream_tenant_timelines(&remote_client, &target, t));
let timelines = timelines.try_buffered(CONCURRENCY);
let timelines = timelines.try_flatten();
// Generate a stream of S3TimelineBlobData
async fn report_on_timeline(
s3_client: &Client,
remote_client: &GenericRemoteStorage,
target: &RootTarget,
ttid: TenantShardTimelineId,
) -> anyhow::Result<(TenantShardTimelineId, S3TimelineBlobData)> {
let data = list_timeline_blobs(s3_client, ttid, target).await?;
) -> anyhow::Result<(TenantShardTimelineId, RemoteTimelineBlobData)> {
let data = list_timeline_blobs(remote_client, ttid, target).await?;
Ok((ttid, data))
}
let timelines = timelines.map_ok(|ttid| report_on_timeline(&s3_client, &target, ttid));
let timelines = timelines.map_ok(|ttid| report_on_timeline(&remote_client, &target, ttid));
let mut timelines = std::pin::pin!(timelines.try_buffered(CONCURRENCY));
// We must gather all the TenantShardTimelineId->S3TimelineBlobData for each tenant, because different
@@ -157,12 +157,11 @@ pub async fn scan_pageserver_metadata(
let mut tenant_timeline_results = Vec::new();
async fn analyze_tenant(
s3_client: &Client,
target: &RootTarget,
remote_client: &GenericRemoteStorage,
tenant_id: TenantId,
summary: &mut MetadataSummary,
mut tenant_objects: TenantObjectListing,
timelines: Vec<(TenantShardTimelineId, S3TimelineBlobData)>,
timelines: Vec<(TenantShardTimelineId, RemoteTimelineBlobData)>,
highest_shard_count: ShardCount,
) {
summary.tenant_count += 1;
@@ -191,8 +190,7 @@ pub async fn scan_pageserver_metadata(
// Apply checks to this timeline shard's metadata, and in the process update `tenant_objects`
// reference counts for layers across the tenant.
let analysis = branch_cleanup_and_check_errors(
s3_client,
target,
remote_client,
&ttid,
&mut tenant_objects,
None,
@@ -273,8 +271,7 @@ pub async fn scan_pageserver_metadata(
let tenant_objects = std::mem::take(&mut tenant_objects);
let timelines = std::mem::take(&mut tenant_timeline_results);
analyze_tenant(
&s3_client,
&target,
&remote_client,
prev_tenant_id,
&mut summary,
tenant_objects,
@@ -311,8 +308,7 @@ pub async fn scan_pageserver_metadata(
if !tenant_timeline_results.is_empty() {
analyze_tenant(
&s3_client,
&target,
&remote_client,
tenant_id.expect("Must be set if results are present"),
&mut summary,
tenant_objects,

View File

@@ -14,9 +14,8 @@ use utils::{
};
use crate::{
cloud_admin_api::CloudAdminApiClient, init_remote_generic,
metadata_stream::stream_listing_generic, BucketConfig, ConsoleConfig, NodeKind, RootTarget,
TenantShardTimelineId,
cloud_admin_api::CloudAdminApiClient, init_remote, metadata_stream::stream_listing,
BucketConfig, ConsoleConfig, NodeKind, RootTarget, TenantShardTimelineId,
};
/// Generally we should ask safekeepers, but so far we use everywhere default 16MB.
@@ -107,7 +106,7 @@ pub async fn scan_safekeeper_metadata(
let timelines = client.query(&query, &[]).await?;
info!("loaded {} timelines", timelines.len());
let (remote_client, target) = init_remote_generic(bucket_config, NodeKind::Safekeeper).await?;
let (remote_client, target) = init_remote(bucket_config, NodeKind::Safekeeper).await?;
let console_config = ConsoleConfig::from_env()?;
let cloud_admin_api_client = CloudAdminApiClient::new(console_config);
@@ -188,14 +187,19 @@ async fn check_timeline(
// we need files, so unset it.
timeline_dir_target.delimiter = String::new();
let mut stream = std::pin::pin!(stream_listing_generic(remote_client, &timeline_dir_target));
let prefix_str = &timeline_dir_target
.prefix_in_bucket
.strip_prefix("/")
.unwrap_or(&timeline_dir_target.prefix_in_bucket);
let mut stream = std::pin::pin!(stream_listing(remote_client, &timeline_dir_target));
while let Some(obj) = stream.next().await {
let (key, _obj) = obj?;
let seg_name = key
.get_path()
.as_str()
.strip_prefix(&timeline_dir_target.prefix_in_bucket)
.strip_prefix(prefix_str)
.expect("failed to extract segment name");
expected_segfiles.remove(seg_name);
}

View File

@@ -1,10 +1,11 @@
use std::collections::HashMap;
use std::sync::Arc;
use crate::checks::{list_timeline_blobs, BlobDataParseResult, S3TimelineBlobData};
use crate::checks::{list_timeline_blobs, BlobDataParseResult, RemoteTimelineBlobData};
use crate::metadata_stream::{stream_tenant_shards, stream_tenant_timelines};
use crate::{
download_object_to_file, init_remote, BucketConfig, NodeKind, RootTarget, TenantShardTimelineId,
download_object_to_file_s3, init_remote, init_remote_s3, BucketConfig, NodeKind, RootTarget,
TenantShardTimelineId,
};
use anyhow::Context;
use async_stream::stream;
@@ -15,6 +16,7 @@ use pageserver::tenant::remote_timeline_client::index::LayerFileMetadata;
use pageserver::tenant::storage_layer::LayerName;
use pageserver::tenant::IndexPart;
use pageserver_api::shard::TenantShardId;
use remote_storage::GenericRemoteStorage;
use utils::generation::Generation;
use utils::id::TenantId;
@@ -34,7 +36,8 @@ impl SnapshotDownloader {
output_path: Utf8PathBuf,
concurrency: usize,
) -> anyhow::Result<Self> {
let (s3_client, s3_root) = init_remote(bucket_config.clone(), NodeKind::Pageserver).await?;
let (s3_client, s3_root) =
init_remote_s3(bucket_config.clone(), NodeKind::Pageserver).await?;
Ok(Self {
s3_client,
s3_root,
@@ -91,7 +94,7 @@ impl SnapshotDownloader {
let Some(version) = versions.versions.as_ref().and_then(|v| v.first()) else {
return Err(anyhow::anyhow!("No versions found for {remote_layer_path}"));
};
download_object_to_file(
download_object_to_file_s3(
&self.s3_client,
&self.bucket_config.bucket,
&remote_layer_path,
@@ -215,11 +218,11 @@ impl SnapshotDownloader {
}
pub async fn download(&self) -> anyhow::Result<()> {
let (s3_client, target) =
let (remote_client, target) =
init_remote(self.bucket_config.clone(), NodeKind::Pageserver).await?;
// Generate a stream of TenantShardId
let shards = stream_tenant_shards(&s3_client, &target, self.tenant_id).await?;
let shards = stream_tenant_shards(&remote_client, &target, self.tenant_id).await?;
let shards: Vec<TenantShardId> = shards.try_collect().await?;
// Only read from shards that have the highest count: avoids redundantly downloading
@@ -237,18 +240,19 @@ impl SnapshotDownloader {
for shard in shards.into_iter().filter(|s| s.shard_count == shard_count) {
// Generate a stream of TenantTimelineId
let timelines = stream_tenant_timelines(&s3_client, &self.s3_root, shard).await?;
let timelines = stream_tenant_timelines(&remote_client, &target, shard).await?;
// Generate a stream of S3TimelineBlobData
async fn load_timeline_index(
s3_client: &Client,
remote_client: &GenericRemoteStorage,
target: &RootTarget,
ttid: TenantShardTimelineId,
) -> anyhow::Result<(TenantShardTimelineId, S3TimelineBlobData)> {
let data = list_timeline_blobs(s3_client, ttid, target).await?;
) -> anyhow::Result<(TenantShardTimelineId, RemoteTimelineBlobData)> {
let data = list_timeline_blobs(remote_client, ttid, target).await?;
Ok((ttid, data))
}
let timelines = timelines.map_ok(|ttid| load_timeline_index(&s3_client, &target, ttid));
let timelines =
timelines.map_ok(|ttid| load_timeline_index(&remote_client, &target, ttid));
let mut timelines = std::pin::pin!(timelines.try_buffered(8));
while let Some(i) = timelines.next().await {
@@ -278,7 +282,7 @@ impl SnapshotDownloader {
for (ttid, layers) in ancestor_layers.into_iter() {
tracing::info!(
"Downloading {} layers from ancvestor timeline {ttid}...",
"Downloading {} layers from ancestor timeline {ttid}...",
layers.len()
);

View File

@@ -71,8 +71,7 @@ a subdirectory for each version with naming convention `v{PG_VERSION}/`.
Inside that dir, a `bin/postgres` binary should be present.
`DEFAULT_PG_VERSION`: The version of Postgres to use,
This is used to construct full path to the postgres binaries.
Format is 2-digit major version nubmer, i.e. `DEFAULT_PG_VERSION="14"`. Alternatively,
you can use `--pg-version` argument.
Format is 2-digit major version nubmer, i.e. `DEFAULT_PG_VERSION=16`
`TEST_OUTPUT`: Set the directory where test state and test output files
should go.
`TEST_SHARED_FIXTURES`: Try to re-use a single pageserver for all the tests.

View File

@@ -4643,6 +4643,7 @@ class StorageScrubber:
]
args = base_args + args
log.info(f"Invoking scrubber command {args} with env: {env}")
(output_path, stdout, status_code) = subprocess_capture(
self.log_dir,
args,

Some files were not shown because too many files have changed in this diff Show More