Compare commits

..

8 Commits

Author SHA1 Message Date
Conrad Ludgate
2cca1b3e4e fix 2024-08-21 18:44:57 +01:00
Conrad Ludgate
471b3b300d fix pin 2024-08-21 16:29:52 +01:00
Conrad Ludgate
fbd4b91169 asyncreadready 2024-08-21 16:16:49 +01:00
Conrad Ludgate
8cc45ad9bd asrawfd things 2024-08-21 15:28:25 +01:00
Conrad Ludgate
aabbd55187 add ktls handling 2024-08-21 14:42:41 +01:00
Conrad Ludgate
987a859352 start integrating ktls 2024-08-21 14:11:58 +01:00
Conrad Ludgate
e171fd805b add ktls dep 2024-08-21 13:51:02 +01:00
Conrad Ludgate
1e4702b26a update rustls 2024-08-21 13:47:19 +01:00
142 changed files with 1983 additions and 3854 deletions

View File

@@ -23,30 +23,10 @@ platforms = [
]
[final-excludes]
workspace-members = [
# vm_monitor benefits from the same Cargo.lock as the rest of our artifacts, but
# it is built primarly in separate repo neondatabase/autoscaling and thus is excluded
# from depending on workspace-hack because most of the dependencies are not used.
"vm_monitor",
# All of these exist in libs and are not usually built independently.
# Putting workspace hack there adds a bottleneck for cargo builds.
"compute_api",
"consumption_metrics",
"desim",
"metrics",
"pageserver_api",
"postgres_backend",
"postgres_connection",
"postgres_ffi",
"pq_proto",
"remote_storage",
"safekeeper_api",
"tenant_size_model",
"tracing-utils",
"utils",
"wal_craft",
"walproposer",
]
# vm_monitor benefits from the same Cargo.lock as the rest of our artifacts, but
# it is built primarly in separate repo neondatabase/autoscaling and thus is excluded
# from depending on workspace-hack because most of the dependencies are not used.
workspace-members = ["vm_monitor"]
# Write out exact versions rather than a semver range. (Defaults to false.)
# exact-versions = true

View File

@@ -169,8 +169,10 @@ runs:
EXTRA_PARAMS="--durations-path $TEST_OUTPUT/benchmark_durations.json $EXTRA_PARAMS"
fi
if [[ $BUILD_TYPE == "debug" && $RUNNER_ARCH == 'X64' ]]; then
if [[ "${{ inputs.build_type }}" == "debug" ]]; then
cov_prefix=(scripts/coverage "--profraw-prefix=$GITHUB_JOB" --dir=/tmp/coverage run)
elif [[ "${{ inputs.build_type }}" == "release" ]]; then
cov_prefix=()
else
cov_prefix=()
fi

View File

@@ -94,16 +94,11 @@ jobs:
# We run tests with addtional features, that are turned off by default (e.g. in release builds), see
# corresponding Cargo.toml files for their descriptions.
- name: Set env variables
env:
ARCH: ${{ inputs.arch }}
run: |
CARGO_FEATURES="--features testing"
if [[ $BUILD_TYPE == "debug" && $ARCH == 'x64' ]]; then
if [[ $BUILD_TYPE == "debug" ]]; then
cov_prefix="scripts/coverage --profraw-prefix=$GITHUB_JOB --dir=/tmp/coverage run"
CARGO_FLAGS="--locked"
elif [[ $BUILD_TYPE == "debug" ]]; then
cov_prefix=""
CARGO_FLAGS="--locked"
elif [[ $BUILD_TYPE == "release" ]]; then
cov_prefix=""
CARGO_FLAGS="--locked --release"
@@ -163,8 +158,6 @@ jobs:
# Do install *before* running rust tests because they might recompile the
# binaries with different features/flags.
- name: Install rust binaries
env:
ARCH: ${{ inputs.arch }}
run: |
# Install target binaries
mkdir -p /tmp/neon/bin/
@@ -179,7 +172,7 @@ jobs:
done
# Install test executables and write list of all binaries (for code coverage)
if [[ $BUILD_TYPE == "debug" && $ARCH == 'x64' ]]; then
if [[ $BUILD_TYPE == "debug" ]]; then
# Keep bloated coverage data files away from the rest of the artifact
mkdir -p /tmp/coverage/
@@ -250,8 +243,8 @@ jobs:
uses: ./.github/actions/save-coverage-data
regress-tests:
# Don't run regression tests on debug arm64 builds
if: inputs.build-type != 'debug' || inputs.arch != 'arm64'
# Run test on x64 only
if: inputs.arch == 'x64'
needs: [ build-neon ]
runs-on: ${{ fromJson(format('["self-hosted", "{0}"]', inputs.arch == 'arm64' && 'large-arm64' || 'large')) }}
container:

View File

@@ -198,7 +198,7 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: [ x64, arm64 ]
arch: [ x64 ]
# Do not build or run tests in debug for release branches
build-type: ${{ fromJson((startsWith(github.ref_name, 'release') && github.event_name == 'push') && '["release"]' || '["debug", "release"]') }}
include:

354
Cargo.lock generated
View File

@@ -316,6 +316,33 @@ dependencies = [
"zeroize",
]
[[package]]
name = "aws-lc-rs"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ae74d9bd0a7530e8afd1770739ad34b36838829d6ad61818f9230f683f5ad77"
dependencies = [
"aws-lc-sys",
"mirai-annotations",
"paste",
"zeroize",
]
[[package]]
name = "aws-lc-sys"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f0e249228c6ad2d240c2dc94b714d711629d52bad946075d8e9b2f5391f0703"
dependencies = [
"bindgen 0.69.4",
"cc",
"cmake",
"dunce",
"fs_extra",
"libc",
"paste",
]
[[package]]
name = "aws-runtime"
version = "1.2.1"
@@ -926,7 +953,30 @@ dependencies = [
"lazycell",
"log",
"peeking_take_while",
"prettyplease 0.2.6",
"prettyplease 0.2.17",
"proc-macro2",
"quote",
"regex",
"rustc-hash",
"shlex",
"syn 2.0.52",
"which",
]
[[package]]
name = "bindgen"
version = "0.69.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a00dc851838a2120612785d195287475a3ac45514741da670b735818822129a0"
dependencies = [
"bitflags 2.4.1",
"cexpr",
"clang-sys",
"itertools 0.12.1",
"lazy_static",
"lazycell",
"log",
"prettyplease 0.2.17",
"proc-macro2",
"quote",
"regex",
@@ -1056,6 +1106,12 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "cfg_aliases"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]]
name = "cgroups-rs"
version = "0.3.3"
@@ -1164,6 +1220,15 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b"
[[package]]
name = "cmake"
version = "0.1.51"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fb1e43aa7fd152b1f968787f7dbcdeb306d1867ff373c69955211876c053f91a"
dependencies = [
"cc",
]
[[package]]
name = "colorchoice"
version = "1.0.0"
@@ -1208,6 +1273,7 @@ dependencies = [
"serde_json",
"serde_with",
"utils",
"workspace_hack",
]
[[package]]
@@ -1320,6 +1386,7 @@ dependencies = [
"serde",
"serde_with",
"utils",
"workspace_hack",
]
[[package]]
@@ -1490,7 +1557,7 @@ dependencies = [
"bitflags 1.3.2",
"crossterm_winapi",
"libc",
"mio",
"mio 0.8.11",
"parking_lot 0.12.1",
"signal-hook",
"signal-hook-mio",
@@ -1668,13 +1735,14 @@ dependencies = [
"smallvec",
"tracing",
"utils",
"workspace_hack",
]
[[package]]
name = "diesel"
version = "2.2.3"
version = "2.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65e13bab2796f412722112327f3e575601a3e9cdcbe426f0d30dbf43f3f5dc71"
checksum = "62d6dcd069e7b5fe49a302411f759d4cf1cf2c27fe798ef46fb8baefc053dd2b"
dependencies = [
"bitflags 2.4.1",
"byteorder",
@@ -1765,6 +1833,12 @@ dependencies = [
"syn 2.0.52",
]
[[package]]
name = "dunce"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813"
[[package]]
name = "dyn-clone"
version = "1.0.14"
@@ -2066,6 +2140,12 @@ dependencies = [
"tokio-util",
]
[[package]]
name = "fs_extra"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "fsevent-sys"
version = "4.1.0"
@@ -2399,9 +2479,9 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hermit-abi"
version = "0.3.3"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7"
checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024"
[[package]]
name = "hex"
@@ -2919,6 +2999,33 @@ dependencies = [
"libc",
]
[[package]]
name = "ktls"
version = "6.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebe51e4a53d53b396707537bc8a5277798b720fb71f0d1b9c63eb53199a00fde"
dependencies = [
"futures-util",
"ktls-sys",
"libc",
"memoffset 0.9.1",
"nix 0.29.0",
"num_enum",
"pin-project-lite",
"rustls 0.23.12",
"smallvec",
"thiserror",
"tokio",
"tokio-rustls 0.26.0",
"tracing",
]
[[package]]
name = "ktls-sys"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "095b1fc8d841c3df8c3f2db78b7425cb2ec424568a282cb589a880b99d256e84"
[[package]]
name = "lasso"
version = "0.7.2"
@@ -2957,9 +3064,9 @@ dependencies = [
[[package]]
name = "libc"
version = "0.2.150"
version = "0.2.158"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c"
checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439"
[[package]]
name = "libloading"
@@ -3123,9 +3230,9 @@ dependencies = [
[[package]]
name = "memoffset"
version = "0.9.0"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c"
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
dependencies = [
"autocfg",
]
@@ -3144,6 +3251,7 @@ dependencies = [
"rand 0.8.5",
"rand_distr",
"twox-hash",
"workspace_hack",
]
[[package]]
@@ -3200,6 +3308,24 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "mio"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec"
dependencies = [
"hermit-abi",
"libc",
"wasi 0.11.0+wasi-snapshot-preview1",
"windows-sys 0.52.0",
]
[[package]]
name = "mirai-annotations"
version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c9be0862c1b3f26a88803c4a49de6889c10e608b3ee9344e6ef5b45fb37ad3d1"
[[package]]
name = "multimap"
version = "0.8.3"
@@ -3240,7 +3366,20 @@ dependencies = [
"bitflags 2.4.1",
"cfg-if",
"libc",
"memoffset 0.9.0",
"memoffset 0.9.1",
]
[[package]]
name = "nix"
version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46"
dependencies = [
"bitflags 2.4.1",
"cfg-if",
"cfg_aliases",
"libc",
"memoffset 0.9.1",
]
[[package]]
@@ -3267,7 +3406,7 @@ dependencies = [
"kqueue",
"libc",
"log",
"mio",
"mio 0.8.11",
"walkdir",
"windows-sys 0.48.0",
]
@@ -3389,6 +3528,27 @@ dependencies = [
"libc",
]
[[package]]
name = "num_enum"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e613fc340b2220f734a8595782c551f1250e969d87d3be1ae0579e8d4065179"
dependencies = [
"num_enum_derive",
]
[[package]]
name = "num_enum_derive"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af1844ef2428cc3e1cb900be36181049ef3d3193c63e43026cfe202983b27a56"
dependencies = [
"proc-macro-crate",
"proc-macro2",
"quote",
"syn 2.0.52",
]
[[package]]
name = "oauth2"
version = "4.4.2"
@@ -3787,6 +3947,7 @@ dependencies = [
"strum_macros",
"thiserror",
"utils",
"workspace_hack",
]
[[package]]
@@ -4051,9 +4212,9 @@ dependencies = [
[[package]]
name = "pin-project-lite"
version = "0.2.13"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58"
checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02"
[[package]]
name = "pin-utils"
@@ -4178,16 +4339,17 @@ dependencies = [
"futures",
"once_cell",
"pq_proto",
"rustls 0.22.4",
"rustls 0.23.12",
"rustls-pemfile 2.1.1",
"serde",
"thiserror",
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",
"tokio-rustls 0.25.0",
"tokio-rustls 0.26.0",
"tokio-util",
"tracing",
"workspace_hack",
]
[[package]]
@@ -4200,6 +4362,7 @@ dependencies = [
"postgres",
"tokio-postgres",
"url",
"workspace_hack",
]
[[package]]
@@ -4207,7 +4370,7 @@ name = "postgres_ffi"
version = "0.1.0"
dependencies = [
"anyhow",
"bindgen",
"bindgen 0.65.1",
"byteorder",
"bytes",
"crc32c",
@@ -4222,6 +4385,7 @@ dependencies = [
"serde",
"thiserror",
"utils",
"workspace_hack",
]
[[package]]
@@ -4259,6 +4423,7 @@ dependencies = [
"thiserror",
"tokio",
"tracing",
"workspace_hack",
]
[[package]]
@@ -4273,9 +4438,9 @@ dependencies = [
[[package]]
name = "prettyplease"
version = "0.2.6"
version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b69d39aab54d069e7f2fe8cb970493e7834601ca2d8c65fd7bbd183578080d1"
checksum = "8d3928fb5db768cb86f891ff014f0144589297e3c6a1aba6ed7cecfdace270c7"
dependencies = [
"proc-macro2",
"syn 2.0.52",
@@ -4290,6 +4455,15 @@ dependencies = [
"elliptic-curve 0.13.8",
]
[[package]]
name = "proc-macro-crate"
version = "3.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d37c51ca738a55da99dc0c4a34860fd675453b8b36209178c2249bb13651284"
dependencies = [
"toml_edit 0.21.1",
]
[[package]]
name = "proc-macro-hack"
version = "0.5.20+deprecated"
@@ -4448,6 +4622,7 @@ dependencies = [
"itertools 0.10.5",
"jose-jwa",
"jose-jwk",
"ktls",
"lasso",
"md5",
"measured",
@@ -4478,7 +4653,7 @@ dependencies = [
"rsa",
"rstest",
"rustc-hash",
"rustls 0.22.4",
"rustls 0.23.12",
"rustls-native-certs 0.7.0",
"rustls-pemfile 2.1.1",
"scopeguard",
@@ -4497,7 +4672,7 @@ dependencies = [
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",
"tokio-rustls 0.25.0",
"tokio-rustls 0.26.0",
"tokio-tungstenite",
"tokio-util",
"tower-service",
@@ -4663,12 +4838,13 @@ dependencies = [
[[package]]
name = "rcgen"
version = "0.12.1"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48406db8ac1f3cbc7dcdb56ec355343817958a356ff430259bb07baf7607e1e1"
checksum = "54077e1872c46788540de1ea3d7f4ccb1983d12f9aa909b234468676c1a36779"
dependencies = [
"pem",
"ring 0.17.6",
"rustls-pki-types",
"time",
"yasna",
]
@@ -4823,6 +4999,7 @@ dependencies = [
"toml_edit 0.19.10",
"tracing",
"utils",
"workspace_hack",
]
[[package]]
@@ -5180,7 +5357,22 @@ dependencies = [
"log",
"ring 0.17.6",
"rustls-pki-types",
"rustls-webpki 0.102.2",
"rustls-webpki 0.102.6",
"subtle",
"zeroize",
]
[[package]]
name = "rustls"
version = "0.23.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c58f8c84392efc0a126acce10fa59ff7b3d2ac06ab451a33f2741989b806b044"
dependencies = [
"aws-lc-rs",
"log",
"once_cell",
"rustls-pki-types",
"rustls-webpki 0.102.6",
"subtle",
"zeroize",
]
@@ -5231,9 +5423,9 @@ dependencies = [
[[package]]
name = "rustls-pki-types"
version = "1.3.1"
version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ede67b28608b4c60685c7d54122d4400d90f62b40caee7700e700380a390fa8"
checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0"
[[package]]
name = "rustls-webpki"
@@ -5257,10 +5449,11 @@ dependencies = [
[[package]]
name = "rustls-webpki"
version = "0.102.2"
version = "0.102.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "faaa0a62740bedb9b2ef5afa303da42764c012f743917351dc9a237ea1663610"
checksum = "8e6b52d4fda176fd835fdc55a835d4a89b8499cad995885a21149d5ad62f852e"
dependencies = [
"aws-lc-rs",
"ring 0.17.6",
"rustls-pki-types",
"untrusted 0.9.0",
@@ -5347,6 +5540,7 @@ dependencies = [
"serde",
"serde_with",
"utils",
"workspace_hack",
]
[[package]]
@@ -5590,12 +5784,11 @@ dependencies = [
[[package]]
name = "serde_json"
version = "1.0.125"
version = "1.0.96"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "83c8e735a073ccf5be70aa8066aa984eaf2fa000db6c8d0100ae605b366d31ed"
checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1"
dependencies = [
"itoa",
"memchr",
"ryu",
"serde",
]
@@ -5701,9 +5894,9 @@ dependencies = [
[[package]]
name = "sha2-asm"
version = "0.6.3"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f27ba7066011e3fb30d808b51affff34f0a66d3a03a58edd787c6e420e40e44e"
checksum = "b845214d6175804686b2bd482bcffe96651bb2d1200742b712003504a2dac1ab"
dependencies = [
"cc",
]
@@ -5740,7 +5933,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29ad2e15f37ec9a6cc544097b78a1ec90001e9f71b81338ca39f430adaca99af"
dependencies = [
"libc",
"mio",
"mio 0.8.11",
"signal-hook",
]
@@ -5802,9 +5995,9 @@ dependencies = [
[[package]]
name = "smallvec"
version = "1.13.1"
version = "1.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7"
checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
[[package]]
name = "smol_str"
@@ -5996,7 +6189,7 @@ dependencies = [
"rand 0.8.5",
"remote_storage",
"reqwest 0.12.4",
"rustls 0.22.4",
"rustls 0.23.12",
"rustls-native-certs 0.7.0",
"serde",
"serde_json",
@@ -6006,7 +6199,7 @@ dependencies = [
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",
"tokio-rustls 0.25.0",
"tokio-rustls 0.26.0",
"tokio-stream",
"tokio-util",
"tracing",
@@ -6183,6 +6376,7 @@ dependencies = [
"anyhow",
"serde",
"serde_json",
"workspace_hack",
]
[[package]]
@@ -6217,18 +6411,18 @@ dependencies = [
[[package]]
name = "thiserror"
version = "1.0.57"
version = "1.0.63"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e45bcbe8ed29775f228095caf2cd67af7a4ccf756ebff23a306bf3e8b47b24b"
checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.57"
version = "1.0.63"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81"
checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
dependencies = [
"proc-macro2",
"quote",
@@ -6355,20 +6549,19 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokio"
version = "1.37.0"
version = "1.39.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787"
checksum = "9babc99b9923bfa4804bd74722ff02c0381021eafa4db9949217e3be8e84fff5"
dependencies = [
"backtrace",
"bytes",
"libc",
"mio",
"num_cpus",
"mio 1.0.2",
"pin-project-lite",
"signal-hook-registry",
"socket2 0.5.5",
"tokio-macros",
"windows-sys 0.48.0",
"windows-sys 0.52.0",
]
[[package]]
@@ -6399,9 +6592,9 @@ dependencies = [
[[package]]
name = "tokio-macros"
version = "2.2.0"
version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b"
checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752"
dependencies = [
"proc-macro2",
"quote",
@@ -6433,16 +6626,15 @@ dependencies = [
[[package]]
name = "tokio-postgres-rustls"
version = "0.11.1"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ea13f22eda7127c827983bdaf0d7fff9df21c8817bab02815ac277a21143677"
checksum = "04fb792ccd6bbcd4bba408eb8a292f70fc4a3589e5d793626f45190e6454b6ab"
dependencies = [
"futures",
"ring 0.17.6",
"rustls 0.22.4",
"rustls 0.23.12",
"tokio",
"tokio-postgres",
"tokio-rustls 0.25.0",
"tokio-rustls 0.26.0",
"x509-certificate",
]
@@ -6467,6 +6659,17 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-rustls"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4"
dependencies = [
"rustls 0.23.12",
"rustls-pki-types",
"tokio",
]
[[package]]
name = "tokio-stream"
version = "0.1.14"
@@ -6568,6 +6771,17 @@ dependencies = [
"winnow 0.4.6",
]
[[package]]
name = "toml_edit"
version = "0.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8534fd7f78b5405e860340ad6575217ce99f38d4d5c8f2442cb5ecb50090e1"
dependencies = [
"indexmap 2.0.1",
"toml_datetime",
"winnow 0.5.40",
]
[[package]]
name = "toml_edit"
version = "0.22.14"
@@ -6660,11 +6874,10 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52"
[[package]]
name = "tracing"
version = "0.1.37"
version = "0.1.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8"
checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
dependencies = [
"cfg-if",
"log",
"pin-project-lite",
"tracing-attributes",
@@ -6684,9 +6897,9 @@ dependencies = [
[[package]]
name = "tracing-attributes"
version = "0.1.24"
version = "0.1.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f57e3ca2a01450b1a921183a9c9cbfda207fd822cef4ccb00a65402cbba7a74"
checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
dependencies = [
"proc-macro2",
"quote",
@@ -6695,9 +6908,9 @@ dependencies = [
[[package]]
name = "tracing-core"
version = "0.1.31"
version = "0.1.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a"
checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54"
dependencies = [
"once_cell",
"valuable",
@@ -6783,6 +6996,7 @@ dependencies = [
"tracing",
"tracing-opentelemetry",
"tracing-subscriber",
"workspace_hack",
]
[[package]]
@@ -7000,6 +7214,7 @@ dependencies = [
"url",
"uuid",
"walkdir",
"workspace_hack",
]
[[package]]
@@ -7078,6 +7293,7 @@ dependencies = [
"postgres_ffi",
"regex",
"utils",
"workspace_hack",
]
[[package]]
@@ -7095,9 +7311,10 @@ name = "walproposer"
version = "0.1.0"
dependencies = [
"anyhow",
"bindgen",
"bindgen 0.65.1",
"postgres_ffi",
"utils",
"workspace_hack",
]
[[package]]
@@ -7548,6 +7765,15 @@ dependencies = [
"memchr",
]
[[package]]
name = "winnow"
version = "0.5.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f593a95398737aeed53e489c785df13f3618e41dbcd6718c6addbf1395aa6876"
dependencies = [
"memchr",
]
[[package]]
name = "winnow"
version = "0.6.13"
@@ -7637,6 +7863,8 @@ dependencies = [
"reqwest 0.11.19",
"reqwest 0.12.4",
"rustls 0.21.11",
"rustls-pki-types",
"rustls-webpki 0.102.6",
"scopeguard",
"serde",
"serde_json",
@@ -7654,6 +7882,8 @@ dependencies = [
"tokio",
"tokio-rustls 0.24.0",
"tokio-util",
"toml_datetime",
"toml_edit 0.19.10",
"tonic",
"tower",
"tracing",

View File

@@ -113,7 +113,7 @@ md5 = "0.7.0"
measured = { version = "0.0.22", features=["lasso"] }
measured-process = { version = "0.0.22" }
memoffset = "0.8"
nix = { version = "0.27", features = ["dir", "fs", "process", "socket", "signal", "poll"] }
nix = { version = "0.27", features = ["fs", "process", "socket", "signal", "poll"] }
notify = "6.0.0"
num_cpus = "1.15"
num-traits = "0.2.15"
@@ -139,7 +139,7 @@ reqwest-retry = "0.5"
routerify = "3"
rpds = "0.13"
rustc-hash = "1.1.0"
rustls = "0.22"
rustls = "0.23"
rustls-pemfile = "2"
rustls-split = "0.3"
scopeguard = "1.1"
@@ -171,8 +171,8 @@ tikv-jemalloc-ctl = "0.5"
tokio = { version = "1.17", features = ["macros"] }
tokio-epoll-uring = { git = "https://github.com/neondatabase/tokio-epoll-uring.git" , branch = "main" }
tokio-io-timeout = "1.2.0"
tokio-postgres-rustls = "0.11.0"
tokio-rustls = "0.25"
tokio-postgres-rustls = "0.12.0"
tokio-rustls = "0.26"
tokio-stream = "0.1"
tokio-tar = "0.3"
tokio-util = { version = "0.7.10", features = ["io", "rt"] }
@@ -232,7 +232,7 @@ workspace_hack = { version = "0.1", path = "./workspace_hack/" }
## Build dependencies
criterion = "0.5.1"
rcgen = "0.12"
rcgen = "0.13"
rstest = "0.18"
camino-tempfile = "1.0.2"
tonic-build = "0.9"

View File

@@ -126,7 +126,7 @@ make -j`sysctl -n hw.logicalcpu` -s
To run the `psql` client, install the `postgresql-client` package or modify `PATH` and `LD_LIBRARY_PATH` to include `pg_install/bin` and `pg_install/lib`, respectively.
To run the integration tests or Python scripts (not required to use the code), install
Python (3.9 or higher), and install the python3 packages using `./scripts/pysync` (requires [poetry>=1.8](https://python-poetry.org/)) in the project directory.
Python (3.9 or higher), and install the python3 packages using `./scripts/pysync` (requires [poetry>=1.3](https://python-poetry.org/)) in the project directory.
#### Running neon database

View File

@@ -147,9 +147,9 @@ enum Command {
#[arg(long)]
threshold: humantime::Duration,
},
// Migrate away from a set of specified pageservers by moving the primary attachments to pageservers
// Drain a set of specified pageservers by moving the primary attachments to pageservers
// outside of the specified set.
BulkMigrate {
Drain {
// Set of pageserver node ids to drain.
#[arg(long)]
nodes: Vec<NodeId>,
@@ -163,34 +163,6 @@ enum Command {
#[arg(long)]
dry_run: Option<bool>,
},
/// Start draining the specified pageserver.
/// The drain is complete when the schedulling policy returns to active.
StartDrain {
#[arg(long)]
node_id: NodeId,
},
/// Cancel draining the specified pageserver and wait for `timeout`
/// for the operation to be canceled. May be retried.
CancelDrain {
#[arg(long)]
node_id: NodeId,
#[arg(long)]
timeout: humantime::Duration,
},
/// Start filling the specified pageserver.
/// The drain is complete when the schedulling policy returns to active.
StartFill {
#[arg(long)]
node_id: NodeId,
},
/// Cancel filling the specified pageserver and wait for `timeout`
/// for the operation to be canceled. May be retried.
CancelFill {
#[arg(long)]
node_id: NodeId,
#[arg(long)]
timeout: humantime::Duration,
},
}
#[derive(Parser)]
@@ -277,34 +249,6 @@ impl FromStr for NodeAvailabilityArg {
}
}
async fn wait_for_scheduling_policy<F>(
client: Client,
node_id: NodeId,
timeout: Duration,
f: F,
) -> anyhow::Result<NodeSchedulingPolicy>
where
F: Fn(NodeSchedulingPolicy) -> bool,
{
let waiter = tokio::time::timeout(timeout, async move {
loop {
let node = client
.dispatch::<(), NodeDescribeResponse>(
Method::GET,
format!("control/v1/node/{node_id}"),
None,
)
.await?;
if f(node.scheduling) {
return Ok::<NodeSchedulingPolicy, mgmt_api::Error>(node.scheduling);
}
}
});
Ok(waiter.await??)
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let cli = Cli::parse();
@@ -684,7 +628,7 @@ async fn main() -> anyhow::Result<()> {
})
.await?;
}
Command::BulkMigrate {
Command::Drain {
nodes,
concurrency,
max_shards,
@@ -713,7 +657,7 @@ async fn main() -> anyhow::Result<()> {
}
if nodes.len() != node_to_drain_descs.len() {
anyhow::bail!("Bulk migration requested away from node which doesn't exist.")
anyhow::bail!("Drain requested for node which doesn't exist.")
}
node_to_fill_descs.retain(|desc| {
@@ -725,7 +669,7 @@ async fn main() -> anyhow::Result<()> {
});
if node_to_fill_descs.is_empty() {
anyhow::bail!("There are no nodes to migrate to")
anyhow::bail!("There are no nodes to drain to")
}
// Set the node scheduling policy to draining for the nodes which
@@ -746,7 +690,7 @@ async fn main() -> anyhow::Result<()> {
.await?;
}
// Perform the migration: move each tenant shard scheduled on a node to
// Perform the drain: move each tenant shard scheduled on a node to
// be drained to a node which is being filled. A simple round robin
// strategy is used to pick the new node.
let tenants = storcon_client
@@ -759,13 +703,13 @@ async fn main() -> anyhow::Result<()> {
let mut selected_node_idx = 0;
struct MigrationMove {
struct DrainMove {
tenant_shard_id: TenantShardId,
from: NodeId,
to: NodeId,
}
let mut moves: Vec<MigrationMove> = Vec::new();
let mut moves: Vec<DrainMove> = Vec::new();
let shards = tenants
.into_iter()
@@ -795,7 +739,7 @@ async fn main() -> anyhow::Result<()> {
continue;
}
moves.push(MigrationMove {
moves.push(DrainMove {
tenant_shard_id: shard.tenant_shard_id,
from: shard
.node_attached
@@ -872,67 +816,6 @@ async fn main() -> anyhow::Result<()> {
failure
);
}
Command::StartDrain { node_id } => {
storcon_client
.dispatch::<(), ()>(
Method::PUT,
format!("control/v1/node/{node_id}/drain"),
None,
)
.await?;
println!("Drain started for {node_id}");
}
Command::CancelDrain { node_id, timeout } => {
storcon_client
.dispatch::<(), ()>(
Method::DELETE,
format!("control/v1/node/{node_id}/drain"),
None,
)
.await?;
println!("Waiting for node {node_id} to quiesce on scheduling policy ...");
let final_policy =
wait_for_scheduling_policy(storcon_client, node_id, *timeout, |sched| {
use NodeSchedulingPolicy::*;
matches!(sched, Active | PauseForRestart)
})
.await?;
println!(
"Drain was cancelled for node {node_id}. Schedulling policy is now {final_policy:?}"
);
}
Command::StartFill { node_id } => {
storcon_client
.dispatch::<(), ()>(Method::PUT, format!("control/v1/node/{node_id}/fill"), None)
.await?;
println!("Fill started for {node_id}");
}
Command::CancelFill { node_id, timeout } => {
storcon_client
.dispatch::<(), ()>(
Method::DELETE,
format!("control/v1/node/{node_id}/fill"),
None,
)
.await?;
println!("Waiting for node {node_id} to quiesce on scheduling policy ...");
let final_policy =
wait_for_scheduling_policy(storcon_client, node_id, *timeout, |sched| {
use NodeSchedulingPolicy::*;
matches!(sched, Active)
})
.await?;
println!(
"Fill was cancelled for node {node_id}. Schedulling policy is now {final_policy:?}"
);
}
}
Ok(())

View File

@@ -21,21 +21,30 @@ _Example: 15.4 is the new minor version to upgrade to from 15.3._
1. Create a new branch based on the stable branch you are updating.
```shell
git checkout -b my-branch-15 REL_15_STABLE_neon
git checkout -b my-branch REL_15_STABLE_neon
```
1. Find the upstream release tags you're looking for. They are of the form `REL_X_Y`.
1. Tag the last commit on the stable branch you are updating.
1. Merge the upstream tag into the branch you created on the tag and resolve any conflicts.
```shell
git tag REL_15_3_neon
```
1. Push the new tag to the Neon Postgres repository.
```shell
git push origin REL_15_3_neon
```
1. Find the release tags you're looking for. They are of the form `REL_X_Y`.
1. Rebase the branch you created on the tag and resolve any conflicts.
```shell
git fetch upstream REL_15_4
git merge REL_15_4
git rebase REL_15_4
```
In the commit message of the merge commit, mention if there were
any non-trivial conflicts or other issues.
1. Run the Postgres test suite to make sure our commits have not affected
Postgres in a negative way.
@@ -48,7 +57,7 @@ Postgres in a negative way.
1. Push your branch to the Neon Postgres repository.
```shell
git push origin my-branch-15
git push origin my-branch
```
1. Clone the Neon repository if you have not done so already.
@@ -65,7 +74,7 @@ branch.
1. Update the Git submodule.
```shell
git submodule set-branch --branch my-branch-15 vendor/postgres-v15
git submodule set-branch --branch my-branch vendor/postgres-v15
git submodule update --remote vendor/postgres-v15
```
@@ -80,12 +89,14 @@ minor Postgres release.
1. Create a pull request, and wait for CI to go green.
1. Push the Postgres branches with the merge commits into the Neon Postgres repository.
1. Force push the rebased Postgres branches into the Neon Postgres repository.
```shell
git push origin my-branch-15:REL_15_STABLE_neon
git push --force origin my-branch:REL_15_STABLE_neon
```
It may require disabling various branch protections.
1. Update your Neon PR to point at the branches.
```shell

View File

@@ -14,3 +14,5 @@ regex.workspace = true
utils = { path = "../utils" }
remote_storage = { version = "0.1", path = "../remote_storage/" }
workspace_hack.workspace = true

View File

@@ -6,8 +6,10 @@ license = "Apache-2.0"
[dependencies]
anyhow.workspace = true
chrono = { workspace = true, features = ["serde"] }
chrono.workspace = true
rand.workspace = true
serde.workspace = true
serde_with.workspace = true
utils.workspace = true
workspace_hack.workspace = true

View File

@@ -14,3 +14,5 @@ parking_lot.workspace = true
hex.workspace = true
scopeguard.workspace = true
smallvec = { workspace = true, features = ["write"] }
workspace_hack.workspace = true

View File

@@ -12,6 +12,8 @@ chrono.workspace = true
twox-hash.workspace = true
measured.workspace = true
workspace_hack.workspace = true
[target.'cfg(target_os = "linux")'.dependencies]
procfs.workspace = true
measured-process.workspace = true

View File

@@ -21,9 +21,11 @@ hex.workspace = true
humantime.workspace = true
thiserror.workspace = true
humantime-serde.workspace = true
chrono = { workspace = true, features = ["serde"] }
chrono.workspace = true
itertools.workspace = true
workspace_hack.workspace = true
[dev-dependencies]
bincode.workspace = true
rand.workspace = true

View File

@@ -8,7 +8,6 @@ use std::time::{Duration, Instant};
use serde::{Deserialize, Serialize};
use utils::id::{NodeId, TenantId};
use crate::models::PageserverUtilization;
use crate::{
models::{ShardParameters, TenantConfig},
shard::{ShardStripeSize, TenantShardId},
@@ -141,11 +140,23 @@ pub struct TenantShardMigrateRequest {
pub node_id: NodeId,
}
#[derive(Serialize, Clone, Debug)]
/// Utilisation score indicating how good a candidate a pageserver
/// is for scheduling the next tenant. See [`crate::models::PageserverUtilization`].
/// Lower values are better.
#[derive(Serialize, Deserialize, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Debug)]
pub struct UtilizationScore(pub u64);
impl UtilizationScore {
pub fn worst() -> Self {
UtilizationScore(u64::MAX)
}
}
#[derive(Serialize, Clone, Copy, Debug)]
#[serde(into = "NodeAvailabilityWrapper")]
pub enum NodeAvailability {
// Normal, happy state
Active(PageserverUtilization),
Active(UtilizationScore),
// Node is warming up, but we expect it to become available soon. Covers
// the time span between the re-attach response being composed on the storage controller
// and the first successful heartbeat after the processing of the re-attach response
@@ -184,9 +195,7 @@ impl From<NodeAvailabilityWrapper> for NodeAvailability {
match val {
// Assume the worst utilisation score to begin with. It will later be updated by
// the heartbeats.
NodeAvailabilityWrapper::Active => {
NodeAvailability::Active(PageserverUtilization::full())
}
NodeAvailabilityWrapper::Active => NodeAvailability::Active(UtilizationScore::worst()),
NodeAvailabilityWrapper::WarmingUp => NodeAvailability::WarmingUp(Instant::now()),
NodeAvailabilityWrapper::Offline => NodeAvailability::Offline,
}

View File

@@ -236,15 +236,6 @@ impl Key {
field5: u8::MAX,
field6: u32::MAX,
};
/// A key slightly smaller than [`Key::MAX`] for use in layer key ranges to avoid them to be confused with L0 layers
pub const NON_L0_MAX: Key = Key {
field1: u8::MAX,
field2: u32::MAX,
field3: u32::MAX,
field4: u32::MAX,
field5: u8::MAX,
field6: u32::MAX - 1,
};
pub fn from_hex(s: &str) -> Result<Self> {
if s.len() != 36 {

View File

@@ -348,7 +348,7 @@ impl AuxFilePolicy {
/// If a tenant writes aux files without setting `switch_aux_policy`, this value will be used.
pub fn default_tenant_config() -> Self {
Self::V2
Self::V1
}
}
@@ -718,7 +718,6 @@ pub struct TimelineInfo {
pub pg_version: u32,
pub state: TimelineState,
pub is_archived: bool,
pub walreceiver_status: String,

View File

@@ -38,7 +38,7 @@ pub struct PageserverUtilization {
pub max_shard_count: u32,
/// Cached result of [`Self::score`]
pub utilization_score: Option<u64>,
pub utilization_score: u64,
/// When was this snapshot captured, pageserver local time.
///
@@ -50,8 +50,6 @@ fn unity_percent() -> Percent {
Percent::new(0).unwrap()
}
pub type RawScore = u64;
impl PageserverUtilization {
const UTILIZATION_FULL: u64 = 1000000;
@@ -64,7 +62,7 @@ impl PageserverUtilization {
/// - Negative values are forbidden
/// - Values over UTILIZATION_FULL indicate an overloaded node, which may show degraded performance due to
/// layer eviction.
pub fn score(&self) -> RawScore {
pub fn score(&self) -> u64 {
let disk_usable_capacity = ((self.disk_usage_bytes + self.free_space_bytes)
* self.disk_usable_pct.get() as u64)
/ 100;
@@ -76,30 +74,8 @@ impl PageserverUtilization {
std::cmp::max(disk_utilization_score, shard_utilization_score)
}
pub fn cached_score(&mut self) -> RawScore {
match self.utilization_score {
None => {
let s = self.score();
self.utilization_score = Some(s);
s
}
Some(s) => s,
}
}
/// If a node is currently hosting more work than it can comfortably handle. This does not indicate that
/// it will fail, but it is a strong signal that more work should not be added unless there is no alternative.
pub fn is_overloaded(score: RawScore) -> bool {
score >= Self::UTILIZATION_FULL
}
pub fn adjust_shard_count_max(&mut self, shard_count: u32) {
if self.shard_count < shard_count {
self.shard_count = shard_count;
// Dirty cache: this will be calculated next time someone retrives the score
self.utilization_score = None;
}
pub fn refresh_score(&mut self) {
self.utilization_score = self.score();
}
/// A utilization structure that has a full utilization score: use this as a placeholder when
@@ -112,38 +88,7 @@ impl PageserverUtilization {
disk_usable_pct: Percent::new(100).unwrap(),
shard_count: 1,
max_shard_count: 1,
utilization_score: Some(Self::UTILIZATION_FULL),
captured_at: serde_system_time::SystemTime(SystemTime::now()),
}
}
}
/// Test helper
pub mod test_utilization {
use super::PageserverUtilization;
use std::time::SystemTime;
use utils::{
serde_percent::Percent,
serde_system_time::{self},
};
// Parameters of the imaginary node used for test utilization instances
const TEST_DISK_SIZE: u64 = 1024 * 1024 * 1024 * 1024;
const TEST_SHARDS_MAX: u32 = 1000;
/// Unit test helper. Unconditionally compiled because cfg(test) doesn't carry across crates. Do
/// not abuse this function from non-test code.
///
/// Emulates a node with a 1000 shard limit and a 1TB disk.
pub fn simple(shard_count: u32, disk_wanted_bytes: u64) -> PageserverUtilization {
PageserverUtilization {
disk_usage_bytes: disk_wanted_bytes,
free_space_bytes: TEST_DISK_SIZE - std::cmp::min(disk_wanted_bytes, TEST_DISK_SIZE),
disk_wanted_bytes,
disk_usable_pct: Percent::new(100).unwrap(),
shard_count,
max_shard_count: TEST_SHARDS_MAX,
utilization_score: None,
utilization_score: Self::UTILIZATION_FULL,
captured_at: serde_system_time::SystemTime(SystemTime::now()),
}
}
@@ -175,7 +120,7 @@ mod tests {
disk_usage_bytes: u64::MAX,
free_space_bytes: 0,
disk_wanted_bytes: u64::MAX,
utilization_score: Some(13),
utilization_score: 13,
disk_usable_pct: Percent::new(90).unwrap(),
shard_count: 100,
max_shard_count: 200,

View File

@@ -18,6 +18,7 @@ tokio-rustls.workspace = true
tracing.workspace = true
pq_proto.workspace = true
workspace_hack.workspace = true
[dev-dependencies]
once_cell.workspace = true

View File

@@ -11,5 +11,7 @@ postgres.workspace = true
tokio-postgres.workspace = true
url.workspace = true
workspace_hack.workspace = true
[dev-dependencies]
once_cell.workspace = true

View File

@@ -19,6 +19,8 @@ thiserror.workspace = true
serde.workspace = true
utils.workspace = true
workspace_hack.workspace = true
[dev-dependencies]
env_logger.workspace = true
postgres.workspace = true

View File

@@ -14,6 +14,8 @@ postgres.workspace = true
postgres_ffi.workspace = true
camino-tempfile.workspace = true
workspace_hack.workspace = true
[dev-dependencies]
regex.workspace = true
utils.workspace = true

View File

@@ -11,7 +11,9 @@ itertools.workspace = true
pin-project-lite.workspace = true
postgres-protocol.workspace = true
rand.workspace = true
tokio = { workspace = true, features = ["io-util"] }
tokio.workspace = true
tracing.workspace = true
thiserror.workspace = true
serde.workspace = true
workspace_hack.workspace = true

View File

@@ -32,7 +32,7 @@ scopeguard.workspace = true
metrics.workspace = true
utils.workspace = true
pin-project-lite.workspace = true
workspace_hack.workspace = true
azure_core.workspace = true
azure_identity.workspace = true
azure_storage.workspace = true
@@ -46,4 +46,3 @@ sync_wrapper = { workspace = true, features = ["futures"] }
camino-tempfile.workspace = true
test-context.workspace = true
rand.workspace = true
tokio = { workspace = true, features = ["test-util"] }

View File

@@ -9,3 +9,5 @@ serde.workspace = true
serde_with.workspace = true
const_format.workspace = true
utils.workspace = true
workspace_hack.workspace = true

View File

@@ -9,3 +9,5 @@ license.workspace = true
anyhow.workspace = true
serde.workspace = true
serde_json.workspace = true
workspace_hack.workspace = true

View File

@@ -14,3 +14,5 @@ tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
tracing.workspace = true
tracing-opentelemetry.workspace = true
tracing-subscriber.workspace = true
workspace_hack.workspace = true

View File

@@ -39,7 +39,7 @@ thiserror.workspace = true
tokio.workspace = true
tokio-tar.workspace = true
tokio-util.workspace = true
toml_edit = { workspace = true, features = ["serde"] }
toml_edit.workspace = true
tracing.workspace = true
tracing-error.workspace = true
tracing-subscriber = { workspace = true, features = ["json", "registry"] }
@@ -54,6 +54,7 @@ walkdir.workspace = true
pq_proto.workspace = true
postgres_connection.workspace = true
metrics.workspace = true
workspace_hack.workspace = true
const_format.workspace = true
@@ -70,7 +71,6 @@ criterion.workspace = true
hex-literal.workspace = true
camino-tempfile.workspace = true
serde_assert.workspace = true
tokio = { workspace = true, features = ["test-util"] }
[[bench]]
name = "benchmarks"

View File

@@ -9,6 +9,8 @@ anyhow.workspace = true
utils.workspace = true
postgres_ffi.workspace = true
workspace_hack.workspace = true
[build-dependencies]
anyhow.workspace = true
bindgen.workspace = true

View File

@@ -95,7 +95,6 @@ fn main() -> anyhow::Result<()> {
.allowlist_var("ERROR")
.allowlist_var("FATAL")
.allowlist_var("PANIC")
.allowlist_var("PG_VERSION_NUM")
.allowlist_var("WPEVENT")
.allowlist_var("WL_LATCH_SET")
.allowlist_var("WL_SOCKET_READABLE")

View File

@@ -282,11 +282,7 @@ mod tests {
use std::cell::UnsafeCell;
use utils::id::TenantTimelineId;
use crate::{
api_bindings::Level,
bindings::{NeonWALReadResult, PG_VERSION_NUM},
walproposer::Wrapper,
};
use crate::{api_bindings::Level, bindings::NeonWALReadResult, walproposer::Wrapper};
use super::ApiImpl;
@@ -493,79 +489,41 @@ mod tests {
let (sender, receiver) = sync_channel(1);
// Messages definitions are at walproposer.h
// xxx: it would be better to extract them from safekeeper crate and
// use serialization/deserialization here.
let greeting_tag = (b'g' as u64).to_ne_bytes();
let proto_version = 2_u32.to_ne_bytes();
let pg_version: [u8; 4] = PG_VERSION_NUM.to_ne_bytes();
let proposer_id = [0; 16];
let system_id = 0_u64.to_ne_bytes();
let tenant_id = ttid.tenant_id.as_arr();
let timeline_id = ttid.timeline_id.as_arr();
let pg_tli = 1_u32.to_ne_bytes();
let wal_seg_size = 16777216_u32.to_ne_bytes();
let proposer_greeting = [
greeting_tag.as_slice(),
proto_version.as_slice(),
pg_version.as_slice(),
proposer_id.as_slice(),
system_id.as_slice(),
tenant_id.as_slice(),
timeline_id.as_slice(),
pg_tli.as_slice(),
wal_seg_size.as_slice(),
]
.concat();
let voting_tag = (b'v' as u64).to_ne_bytes();
let vote_request_term = 3_u64.to_ne_bytes();
let proposer_id = [0; 16];
let vote_request = [
voting_tag.as_slice(),
vote_request_term.as_slice(),
proposer_id.as_slice(),
]
.concat();
let acceptor_greeting_term = 2_u64.to_ne_bytes();
let acceptor_greeting_node_id = 1_u64.to_ne_bytes();
let acceptor_greeting = [
greeting_tag.as_slice(),
acceptor_greeting_term.as_slice(),
acceptor_greeting_node_id.as_slice(),
]
.concat();
let vote_response_term = 3_u64.to_ne_bytes();
let vote_given = 1_u64.to_ne_bytes();
let flush_lsn = 0x539_u64.to_ne_bytes();
let truncate_lsn = 0x539_u64.to_ne_bytes();
let th_len = 1_u32.to_ne_bytes();
let th_term = 2_u64.to_ne_bytes();
let th_lsn = 0x539_u64.to_ne_bytes();
let timeline_start_lsn = 0x539_u64.to_ne_bytes();
let vote_response = [
voting_tag.as_slice(),
vote_response_term.as_slice(),
vote_given.as_slice(),
flush_lsn.as_slice(),
truncate_lsn.as_slice(),
th_len.as_slice(),
th_term.as_slice(),
th_lsn.as_slice(),
timeline_start_lsn.as_slice(),
]
.concat();
let my_impl: Box<dyn ApiImpl> = Box::new(MockImpl {
wait_events: Cell::new(WaitEventsData {
sk: std::ptr::null_mut(),
event_mask: 0,
}),
expected_messages: vec![proposer_greeting, vote_request],
expected_messages: vec![
// TODO: When updating Postgres versions, this test will cause
// problems. Postgres version in message needs updating.
//
// Greeting(ProposerGreeting { protocol_version: 2, pg_version: 160003, proposer_id: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], system_id: 0, timeline_id: 9e4c8f36063c6c6e93bc20d65a820f3d, tenant_id: 9e4c8f36063c6c6e93bc20d65a820f3d, tli: 1, wal_seg_size: 16777216 })
vec![
103, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3, 113, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 158, 76, 143, 54, 6, 60, 108, 110,
147, 188, 32, 214, 90, 130, 15, 61, 158, 76, 143, 54, 6, 60, 108, 110, 147,
188, 32, 214, 90, 130, 15, 61, 1, 0, 0, 0, 0, 0, 0, 1,
],
// VoteRequest(VoteRequest { term: 3 })
vec![
118, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0,
],
],
expected_ptr: AtomicUsize::new(0),
safekeeper_replies: vec![acceptor_greeting, vote_response],
safekeeper_replies: vec![
// Greeting(AcceptorGreeting { term: 2, node_id: NodeId(1) })
vec![
103, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
],
// VoteResponse(VoteResponse { term: 3, vote_given: 1, flush_lsn: 0/539, truncate_lsn: 0/539, term_history: [(2, 0/539)], timeline_start_lsn: 0/539 })
vec![
118, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 57,
5, 0, 0, 0, 0, 0, 0, 57, 5, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0,
0, 57, 5, 0, 0, 0, 0, 0, 0, 57, 5, 0, 0, 0, 0, 0, 0,
],
],
replies_ptr: AtomicUsize::new(0),
sync_channel: sender,
shmem: UnsafeCell::new(crate::api_bindings::empty_shmem()),

View File

@@ -10,7 +10,6 @@ use pageserver::{
page_cache,
repository::Value,
task_mgr::TaskKind,
tenant::storage_layer::inmemory_layer::SerializedBatch,
tenant::storage_layer::InMemoryLayer,
virtual_file,
};
@@ -68,16 +67,12 @@ async fn ingest(
let layer =
InMemoryLayer::create(conf, timeline_id, tenant_shard_id, lsn, entered, &ctx).await?;
let data = Value::Image(Bytes::from(vec![0u8; put_size]));
let data_ser_size = data.serialized_size().unwrap() as usize;
let data = Value::Image(Bytes::from(vec![0u8; put_size])).ser()?;
let ctx = RequestContext::new(
pageserver::task_mgr::TaskKind::WalReceiverConnectionHandler,
pageserver::context::DownloadBehavior::Download,
);
const BATCH_SIZE: usize = 16;
let mut batch = Vec::new();
for i in 0..put_count {
lsn += put_size as u64;
@@ -100,17 +95,7 @@ async fn ingest(
}
}
batch.push((key.to_compact(), lsn, data_ser_size, data.clone()));
if batch.len() >= BATCH_SIZE {
let this_batch = std::mem::take(&mut batch);
let serialized = SerializedBatch::from_values(this_batch);
layer.put_batch(serialized, &ctx).await?;
}
}
if !batch.is_empty() {
let this_batch = std::mem::take(&mut batch);
let serialized = SerializedBatch::from_values(this_batch);
layer.put_batch(serialized, &ctx).await?;
layer.put_value(key.to_compact(), lsn, &data, &ctx).await?;
}
layer.freeze(lsn + 1).await;

View File

@@ -126,56 +126,10 @@ fn main() -> anyhow::Result<()> {
info!(?conf.virtual_file_direct_io, "starting with virtual_file Direct IO settings");
info!(?conf.compact_level0_phase1_value_access, "starting with setting for compact_level0_phase1_value_access");
// The tenants directory contains all the pageserver local disk state.
// Create if not exists and make sure all the contents are durable before proceeding.
// Ensuring durability eliminates a whole bug class where we come up after an unclean shutdown.
// After unclea shutdown, we don't know if all the filesystem content we can read via syscalls is actually durable or not.
// Examples for that: OOM kill, systemd killing us during shutdown, self abort due to unrecoverable IO error.
let tenants_path = conf.tenants_path();
{
let open = || {
nix::dir::Dir::open(
tenants_path.as_std_path(),
nix::fcntl::OFlag::O_DIRECTORY | nix::fcntl::OFlag::O_RDONLY,
nix::sys::stat::Mode::empty(),
)
};
let dirfd = match open() {
Ok(dirfd) => dirfd,
Err(e) => match e {
nix::errno::Errno::ENOENT => {
utils::crashsafe::create_dir_all(&tenants_path).with_context(|| {
format!("Failed to create tenants root dir at '{tenants_path}'")
})?;
open().context("open tenants dir after creating it")?
}
e => anyhow::bail!(e),
},
};
let started = Instant::now();
// Linux guarantees durability for syncfs.
// POSIX doesn't have syncfs, and further does not actually guarantee durability of sync().
#[cfg(target_os = "linux")]
{
use std::os::fd::AsRawFd;
nix::unistd::syncfs(dirfd.as_raw_fd()).context("syncfs")?;
}
#[cfg(target_os = "macos")]
{
// macOS is not a production platform for Neon, don't even bother.
drop(dirfd);
}
#[cfg(not(any(target_os = "linux", target_os = "macos")))]
{
compile_error!("Unsupported OS");
}
let elapsed = started.elapsed();
info!(
elapsed_ms = elapsed.as_millis(),
"made tenant directory contents durable"
);
if !tenants_path.exists() {
utils::crashsafe::create_dir_all(conf.tenants_path())
.with_context(|| format!("Failed to create tenants root dir at '{tenants_path}'"))?;
}
// Initialize up failpoints support

View File

@@ -318,24 +318,6 @@ impl From<crate::tenant::DeleteTimelineError> for ApiError {
}
}
impl From<crate::tenant::TimelineArchivalError> for ApiError {
fn from(value: crate::tenant::TimelineArchivalError) -> Self {
use crate::tenant::TimelineArchivalError::*;
match value {
NotFound => ApiError::NotFound(anyhow::anyhow!("timeline not found").into()),
Timeout => ApiError::Timeout("hit pageserver internal timeout".into()),
HasUnarchivedChildren(children) => ApiError::PreconditionFailed(
format!(
"Cannot archive timeline which has non-archived child timelines: {children:?}"
)
.into_boxed_str(),
),
a @ AlreadyInProgress => ApiError::Conflict(a.to_string()),
Other(e) => ApiError::InternalServerError(e),
}
}
}
impl From<crate::tenant::mgr::DeleteTimelineError> for ApiError {
fn from(value: crate::tenant::mgr::DeleteTimelineError) -> Self {
use crate::tenant::mgr::DeleteTimelineError::*;
@@ -423,8 +405,6 @@ async fn build_timeline_info_common(
let current_logical_size = timeline.get_current_logical_size(logical_size_task_priority, ctx);
let current_physical_size = Some(timeline.layer_size_sum().await);
let state = timeline.current_state();
// Report is_archived = false if the timeline is still loading
let is_archived = timeline.is_archived().unwrap_or(false);
let remote_consistent_lsn_projected = timeline
.get_remote_consistent_lsn_projected()
.unwrap_or(Lsn(0));
@@ -465,7 +445,6 @@ async fn build_timeline_info_common(
pg_version: timeline.pg_version,
state,
is_archived,
walreceiver_status,
@@ -707,7 +686,9 @@ async fn timeline_archival_config_handler(
tenant
.apply_timeline_archival_config(timeline_id, request_data.state)
.await?;
.await
.context("applying archival config")
.map_err(ApiError::InternalServerError)?;
Ok::<_, ApiError>(())
}
.instrument(info_span!("timeline_archival_config",
@@ -1725,6 +1706,11 @@ async fn timeline_compact_handler(
flags |= CompactFlags::ForceImageLayerCreation;
}
if Some(true) == parse_query_param::<_, bool>(&request, "enhanced_gc_bottom_most_compaction")? {
if !cfg!(feature = "testing") {
return Err(ApiError::InternalServerError(anyhow!(
"enhanced_gc_bottom_most_compaction is only available in testing mode"
)));
}
flags |= CompactFlags::EnhancedGcBottomMostCompaction;
}
let wait_until_uploaded =
@@ -2956,7 +2942,7 @@ pub fn make_router(
)
.put(
"/v1/tenant/:tenant_shard_id/timeline/:timeline_id/compact",
|r| api_handler(r, timeline_compact_handler),
|r| testing_api_handler("run timeline compaction", r, timeline_compact_handler),
)
.put(
"/v1/tenant/:tenant_shard_id/timeline/:timeline_id/checkpoint",

View File

@@ -88,8 +88,6 @@ pub async fn shutdown_pageserver(
) {
use std::time::Duration;
let started_at = std::time::Instant::now();
// If the orderly shutdown below takes too long, we still want to make
// sure that all walredo processes are killed and wait()ed on by us, not systemd.
//
@@ -243,10 +241,7 @@ pub async fn shutdown_pageserver(
walredo_extraordinary_shutdown_thread.join().unwrap();
info!("walredo_extraordinary_shutdown_thread done");
info!(
elapsed_ms = started_at.elapsed().as_millis(),
"Shut down successfully completed"
);
info!("Shut down successfully completed");
std::process::exit(exit_code);
}

View File

@@ -1803,14 +1803,6 @@ pub(crate) static SECONDARY_RESIDENT_PHYSICAL_SIZE: Lazy<UIntGaugeVec> = Lazy::n
.expect("failed to define a metric")
});
pub(crate) static NODE_UTILIZATION_SCORE: Lazy<UIntGauge> = Lazy::new(|| {
register_uint_gauge!(
"pageserver_utilization_score",
"The utilization score we report to the storage controller for scheduling, where 0 is empty, 1000000 is full, and anything above is considered overloaded",
)
.expect("failed to define a metric")
});
pub(crate) static SECONDARY_HEATMAP_TOTAL_SIZE: Lazy<UIntGaugeVec> = Lazy::new(|| {
register_uint_gauge_vec!(
"pageserver_secondary_heatmap_total_size",

View File

@@ -15,11 +15,12 @@ use crate::{aux_file, repository::*};
use anyhow::{ensure, Context};
use bytes::{Buf, Bytes, BytesMut};
use enum_map::Enum;
use itertools::Itertools;
use pageserver_api::key::{
dbdir_key_range, rel_block_to_key, rel_dir_to_key, rel_key_range, rel_size_to_key,
relmap_file_key, repl_origin_key, repl_origin_key_range, slru_block_to_key, slru_dir_to_key,
slru_segment_key_range, slru_segment_size_to_key, twophase_file_key, twophase_key_range,
CompactKey, AUX_FILES_KEY, CHECKPOINT_KEY, CONTROLFILE_KEY, DBDIR_KEY, TWOPHASEDIR_KEY,
AUX_FILES_KEY, CHECKPOINT_KEY, CONTROLFILE_KEY, DBDIR_KEY, TWOPHASEDIR_KEY,
};
use pageserver_api::keyspace::SparseKeySpace;
use pageserver_api::models::AuxFilePolicy;
@@ -36,6 +37,7 @@ use tokio_util::sync::CancellationToken;
use tracing::{debug, info, trace, warn};
use utils::bin_ser::DeserializeError;
use utils::pausable_failpoint;
use utils::vec_map::{VecMap, VecMapOrdering};
use utils::{bin_ser::BeSer, lsn::Lsn};
/// Max delta records appended to the AUX_FILES_KEY (for aux v1). The write path will write a full image once this threshold is reached.
@@ -172,7 +174,6 @@ impl Timeline {
pending_deletions: Vec::new(),
pending_nblocks: 0,
pending_directory_entries: Vec::new(),
pending_bytes: 0,
lsn,
}
}
@@ -726,17 +727,7 @@ impl Timeline {
) -> Result<HashMap<String, Bytes>, PageReconstructError> {
let current_policy = self.last_aux_file_policy.load();
match current_policy {
Some(AuxFilePolicy::V1) => {
warn!("this timeline is using deprecated aux file policy V1 (policy=V1)");
self.list_aux_files_v1(lsn, ctx).await
}
None => {
let res = self.list_aux_files_v1(lsn, ctx).await?;
if !res.is_empty() {
warn!("this timeline is using deprecated aux file policy V1 (policy=None)");
}
Ok(res)
}
Some(AuxFilePolicy::V1) | None => self.list_aux_files_v1(lsn, ctx).await,
Some(AuxFilePolicy::V2) => self.list_aux_files_v2(lsn, ctx).await,
Some(AuxFilePolicy::CrossValidation) => {
let v1_result = self.list_aux_files_v1(lsn, ctx).await;
@@ -1031,33 +1022,21 @@ pub struct DatadirModification<'a> {
// The put-functions add the modifications here, and they are flushed to the
// underlying key-value store by the 'finish' function.
pending_lsns: Vec<Lsn>,
pending_updates: HashMap<Key, Vec<(Lsn, usize, Value)>>,
pending_updates: HashMap<Key, Vec<(Lsn, Value)>>,
pending_deletions: Vec<(Range<Key>, Lsn)>,
pending_nblocks: i64,
/// For special "directory" keys that store key-value maps, track the size of the map
/// if it was updated in this modification.
pending_directory_entries: Vec<(DirectoryKind, usize)>,
/// An **approximation** of how large our EphemeralFile write will be when committed.
pending_bytes: usize,
}
impl<'a> DatadirModification<'a> {
// When a DatadirModification is committed, we do a monolithic serialization of all its contents. WAL records can
// contain multiple pages, so the pageserver's record-based batch size isn't sufficient to bound this allocation: we
// additionally specify a limit on how much payload a DatadirModification may contain before it should be committed.
pub(crate) const MAX_PENDING_BYTES: usize = 8 * 1024 * 1024;
/// Get the current lsn
pub(crate) fn get_lsn(&self) -> Lsn {
self.lsn
}
pub(crate) fn approx_pending_bytes(&self) -> usize {
self.pending_bytes
}
/// Set the current lsn
pub(crate) fn set_lsn(&mut self, lsn: Lsn) -> anyhow::Result<()> {
ensure!(
@@ -1597,7 +1576,6 @@ impl<'a> DatadirModification<'a> {
if aux_files_key_v1.is_empty() {
None
} else {
warn!("this timeline is using deprecated aux file policy V1");
self.tline.do_switch_aux_policy(AuxFilePolicy::V1)?;
Some(AuxFilePolicy::V1)
}
@@ -1791,25 +1769,21 @@ impl<'a> DatadirModification<'a> {
// Flush relation and SLRU data blocks, keep metadata.
let mut retained_pending_updates = HashMap::<_, Vec<_>>::new();
for (key, values) in self.pending_updates.drain() {
let mut write_batch = Vec::new();
for (lsn, value_ser_size, value) in values {
for (lsn, value) in values {
if key.is_rel_block_key() || key.is_slru_block_key() {
// This bails out on first error without modifying pending_updates.
// That's Ok, cf this function's doc comment.
write_batch.push((key.to_compact(), lsn, value_ser_size, value));
writer.put(key, lsn, &value, ctx).await?;
} else {
retained_pending_updates.entry(key).or_default().push((
lsn,
value_ser_size,
value,
));
retained_pending_updates
.entry(key)
.or_default()
.push((lsn, value));
}
}
writer.put_batch(write_batch, ctx).await?;
}
self.pending_updates = retained_pending_updates;
self.pending_bytes = 0;
if pending_nblocks != 0 {
writer.update_current_logical_size(pending_nblocks * i64::from(BLCKSZ));
@@ -1835,20 +1809,17 @@ impl<'a> DatadirModification<'a> {
self.pending_nblocks = 0;
if !self.pending_updates.is_empty() {
// Ordering: the items in this batch do not need to be in any global order, but values for
// a particular Key must be in Lsn order relative to one another. InMemoryLayer relies on
// this to do efficient updates to its index.
let batch: Vec<(CompactKey, Lsn, usize, Value)> = self
.pending_updates
.drain()
.flat_map(|(key, values)| {
values.into_iter().map(move |(lsn, val_ser_size, value)| {
(key.to_compact(), lsn, val_ser_size, value)
})
})
.collect::<Vec<_>>();
// The put_batch call below expects expects the inputs to be sorted by Lsn,
// so we do that first.
let lsn_ordered_batch: VecMap<Lsn, (Key, Value)> = VecMap::from_iter(
self.pending_updates
.drain()
.map(|(key, vals)| vals.into_iter().map(move |(lsn, val)| (lsn, (key, val))))
.kmerge_by(|lhs, rhs| lhs.0 < rhs.0),
VecMapOrdering::GreaterOrEqual,
);
writer.put_batch(batch, ctx).await?;
writer.put_batch(lsn_ordered_batch, ctx).await?;
}
if !self.pending_deletions.is_empty() {
@@ -1873,8 +1844,6 @@ impl<'a> DatadirModification<'a> {
writer.update_directory_entries_count(kind, count as u64);
}
self.pending_bytes = 0;
Ok(())
}
@@ -1891,7 +1860,7 @@ impl<'a> DatadirModification<'a> {
// Note: we don't check pending_deletions. It is an error to request a
// value that has been removed, deletion only avoids leaking storage.
if let Some(values) = self.pending_updates.get(&key) {
if let Some((_, _, value)) = values.last() {
if let Some((_, value)) = values.last() {
return if let Value::Image(img) = value {
Ok(img.clone())
} else {
@@ -1919,17 +1888,13 @@ impl<'a> DatadirModification<'a> {
fn put(&mut self, key: Key, val: Value) {
let values = self.pending_updates.entry(key).or_default();
// Replace the previous value if it exists at the same lsn
if let Some((last_lsn, last_value_ser_size, last_value)) = values.last_mut() {
if let Some((last_lsn, last_value)) = values.last_mut() {
if *last_lsn == self.lsn {
*last_value_ser_size = val.serialized_size().unwrap() as usize;
*last_value = val;
return;
}
}
let val_serialized_size = val.serialized_size().unwrap() as usize;
self.pending_bytes += val_serialized_size;
values.push((self.lsn, val_serialized_size, val));
values.push((self.lsn, val));
}
fn delete(&mut self, key_range: Range<Key>) {
@@ -2059,7 +2024,7 @@ mod tests {
let (tenant, ctx) = harness.load().await;
let tline = tenant
.create_empty_timeline(TIMELINE_ID, Lsn(0x10), DEFAULT_PG_VERSION, &ctx)
.create_empty_timeline(TIMELINE_ID, Lsn(0), DEFAULT_PG_VERSION, &ctx)
.await?;
let tline = tline.raw_timeline().unwrap();

View File

@@ -501,38 +501,6 @@ impl Debug for DeleteTimelineError {
}
}
#[derive(thiserror::Error)]
pub enum TimelineArchivalError {
#[error("NotFound")]
NotFound,
#[error("Timeout")]
Timeout,
#[error("HasUnarchivedChildren")]
HasUnarchivedChildren(Vec<TimelineId>),
#[error("Timeline archival is already in progress")]
AlreadyInProgress,
#[error(transparent)]
Other(#[from] anyhow::Error),
}
impl Debug for TimelineArchivalError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NotFound => write!(f, "NotFound"),
Self::Timeout => write!(f, "Timeout"),
Self::HasUnarchivedChildren(c) => {
f.debug_tuple("HasUnarchivedChildren").field(c).finish()
}
Self::AlreadyInProgress => f.debug_tuple("AlreadyInProgress").finish(),
Self::Other(e) => f.debug_tuple("Other").field(e).finish(),
}
}
}
pub enum SetStoppingError {
AlreadyStopping(completion::Barrier),
Broken,
@@ -1358,50 +1326,24 @@ impl Tenant {
&self,
timeline_id: TimelineId,
state: TimelineArchivalState,
) -> Result<(), TimelineArchivalError> {
info!("setting timeline archival config");
let timeline = {
let timelines = self.timelines.lock().unwrap();
let timeline = match timelines.get(&timeline_id) {
Some(t) => t,
None => return Err(TimelineArchivalError::NotFound),
};
// Ensure that there are no non-archived child timelines
let children: Vec<TimelineId> = timelines
.iter()
.filter_map(|(id, entry)| {
if entry.get_ancestor_timeline_id() != Some(timeline_id) {
return None;
}
if entry.is_archived() == Some(true) {
return None;
}
Some(*id)
})
.collect();
if !children.is_empty() && state == TimelineArchivalState::Archived {
return Err(TimelineArchivalError::HasUnarchivedChildren(children));
}
Arc::clone(timeline)
};
) -> anyhow::Result<()> {
let timeline = self
.get_timeline(timeline_id, false)
.context("Cannot apply timeline archival config to inexistent timeline")?;
let upload_needed = timeline
.remote_client
.schedule_index_upload_for_timeline_archival_state(state)?;
if upload_needed {
info!("Uploading new state");
const MAX_WAIT: Duration = Duration::from_secs(10);
let Ok(v) =
tokio::time::timeout(MAX_WAIT, timeline.remote_client.wait_completion()).await
else {
tracing::warn!("reached timeout for waiting on upload queue");
return Err(TimelineArchivalError::Timeout);
bail!("reached timeout for upload queue flush");
};
v.map_err(|e| TimelineArchivalError::Other(anyhow::anyhow!(e)))?;
v?;
}
Ok(())
}
@@ -3799,21 +3741,13 @@ impl Tenant {
/// less than this (via eviction and on-demand downloads), but this function enables
/// the Tenant to advertise how much storage it would prefer to have to provide fast I/O
/// by keeping important things on local disk.
///
/// This is a heuristic, not a guarantee: tenants that are long-idle will actually use less
/// than they report here, due to layer eviction. Tenants with many active branches may
/// actually use more than they report here.
pub(crate) fn local_storage_wanted(&self) -> u64 {
let mut wanted = 0;
let timelines = self.timelines.lock().unwrap();
// Heuristic: we use the max() of the timelines' visible sizes, rather than the sum. This
// reflects the observation that on tenants with multiple large branches, typically only one
// of them is used actively enough to occupy space on disk.
timelines
.values()
.map(|t| t.metrics.visible_physical_size_gauge.get())
.max()
.unwrap_or(0)
for timeline in timelines.values() {
wanted += timeline.metrics.visible_physical_size_gauge.get();
}
wanted
}
}
@@ -5998,10 +5932,10 @@ mod tests {
.await
.unwrap();
// the default aux file policy to switch is v2 if not set by the admins
// the default aux file policy to switch is v1 if not set by the admins
assert_eq!(
harness.tenant_conf.switch_aux_file_policy,
AuxFilePolicy::default_tenant_config()
AuxFilePolicy::V1
);
let (tenant, ctx) = harness.load().await;
@@ -6045,8 +5979,8 @@ mod tests {
);
assert_eq!(
tline.last_aux_file_policy.load(),
Some(AuxFilePolicy::V2),
"aux file is written with switch_aux_file_policy unset (which is v2), so we should use v2 there"
Some(AuxFilePolicy::V1),
"aux file is written with switch_aux_file_policy unset (which is v1), so we should keep v1"
);
// we can read everything from the storage
@@ -6068,8 +6002,8 @@ mod tests {
assert_eq!(
tline.last_aux_file_policy.load(),
Some(AuxFilePolicy::V2),
"keep v2 storage format when new files are written"
Some(AuxFilePolicy::V1),
"keep v1 storage format when new files are written"
);
let files = tline.list_aux_files(lsn, &ctx).await.unwrap();
@@ -6085,7 +6019,7 @@ mod tests {
// child copies the last flag even if that is not on remote storage yet
assert_eq!(child.get_switch_aux_file_policy(), AuxFilePolicy::V2);
assert_eq!(child.last_aux_file_policy.load(), Some(AuxFilePolicy::V2));
assert_eq!(child.last_aux_file_policy.load(), Some(AuxFilePolicy::V1));
let files = child.list_aux_files(lsn, &ctx).await.unwrap();
assert_eq!(files.get("pg_logical/mappings/test1"), None);
@@ -7071,14 +7005,18 @@ mod tests {
vec![
// Image layer at GC horizon
PersistentLayerKey {
key_range: Key::MIN..Key::NON_L0_MAX,
key_range: {
let mut key = Key::MAX;
key.field6 -= 1;
Key::MIN..key
},
lsn_range: Lsn(0x30)..Lsn(0x31),
is_delta: false
},
// The delta layer covers the full range (with the layer key hack to avoid being recognized as L0)
// The delta layer that is cut in the middle
PersistentLayerKey {
key_range: Key::MIN..Key::NON_L0_MAX,
lsn_range: Lsn(0x30)..Lsn(0x48),
key_range: get_key(3)..get_key(4),
lsn_range: Lsn(0x30)..Lsn(0x41),
is_delta: true
},
// The delta3 layer that should not be picked for the compaction
@@ -8058,214 +7996,6 @@ mod tests {
Ok(())
}
#[tokio::test]
async fn test_simple_bottom_most_compaction_with_retain_lsns_single_key() -> anyhow::Result<()>
{
let harness =
TenantHarness::create("test_simple_bottom_most_compaction_with_retain_lsns_single_key")
.await?;
let (tenant, ctx) = harness.load().await;
fn get_key(id: u32) -> Key {
// using aux key here b/c they are guaranteed to be inside `collect_keyspace`.
let mut key = Key::from_hex("620000000033333333444444445500000000").unwrap();
key.field6 = id;
key
}
let img_layer = (0..10)
.map(|id| (get_key(id), Bytes::from(format!("value {id}@0x10"))))
.collect_vec();
let delta1 = vec![
(
get_key(1),
Lsn(0x20),
Value::WalRecord(NeonWalRecord::wal_append("@0x20")),
),
(
get_key(1),
Lsn(0x28),
Value::WalRecord(NeonWalRecord::wal_append("@0x28")),
),
];
let delta2 = vec![
(
get_key(1),
Lsn(0x30),
Value::WalRecord(NeonWalRecord::wal_append("@0x30")),
),
(
get_key(1),
Lsn(0x38),
Value::WalRecord(NeonWalRecord::wal_append("@0x38")),
),
];
let delta3 = vec![
(
get_key(8),
Lsn(0x48),
Value::WalRecord(NeonWalRecord::wal_append("@0x48")),
),
(
get_key(9),
Lsn(0x48),
Value::WalRecord(NeonWalRecord::wal_append("@0x48")),
),
];
let tline = tenant
.create_test_timeline_with_layers(
TIMELINE_ID,
Lsn(0x10),
DEFAULT_PG_VERSION,
&ctx,
vec![
// delta1 and delta 2 only contain a single key but multiple updates
DeltaLayerTestDesc::new_with_inferred_key_range(Lsn(0x10)..Lsn(0x30), delta1),
DeltaLayerTestDesc::new_with_inferred_key_range(Lsn(0x30)..Lsn(0x50), delta2),
DeltaLayerTestDesc::new_with_inferred_key_range(Lsn(0x10)..Lsn(0x50), delta3),
], // delta layers
vec![(Lsn(0x10), img_layer)], // image layers
Lsn(0x50),
)
.await?;
{
// Update GC info
let mut guard = tline.gc_info.write().unwrap();
*guard = GcInfo {
retain_lsns: vec![
(Lsn(0x10), tline.timeline_id),
(Lsn(0x20), tline.timeline_id),
],
cutoffs: GcCutoffs {
time: Lsn(0x30),
space: Lsn(0x30),
},
leases: Default::default(),
within_ancestor_pitr: false,
};
}
let expected_result = [
Bytes::from_static(b"value 0@0x10"),
Bytes::from_static(b"value 1@0x10@0x20@0x28@0x30@0x38"),
Bytes::from_static(b"value 2@0x10"),
Bytes::from_static(b"value 3@0x10"),
Bytes::from_static(b"value 4@0x10"),
Bytes::from_static(b"value 5@0x10"),
Bytes::from_static(b"value 6@0x10"),
Bytes::from_static(b"value 7@0x10"),
Bytes::from_static(b"value 8@0x10@0x48"),
Bytes::from_static(b"value 9@0x10@0x48"),
];
let expected_result_at_gc_horizon = [
Bytes::from_static(b"value 0@0x10"),
Bytes::from_static(b"value 1@0x10@0x20@0x28@0x30"),
Bytes::from_static(b"value 2@0x10"),
Bytes::from_static(b"value 3@0x10"),
Bytes::from_static(b"value 4@0x10"),
Bytes::from_static(b"value 5@0x10"),
Bytes::from_static(b"value 6@0x10"),
Bytes::from_static(b"value 7@0x10"),
Bytes::from_static(b"value 8@0x10"),
Bytes::from_static(b"value 9@0x10"),
];
let expected_result_at_lsn_20 = [
Bytes::from_static(b"value 0@0x10"),
Bytes::from_static(b"value 1@0x10@0x20"),
Bytes::from_static(b"value 2@0x10"),
Bytes::from_static(b"value 3@0x10"),
Bytes::from_static(b"value 4@0x10"),
Bytes::from_static(b"value 5@0x10"),
Bytes::from_static(b"value 6@0x10"),
Bytes::from_static(b"value 7@0x10"),
Bytes::from_static(b"value 8@0x10"),
Bytes::from_static(b"value 9@0x10"),
];
let expected_result_at_lsn_10 = [
Bytes::from_static(b"value 0@0x10"),
Bytes::from_static(b"value 1@0x10"),
Bytes::from_static(b"value 2@0x10"),
Bytes::from_static(b"value 3@0x10"),
Bytes::from_static(b"value 4@0x10"),
Bytes::from_static(b"value 5@0x10"),
Bytes::from_static(b"value 6@0x10"),
Bytes::from_static(b"value 7@0x10"),
Bytes::from_static(b"value 8@0x10"),
Bytes::from_static(b"value 9@0x10"),
];
let verify_result = || async {
let gc_horizon = {
let gc_info = tline.gc_info.read().unwrap();
gc_info.cutoffs.time
};
for idx in 0..10 {
assert_eq!(
tline
.get(get_key(idx as u32), Lsn(0x50), &ctx)
.await
.unwrap(),
&expected_result[idx]
);
assert_eq!(
tline
.get(get_key(idx as u32), gc_horizon, &ctx)
.await
.unwrap(),
&expected_result_at_gc_horizon[idx]
);
assert_eq!(
tline
.get(get_key(idx as u32), Lsn(0x20), &ctx)
.await
.unwrap(),
&expected_result_at_lsn_20[idx]
);
assert_eq!(
tline
.get(get_key(idx as u32), Lsn(0x10), &ctx)
.await
.unwrap(),
&expected_result_at_lsn_10[idx]
);
}
};
verify_result().await;
let cancel = CancellationToken::new();
let mut dryrun_flags = EnumSet::new();
dryrun_flags.insert(CompactFlags::DryRun);
tline
.compact_with_gc(&cancel, dryrun_flags, &ctx)
.await
.unwrap();
// We expect layer map to be the same b/c the dry run flag, but we don't know whether there will be other background jobs
// cleaning things up, and therefore, we don't do sanity checks on the layer map during unit tests.
verify_result().await;
tline
.compact_with_gc(&cancel, EnumSet::new(), &ctx)
.await
.unwrap();
verify_result().await;
// compact again
tline
.compact_with_gc(&cancel, EnumSet::new(), &ctx)
.await
.unwrap();
verify_result().await;
Ok(())
}
#[tokio::test]
async fn test_simple_bottom_most_compaction_on_branch() -> anyhow::Result<()> {
let harness = TenantHarness::create("test_simple_bottom_most_compaction_on_branch").await?;

View File

@@ -79,8 +79,6 @@ impl EphemeralFile {
self.rw.read_blk(blknum, ctx).await
}
#[cfg(test)]
// This is a test helper: outside of tests, we are always written to via a pre-serialized batch.
pub(crate) async fn write_blob(
&mut self,
srcbuf: &[u8],
@@ -88,30 +86,17 @@ impl EphemeralFile {
) -> Result<u64, io::Error> {
let pos = self.rw.bytes_written();
let mut len_bytes = std::io::Cursor::new(Vec::new());
crate::tenant::storage_layer::inmemory_layer::SerializedBatch::write_blob_length(
srcbuf.len(),
&mut len_bytes,
);
let len_bytes = len_bytes.into_inner();
// Write the length field
self.rw.write_all_borrowed(&len_bytes, ctx).await?;
if srcbuf.len() < 0x80 {
// short one-byte length header
let len_buf = [srcbuf.len() as u8];
// Write the payload
self.rw.write_all_borrowed(srcbuf, ctx).await?;
Ok(pos)
}
/// Returns the offset at which the first byte of the input was written, for use
/// in constructing indices over the written value.
pub(crate) async fn write_raw(
&mut self,
srcbuf: &[u8],
ctx: &RequestContext,
) -> Result<u64, io::Error> {
let pos = self.rw.bytes_written();
self.rw.write_all_borrowed(&len_buf, ctx).await?;
} else {
let mut len_buf = u32::to_be_bytes(srcbuf.len() as u32);
len_buf[0] |= 0x80;
self.rw.write_all_borrowed(&len_buf, ctx).await?;
}
// Write the payload
self.rw.write_all_borrowed(srcbuf, ctx).await?;

View File

@@ -464,7 +464,7 @@ impl LayerMap {
pub(self) fn insert_historic_noflush(&mut self, layer_desc: PersistentLayerDesc) {
// TODO: See #3869, resulting #4088, attempted fix and repro #4094
if Self::is_l0(&layer_desc.key_range, layer_desc.is_delta) {
if Self::is_l0(&layer_desc.key_range) {
self.l0_delta_layers.push(layer_desc.clone().into());
}
@@ -483,7 +483,7 @@ impl LayerMap {
self.historic
.remove(historic_layer_coverage::LayerKey::from(layer_desc));
let layer_key = layer_desc.key();
if Self::is_l0(&layer_desc.key_range, layer_desc.is_delta) {
if Self::is_l0(&layer_desc.key_range) {
let len_before = self.l0_delta_layers.len();
let mut l0_delta_layers = std::mem::take(&mut self.l0_delta_layers);
l0_delta_layers.retain(|other| other.key() != layer_key);
@@ -600,8 +600,8 @@ impl LayerMap {
}
/// Check if the key range resembles that of an L0 layer.
pub fn is_l0(key_range: &Range<Key>, is_delta_layer: bool) -> bool {
is_delta_layer && key_range == &(Key::MIN..Key::MAX)
pub fn is_l0(key_range: &Range<Key>) -> bool {
key_range == &(Key::MIN..Key::MAX)
}
/// This function determines which layers are counted in `count_deltas`:
@@ -628,7 +628,7 @@ impl LayerMap {
/// than just the current partition_range.
pub fn is_reimage_worthy(layer: &PersistentLayerDesc, partition_range: &Range<Key>) -> bool {
// Case 1
if !Self::is_l0(&layer.key_range, layer.is_delta) {
if !Self::is_l0(&layer.key_range) {
return true;
}

View File

@@ -2,12 +2,13 @@
pub mod delta_layer;
pub mod image_layer;
pub mod inmemory_layer;
pub(crate) mod inmemory_layer;
pub(crate) mod layer;
mod layer_desc;
mod layer_name;
pub mod merge_iterator;
#[cfg(test)]
pub mod split_writer;
use crate::context::{AccessStatsBehavior, RequestContext};

View File

@@ -36,7 +36,6 @@ use crate::tenant::block_io::{BlockBuf, BlockCursor, BlockLease, BlockReader, Fi
use crate::tenant::disk_btree::{
DiskBtreeBuilder, DiskBtreeIterator, DiskBtreeReader, VisitDirection,
};
use crate::tenant::storage_layer::layer::S3_UPLOAD_LIMIT;
use crate::tenant::timeline::GetVectoredError;
use crate::tenant::vectored_blob_io::{
BlobFlag, MaxVectoredReadBytes, StreamingVectoredReadPlanner, VectoredBlobReader, VectoredRead,
@@ -233,18 +232,6 @@ pub struct DeltaLayerInner {
max_vectored_read_bytes: Option<MaxVectoredReadBytes>,
}
impl DeltaLayerInner {
pub(crate) fn layer_dbg_info(&self) -> String {
format!(
"delta {}..{} {}..{}",
self.key_range().start,
self.key_range().end,
self.lsn_range().start,
self.lsn_range().end
)
}
}
impl std::fmt::Debug for DeltaLayerInner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DeltaLayerInner")
@@ -569,6 +556,7 @@ impl DeltaLayerWriterInner {
// 5GB limit for objects without multipart upload (which we don't want to use)
// Make it a little bit below to account for differing GB units
// https://docs.aws.amazon.com/AmazonS3/latest/userguide/upload-objects.html
const S3_UPLOAD_LIMIT: u64 = 4_500_000_000;
ensure!(
metadata.len() <= S3_UPLOAD_LIMIT,
"Created delta layer file at {} of size {} above limit {S3_UPLOAD_LIMIT}!",
@@ -702,10 +690,12 @@ impl DeltaLayerWriter {
self.inner.take().unwrap().finish(key_end, ctx).await
}
#[cfg(test)]
pub(crate) fn num_keys(&self) -> usize {
self.inner.as_ref().unwrap().num_keys
}
#[cfg(test)]
pub(crate) fn estimated_size(&self) -> u64 {
let inner = self.inner.as_ref().unwrap();
inner.blob_writer.size() + inner.tree.borrow_writer().size() + PAGE_SZ as u64
@@ -1537,10 +1527,6 @@ pub struct DeltaLayerIterator<'a> {
}
impl<'a> DeltaLayerIterator<'a> {
pub(crate) fn layer_dbg_info(&self) -> String {
self.delta_layer.layer_dbg_info()
}
/// Retrieve a batch of key-value pairs into the iterator buffer.
async fn next_batch(&mut self) -> anyhow::Result<()> {
assert!(self.key_values_batch.is_empty());

View File

@@ -167,17 +167,6 @@ pub struct ImageLayerInner {
max_vectored_read_bytes: Option<MaxVectoredReadBytes>,
}
impl ImageLayerInner {
pub(crate) fn layer_dbg_info(&self) -> String {
format!(
"image {}..{} {}",
self.key_range().start,
self.key_range().end,
self.lsn()
)
}
}
impl std::fmt::Debug for ImageLayerInner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ImageLayerInner")
@@ -716,6 +705,10 @@ struct ImageLayerWriterInner {
}
impl ImageLayerWriterInner {
fn size(&self) -> u64 {
self.tree.borrow_writer().size() + self.blob_writer.size()
}
///
/// Start building a new image layer.
///
@@ -850,19 +843,13 @@ impl ImageLayerWriterInner {
res?;
}
let final_key_range = if let Some(end_key) = end_key {
self.key_range.start..end_key
} else {
self.key_range.clone()
};
// Fill in the summary on blk 0
let summary = Summary {
magic: IMAGE_FILE_MAGIC,
format_version: STORAGE_FORMAT_VERSION,
tenant_id: self.tenant_shard_id.tenant_id,
timeline_id: self.timeline_id,
key_range: final_key_range.clone(),
key_range: self.key_range.clone(),
lsn: self.lsn,
index_start_blk,
index_root_blk,
@@ -883,7 +870,11 @@ impl ImageLayerWriterInner {
let desc = PersistentLayerDesc::new_img(
self.tenant_shard_id,
self.timeline_id,
final_key_range,
if let Some(end_key) = end_key {
self.key_range.start..end_key
} else {
self.key_range.clone()
},
self.lsn,
metadata.len(),
);
@@ -972,12 +963,14 @@ impl ImageLayerWriter {
self.inner.as_mut().unwrap().put_image(key, img, ctx).await
}
#[cfg(test)]
/// Estimated size of the image layer.
pub(crate) fn estimated_size(&self) -> u64 {
let inner = self.inner.as_ref().unwrap();
inner.blob_writer.size() + inner.tree.borrow_writer().size() + PAGE_SZ as u64
}
#[cfg(test)]
pub(crate) fn num_keys(&self) -> usize {
self.inner.as_ref().unwrap().num_keys
}
@@ -993,6 +986,7 @@ impl ImageLayerWriter {
self.inner.take().unwrap().finish(timeline, ctx, None).await
}
#[cfg(test)]
/// Finish writing the image layer with an end key, used in [`super::split_writer::SplitImageLayerWriter`]. The end key determines the end of the image layer's covered range and is exclusive.
pub(super) async fn finish_with_end_key(
mut self,
@@ -1006,6 +1000,10 @@ impl ImageLayerWriter {
.finish(timeline, ctx, Some(end_key))
.await
}
pub(crate) fn size(&self) -> u64 {
self.inner.as_ref().unwrap().size()
}
}
impl Drop for ImageLayerWriter {
@@ -1026,10 +1024,6 @@ pub struct ImageLayerIterator<'a> {
}
impl<'a> ImageLayerIterator<'a> {
pub(crate) fn layer_dbg_info(&self) -> String {
self.image_layer.layer_dbg_info()
}
/// Retrieve a batch of key-value pairs into the iterator buffer.
async fn next_batch(&mut self) -> anyhow::Result<()> {
assert!(self.key_values_batch.is_empty());

View File

@@ -33,7 +33,7 @@ use std::fmt::Write;
use std::ops::Range;
use std::sync::atomic::Ordering as AtomicOrdering;
use std::sync::atomic::{AtomicU64, AtomicUsize};
use tokio::sync::RwLock;
use tokio::sync::{RwLock, RwLockWriteGuard};
use super::{
DeltaLayerWriter, PersistentLayerDesc, ValueReconstructSituation, ValuesReconstructState,
@@ -320,82 +320,6 @@ impl InMemoryLayer {
}
}
/// Offset of a particular Value within a serialized batch.
struct SerializedBatchOffset {
key: CompactKey,
lsn: Lsn,
/// offset in bytes from the start of the batch's buffer to the Value's serialized size header.
offset: u64,
}
pub struct SerializedBatch {
/// Blobs serialized in EphemeralFile's native format, ready for passing to [`EphemeralFile::write_raw`].
pub(crate) raw: Vec<u8>,
/// Index of values in [`Self::raw`], using offsets relative to the start of the buffer.
offsets: Vec<SerializedBatchOffset>,
/// The highest LSN of any value in the batch
pub(crate) max_lsn: Lsn,
}
impl SerializedBatch {
/// Write a blob length in the internal format of the EphemeralFile
pub(crate) fn write_blob_length(len: usize, cursor: &mut std::io::Cursor<Vec<u8>>) {
use std::io::Write;
if len < 0x80 {
// short one-byte length header
let len_buf = [len as u8];
cursor
.write_all(&len_buf)
.expect("Writing to Vec is infallible");
} else {
let mut len_buf = u32::to_be_bytes(len as u32);
len_buf[0] |= 0x80;
cursor
.write_all(&len_buf)
.expect("Writing to Vec is infallible");
}
}
pub fn from_values(batch: Vec<(CompactKey, Lsn, usize, Value)>) -> Self {
// Pre-allocate a big flat buffer to write into. This should be large but not huge: it is soft-limited in practice by
// [`crate::pgdatadir_mapping::DatadirModification::MAX_PENDING_BYTES`]
let buffer_size = batch.iter().map(|i| i.2).sum::<usize>() + 4 * batch.len();
let mut cursor = std::io::Cursor::new(Vec::<u8>::with_capacity(buffer_size));
let mut offsets: Vec<SerializedBatchOffset> = Vec::with_capacity(batch.len());
let mut max_lsn: Lsn = Lsn(0);
for (key, lsn, val_ser_size, val) in batch {
let relative_off = cursor.position();
Self::write_blob_length(val_ser_size, &mut cursor);
val.ser_into(&mut cursor)
.expect("Writing into in-memory buffer is infallible");
offsets.push(SerializedBatchOffset {
key,
lsn,
offset: relative_off,
});
max_lsn = std::cmp::max(max_lsn, lsn);
}
let buffer = cursor.into_inner();
// Assert that we didn't do any extra allocations while building buffer.
debug_assert!(buffer.len() <= buffer_size);
Self {
raw: buffer,
offsets,
max_lsn,
}
}
}
fn inmem_layer_display(mut f: impl Write, start_lsn: Lsn, end_lsn: Lsn) -> std::fmt::Result {
write!(f, "inmem-{:016X}-{:016X}", start_lsn.0, end_lsn.0)
}
@@ -456,20 +380,37 @@ impl InMemoryLayer {
})
}
// Write path.
pub async fn put_batch(
// Write operations
/// Common subroutine of the public put_wal_record() and put_page_image() functions.
/// Adds the page version to the in-memory tree
pub async fn put_value(
&self,
serialized_batch: SerializedBatch,
key: CompactKey,
lsn: Lsn,
buf: &[u8],
ctx: &RequestContext,
) -> Result<()> {
let mut inner = self.inner.write().await;
self.assert_writable();
self.put_value_locked(&mut inner, key, lsn, buf, ctx).await
}
let base_off = {
inner
async fn put_value_locked(
&self,
locked_inner: &mut RwLockWriteGuard<'_, InMemoryLayerInner>,
key: CompactKey,
lsn: Lsn,
buf: &[u8],
ctx: &RequestContext,
) -> Result<()> {
trace!("put_value key {} at {}/{}", key, self.timeline_id, lsn);
let off = {
locked_inner
.file
.write_raw(
&serialized_batch.raw,
.write_blob(
buf,
&RequestContextBuilder::extend(ctx)
.page_content_kind(PageContentKind::InMemoryLayer)
.build(),
@@ -477,23 +418,15 @@ impl InMemoryLayer {
.await?
};
for SerializedBatchOffset {
key,
lsn,
offset: relative_off,
} in serialized_batch.offsets
{
let off = base_off + relative_off;
let vec_map = inner.index.entry(key).or_default();
let old = vec_map.append_or_update_last(lsn, off).unwrap().0;
if old.is_some() {
// We already had an entry for this LSN. That's odd..
warn!("Key {} at {} already exists", key, lsn);
}
let vec_map = locked_inner.index.entry(key).or_default();
let old = vec_map.append_or_update_last(lsn, off).unwrap().0;
if old.is_some() {
// We already had an entry for this LSN. That's odd..
warn!("Key {} at {} already exists", key, lsn);
}
let size = inner.file.len();
inner.resource_units.maybe_publish_size(size);
let size = locked_inner.file.len();
locked_inner.resource_units.maybe_publish_size(size);
Ok(())
}

View File

@@ -35,8 +35,6 @@ mod tests;
#[cfg(test)]
mod failpoints;
pub const S3_UPLOAD_LIMIT: u64 = 4_500_000_000;
/// A Layer contains all data in a "rectangle" consisting of a range of keys and
/// range of LSNs.
///
@@ -1298,10 +1296,7 @@ impl LayerInner {
lsn_end: lsn_range.end,
remote: !resident,
access_stats,
l0: crate::tenant::layer_map::LayerMap::is_l0(
&self.layer_desc().key_range,
self.layer_desc().is_delta,
),
l0: crate::tenant::layer_map::LayerMap::is_l0(&self.layer_desc().key_range),
}
} else {
let lsn = self.desc.image_layer_lsn();

View File

@@ -256,10 +256,6 @@ impl LayerName {
LayerName::Delta(layer) => &layer.key_range,
}
}
pub fn is_delta(&self) -> bool {
matches!(self, LayerName::Delta(_))
}
}
impl fmt::Display for LayerName {

View File

@@ -3,7 +3,6 @@ use std::{
collections::{binary_heap, BinaryHeap},
};
use anyhow::bail;
use pageserver_api::key::Key;
use utils::lsn::Lsn;
@@ -27,13 +26,6 @@ impl<'a> LayerRef<'a> {
Self::Delta(x) => LayerIterRef::Delta(x.iter(ctx)),
}
}
fn layer_dbg_info(&self) -> String {
match self {
Self::Image(x) => x.layer_dbg_info(),
Self::Delta(x) => x.layer_dbg_info(),
}
}
}
enum LayerIterRef<'a> {
@@ -48,13 +40,6 @@ impl LayerIterRef<'_> {
Self::Image(x) => x.next().await,
}
}
fn layer_dbg_info(&self) -> String {
match self {
Self::Image(x) => x.layer_dbg_info(),
Self::Delta(x) => x.layer_dbg_info(),
}
}
}
/// This type plays several roles at once
@@ -90,11 +75,6 @@ impl<'a> PeekableLayerIterRef<'a> {
async fn next(&mut self) -> anyhow::Result<Option<(Key, Lsn, Value)>> {
let result = self.peeked.take();
self.peeked = self.iter.next().await?;
if let (Some((k1, l1, _)), Some((k2, l2, _))) = (&self.peeked, &result) {
if (k1, l1) < (k2, l2) {
bail!("iterator is not ordered: {}", self.iter.layer_dbg_info());
}
}
Ok(result)
}
}
@@ -198,12 +178,7 @@ impl<'a> IteratorWrapper<'a> {
let iter = PeekableLayerIterRef::create(iter).await?;
if let Some((k1, l1, _)) = iter.peek() {
let (k2, l2) = first_key_lower_bound;
if (k1, l1) < (k2, l2) {
bail!(
"layer key range did not include the first key in the layer: {}",
layer.layer_dbg_info()
);
}
debug_assert!((k1, l1) >= (k2, l2));
}
*self = Self::Loaded { iter };
Ok(())

View File

@@ -1,4 +1,4 @@
use std::{future::Future, ops::Range, sync::Arc};
use std::{ops::Range, sync::Arc};
use bytes::Bytes;
use pageserver_api::key::{Key, KEY_SIZE};
@@ -7,32 +7,7 @@ use utils::{id::TimelineId, lsn::Lsn, shard::TenantShardId};
use crate::tenant::storage_layer::Layer;
use crate::{config::PageServerConf, context::RequestContext, repository::Value, tenant::Timeline};
use super::layer::S3_UPLOAD_LIMIT;
use super::{
DeltaLayerWriter, ImageLayerWriter, PersistentLayerDesc, PersistentLayerKey, ResidentLayer,
};
pub(crate) enum SplitWriterResult {
Produced(ResidentLayer),
Discarded(PersistentLayerKey),
}
#[cfg(test)]
impl SplitWriterResult {
fn into_resident_layer(self) -> ResidentLayer {
match self {
SplitWriterResult::Produced(layer) => layer,
SplitWriterResult::Discarded(_) => panic!("unexpected discarded layer"),
}
}
fn into_discarded_layer(self) -> PersistentLayerKey {
match self {
SplitWriterResult::Produced(_) => panic!("unexpected produced layer"),
SplitWriterResult::Discarded(layer) => layer,
}
}
}
use super::{DeltaLayerWriter, ImageLayerWriter, ResidentLayer};
/// An image writer that takes images and produces multiple image layers. The interface does not
/// guarantee atomicity (i.e., if the image layer generation fails, there might be leftover files
@@ -41,12 +16,11 @@ impl SplitWriterResult {
pub struct SplitImageLayerWriter {
inner: ImageLayerWriter,
target_layer_size: u64,
generated_layers: Vec<SplitWriterResult>,
generated_layers: Vec<ResidentLayer>,
conf: &'static PageServerConf,
timeline_id: TimelineId,
tenant_shard_id: TenantShardId,
lsn: Lsn,
start_key: Key,
}
impl SplitImageLayerWriter {
@@ -75,22 +49,16 @@ impl SplitImageLayerWriter {
timeline_id,
tenant_shard_id,
lsn,
start_key,
})
}
pub async fn put_image_with_discard_fn<D, F>(
pub async fn put_image(
&mut self,
key: Key,
img: Bytes,
tline: &Arc<Timeline>,
ctx: &RequestContext,
discard: D,
) -> anyhow::Result<()>
where
D: FnOnce(&PersistentLayerKey) -> F,
F: Future<Output = bool>,
{
) -> anyhow::Result<()> {
// The current estimation is an upper bound of the space that the key/image could take
// because we did not consider compression in this estimation. The resulting image layer
// could be smaller than the target size.
@@ -108,87 +76,33 @@ impl SplitImageLayerWriter {
)
.await?;
let prev_image_writer = std::mem::replace(&mut self.inner, next_image_writer);
let layer_key = PersistentLayerKey {
key_range: self.start_key..key,
lsn_range: PersistentLayerDesc::image_layer_lsn_range(self.lsn),
is_delta: false,
};
self.start_key = key;
if discard(&layer_key).await {
drop(prev_image_writer);
self.generated_layers
.push(SplitWriterResult::Discarded(layer_key));
} else {
self.generated_layers.push(SplitWriterResult::Produced(
prev_image_writer
.finish_with_end_key(tline, key, ctx)
.await?,
));
}
self.generated_layers.push(
prev_image_writer
.finish_with_end_key(tline, key, ctx)
.await?,
);
}
self.inner.put_image(key, img, ctx).await
}
#[cfg(test)]
pub async fn put_image(
&mut self,
key: Key,
img: Bytes,
tline: &Arc<Timeline>,
ctx: &RequestContext,
) -> anyhow::Result<()> {
self.put_image_with_discard_fn(key, img, tline, ctx, |_| async { false })
.await
}
pub(crate) async fn finish_with_discard_fn<D, F>(
self,
tline: &Arc<Timeline>,
ctx: &RequestContext,
end_key: Key,
discard: D,
) -> anyhow::Result<Vec<SplitWriterResult>>
where
D: FnOnce(&PersistentLayerKey) -> F,
F: Future<Output = bool>,
{
let Self {
mut generated_layers,
inner,
..
} = self;
if inner.num_keys() == 0 {
return Ok(generated_layers);
}
let layer_key = PersistentLayerKey {
key_range: self.start_key..end_key,
lsn_range: PersistentLayerDesc::image_layer_lsn_range(self.lsn),
is_delta: false,
};
if discard(&layer_key).await {
generated_layers.push(SplitWriterResult::Discarded(layer_key));
} else {
generated_layers.push(SplitWriterResult::Produced(
inner.finish_with_end_key(tline, end_key, ctx).await?,
));
}
Ok(generated_layers)
}
#[cfg(test)]
pub(crate) async fn finish(
self,
tline: &Arc<Timeline>,
ctx: &RequestContext,
end_key: Key,
) -> anyhow::Result<Vec<SplitWriterResult>> {
self.finish_with_discard_fn(tline, ctx, end_key, |_| async { false })
.await
) -> anyhow::Result<Vec<ResidentLayer>> {
let Self {
mut generated_layers,
inner,
..
} = self;
generated_layers.push(inner.finish_with_end_key(tline, end_key, ctx).await?);
Ok(generated_layers)
}
/// When split writer fails, the caller should call this function and handle partially generated layers.
pub(crate) fn take(self) -> anyhow::Result<(Vec<SplitWriterResult>, ImageLayerWriter)> {
#[allow(dead_code)]
pub(crate) async fn take(self) -> anyhow::Result<(Vec<ResidentLayer>, ImageLayerWriter)> {
Ok((self.generated_layers, self.inner))
}
}
@@ -196,21 +110,15 @@ impl SplitImageLayerWriter {
/// A delta writer that takes key-lsn-values and produces multiple delta layers. The interface does not
/// guarantee atomicity (i.e., if the delta layer generation fails, there might be leftover files
/// to be cleaned up).
///
/// Note that if updates of a single key exceed the target size limit, all of the updates will be batched
/// into a single file. This behavior might change in the future. For reference, the legacy compaction algorithm
/// will split them into multiple files based on size.
#[must_use]
pub struct SplitDeltaLayerWriter {
inner: DeltaLayerWriter,
target_layer_size: u64,
generated_layers: Vec<SplitWriterResult>,
generated_layers: Vec<ResidentLayer>,
conf: &'static PageServerConf,
timeline_id: TimelineId,
tenant_shard_id: TenantShardId,
lsn_range: Range<Lsn>,
last_key_written: Key,
start_key: Key,
}
impl SplitDeltaLayerWriter {
@@ -239,74 +147,9 @@ impl SplitDeltaLayerWriter {
timeline_id,
tenant_shard_id,
lsn_range,
last_key_written: Key::MIN,
start_key,
})
}
/// Put value into the layer writer. In the case the writer decides to produce a layer, and the discard fn returns true, no layer will be written in the end.
pub async fn put_value_with_discard_fn<D, F>(
&mut self,
key: Key,
lsn: Lsn,
val: Value,
tline: &Arc<Timeline>,
ctx: &RequestContext,
discard: D,
) -> anyhow::Result<()>
where
D: FnOnce(&PersistentLayerKey) -> F,
F: Future<Output = bool>,
{
// The current estimation is key size plus LSN size plus value size estimation. This is not an accurate
// number, and therefore the final layer size could be a little bit larger or smaller than the target.
//
// Also, keep all updates of a single key in a single file. TODO: split them using the legacy compaction
// strategy. https://github.com/neondatabase/neon/issues/8837
let addition_size_estimation = KEY_SIZE as u64 + 8 /* LSN u64 size */ + 80 /* value size estimation */;
if self.inner.num_keys() >= 1
&& self.inner.estimated_size() + addition_size_estimation >= self.target_layer_size
{
if key != self.last_key_written {
let next_delta_writer = DeltaLayerWriter::new(
self.conf,
self.timeline_id,
self.tenant_shard_id,
key,
self.lsn_range.clone(),
ctx,
)
.await?;
let prev_delta_writer = std::mem::replace(&mut self.inner, next_delta_writer);
let layer_key = PersistentLayerKey {
key_range: self.start_key..key,
lsn_range: self.lsn_range.clone(),
is_delta: true,
};
self.start_key = key;
if discard(&layer_key).await {
drop(prev_delta_writer);
self.generated_layers
.push(SplitWriterResult::Discarded(layer_key));
} else {
let (desc, path) = prev_delta_writer.finish(key, ctx).await?;
let delta_layer = Layer::finish_creating(self.conf, tline, desc, &path)?;
self.generated_layers
.push(SplitWriterResult::Produced(delta_layer));
}
} else if self.inner.estimated_size() >= S3_UPLOAD_LIMIT {
// We have to produce a very large file b/c a key is updated too often.
anyhow::bail!(
"a single key is updated too often: key={}, estimated_size={}, and the layer file cannot be produced",
key,
self.inner.estimated_size()
);
}
}
self.last_key_written = key;
self.inner.put_value(key, lsn, val, ctx).await
}
pub async fn put_value(
&mut self,
key: Key,
@@ -315,64 +158,56 @@ impl SplitDeltaLayerWriter {
tline: &Arc<Timeline>,
ctx: &RequestContext,
) -> anyhow::Result<()> {
self.put_value_with_discard_fn(key, lsn, val, tline, ctx, |_| async { false })
.await
}
pub(crate) async fn finish_with_discard_fn<D, F>(
self,
tline: &Arc<Timeline>,
ctx: &RequestContext,
end_key: Key,
discard: D,
) -> anyhow::Result<Vec<SplitWriterResult>>
where
D: FnOnce(&PersistentLayerKey) -> F,
F: Future<Output = bool>,
{
let Self {
mut generated_layers,
inner,
..
} = self;
if inner.num_keys() == 0 {
return Ok(generated_layers);
}
let layer_key = PersistentLayerKey {
key_range: self.start_key..end_key,
lsn_range: self.lsn_range.clone(),
is_delta: true,
};
if discard(&layer_key).await {
generated_layers.push(SplitWriterResult::Discarded(layer_key));
} else {
let (desc, path) = inner.finish(end_key, ctx).await?;
// The current estimation is key size plus LSN size plus value size estimation. This is not an accurate
// number, and therefore the final layer size could be a little bit larger or smaller than the target.
let addition_size_estimation = KEY_SIZE as u64 + 8 /* LSN u64 size */ + 80 /* value size estimation */;
if self.inner.num_keys() >= 1
&& self.inner.estimated_size() + addition_size_estimation >= self.target_layer_size
{
let next_delta_writer = DeltaLayerWriter::new(
self.conf,
self.timeline_id,
self.tenant_shard_id,
key,
self.lsn_range.clone(),
ctx,
)
.await?;
let prev_delta_writer = std::mem::replace(&mut self.inner, next_delta_writer);
let (desc, path) = prev_delta_writer.finish(key, ctx).await?;
let delta_layer = Layer::finish_creating(self.conf, tline, desc, &path)?;
generated_layers.push(SplitWriterResult::Produced(delta_layer));
self.generated_layers.push(delta_layer);
}
Ok(generated_layers)
self.inner.put_value(key, lsn, val, ctx).await
}
#[allow(dead_code)]
pub(crate) async fn finish(
self,
tline: &Arc<Timeline>,
ctx: &RequestContext,
end_key: Key,
) -> anyhow::Result<Vec<SplitWriterResult>> {
self.finish_with_discard_fn(tline, ctx, end_key, |_| async { false })
.await
) -> anyhow::Result<Vec<ResidentLayer>> {
let Self {
mut generated_layers,
inner,
..
} = self;
let (desc, path) = inner.finish(end_key, ctx).await?;
let delta_layer = Layer::finish_creating(self.conf, tline, desc, &path)?;
generated_layers.push(delta_layer);
Ok(generated_layers)
}
/// When split writer fails, the caller should call this function and handle partially generated layers.
pub(crate) fn take(self) -> anyhow::Result<(Vec<SplitWriterResult>, DeltaLayerWriter)> {
#[allow(dead_code)]
pub(crate) async fn take(self) -> anyhow::Result<(Vec<ResidentLayer>, DeltaLayerWriter)> {
Ok((self.generated_layers, self.inner))
}
}
#[cfg(test)]
mod tests {
use itertools::Itertools;
use rand::{RngCore, SeedableRng};
use crate::{
@@ -467,16 +302,9 @@ mod tests {
#[tokio::test]
async fn write_split() {
write_split_helper("split_writer_write_split", false).await;
}
#[tokio::test]
async fn write_split_discard() {
write_split_helper("split_writer_write_split_discard", false).await;
}
async fn write_split_helper(harness_name: &'static str, discard: bool) {
let harness = TenantHarness::create(harness_name).await.unwrap();
let harness = TenantHarness::create("split_writer_write_split")
.await
.unwrap();
let (tenant, ctx) = harness.load().await;
let tline = tenant
@@ -510,19 +338,16 @@ mod tests {
for i in 0..N {
let i = i as u32;
image_writer
.put_image_with_discard_fn(get_key(i), get_large_img(), &tline, &ctx, |_| async {
discard
})
.put_image(get_key(i), get_large_img(), &tline, &ctx)
.await
.unwrap();
delta_writer
.put_value_with_discard_fn(
.put_value(
get_key(i),
Lsn(0x20),
Value::Image(get_large_img()),
&tline,
&ctx,
|_| async { discard },
)
.await
.unwrap();
@@ -535,39 +360,22 @@ mod tests {
.finish(&tline, &ctx, get_key(N as u32))
.await
.unwrap();
if discard {
for layer in image_layers {
layer.into_discarded_layer();
}
for layer in delta_layers {
layer.into_discarded_layer();
}
} else {
let image_layers = image_layers
.into_iter()
.map(|x| x.into_resident_layer())
.collect_vec();
let delta_layers = delta_layers
.into_iter()
.map(|x| x.into_resident_layer())
.collect_vec();
assert_eq!(image_layers.len(), N / 512 + 1);
assert_eq!(delta_layers.len(), N / 512 + 1);
for idx in 0..image_layers.len() {
assert_ne!(image_layers[idx].layer_desc().key_range.start, Key::MIN);
assert_ne!(image_layers[idx].layer_desc().key_range.end, Key::MAX);
assert_ne!(delta_layers[idx].layer_desc().key_range.start, Key::MIN);
assert_ne!(delta_layers[idx].layer_desc().key_range.end, Key::MAX);
if idx > 0 {
assert_eq!(
image_layers[idx - 1].layer_desc().key_range.end,
image_layers[idx].layer_desc().key_range.start
);
assert_eq!(
delta_layers[idx - 1].layer_desc().key_range.end,
delta_layers[idx].layer_desc().key_range.start
);
}
assert_eq!(image_layers.len(), N / 512 + 1);
assert_eq!(delta_layers.len(), N / 512 + 1);
for idx in 0..image_layers.len() {
assert_ne!(image_layers[idx].layer_desc().key_range.start, Key::MIN);
assert_ne!(image_layers[idx].layer_desc().key_range.end, Key::MAX);
assert_ne!(delta_layers[idx].layer_desc().key_range.start, Key::MIN);
assert_ne!(delta_layers[idx].layer_desc().key_range.end, Key::MAX);
if idx > 0 {
assert_eq!(
image_layers[idx - 1].layer_desc().key_range.end,
image_layers[idx].layer_desc().key_range.start
);
assert_eq!(
delta_layers[idx - 1].layer_desc().key_range.end,
delta_layers[idx].layer_desc().key_range.start
);
}
}
}
@@ -648,49 +456,4 @@ mod tests {
.unwrap();
assert_eq!(layers.len(), 2);
}
#[tokio::test]
async fn write_split_single_key() {
let harness = TenantHarness::create("split_writer_write_split_single_key")
.await
.unwrap();
let (tenant, ctx) = harness.load().await;
let tline = tenant
.create_test_timeline(TIMELINE_ID, Lsn(0x10), DEFAULT_PG_VERSION, &ctx)
.await
.unwrap();
const N: usize = 2000;
let mut delta_writer = SplitDeltaLayerWriter::new(
tenant.conf,
tline.timeline_id,
tenant.tenant_shard_id,
get_key(0),
Lsn(0x10)..Lsn(N as u64 * 16 + 0x10),
4 * 1024 * 1024,
&ctx,
)
.await
.unwrap();
for i in 0..N {
let i = i as u32;
delta_writer
.put_value(
get_key(0),
Lsn(i as u64 * 16 + 0x10),
Value::Image(get_large_img()),
&tline,
&ctx,
)
.await
.unwrap();
}
let delta_layers = delta_writer
.finish(&tline, &ctx, get_key(N as u32))
.await
.unwrap();
assert_eq!(delta_layers.len(), 1);
}
}

View File

@@ -22,8 +22,8 @@ use handle::ShardTimelineId;
use once_cell::sync::Lazy;
use pageserver_api::{
key::{
CompactKey, KEY_SIZE, METADATA_KEY_BEGIN_PREFIX, METADATA_KEY_END_PREFIX,
NON_INHERITED_RANGE, NON_INHERITED_SPARSE_RANGE,
KEY_SIZE, METADATA_KEY_BEGIN_PREFIX, METADATA_KEY_END_PREFIX, NON_INHERITED_RANGE,
NON_INHERITED_SPARSE_RANGE,
},
keyspace::{KeySpaceAccum, KeySpaceRandomAccum, SparseKeyPartitioning},
models::{
@@ -44,8 +44,10 @@ use tokio::{
use tokio_util::sync::CancellationToken;
use tracing::*;
use utils::{
bin_ser::BeSer,
fs_ext, pausable_failpoint,
sync::gate::{Gate, GateGuard},
vec_map::VecMap,
};
use std::pin::pin;
@@ -135,10 +137,7 @@ use self::layer_manager::LayerManager;
use self::logical_size::LogicalSize;
use self::walreceiver::{WalReceiver, WalReceiverConf};
use super::{
config::TenantConf, storage_layer::inmemory_layer, storage_layer::LayerVisibilityHint,
upload_queue::NotInitialized,
};
use super::{config::TenantConf, storage_layer::LayerVisibilityHint, upload_queue::NotInitialized};
use super::{debug_assert_current_span_has_tenant_and_timeline_id, AttachedTenantConf};
use super::{remote_timeline_client::index::IndexPart, storage_layer::LayerFringe};
use super::{
@@ -2234,11 +2233,6 @@ impl Timeline {
handles: Default::default(),
};
if aux_file_policy == Some(AuxFilePolicy::V1) {
warn!("this timeline is using deprecated aux file policy V1");
}
result.repartition_threshold =
result.get_checkpoint_distance() / REPARTITION_FREQ_IN_CHECKPOINT_DISTANCE;
@@ -3002,10 +2996,7 @@ impl Timeline {
// - For L1 & image layers, download most recent LSNs first: the older the LSN, the sooner
// the layer is likely to be covered by an image layer during compaction.
layers.sort_by_key(|(desc, _meta, _atime)| {
std::cmp::Reverse((
!LayerMap::is_l0(&desc.key_range, desc.is_delta),
desc.lsn_range.end,
))
std::cmp::Reverse((!LayerMap::is_l0(&desc.key_range), desc.lsn_range.end))
});
let layers = layers
@@ -3598,6 +3589,34 @@ impl Timeline {
return Err(FlushLayerError::Cancelled);
}
// FIXME(auxfilesv2): support multiple metadata key partitions might need initdb support as well?
// This code path will not be hit during regression tests. After #7099 we have a single partition
// with two key ranges. If someone wants to fix initdb optimization in the future, this might need
// to be fixed.
// For metadata, always create delta layers.
let delta_layer = if !metadata_partition.parts.is_empty() {
assert_eq!(
metadata_partition.parts.len(),
1,
"currently sparse keyspace should only contain a single metadata keyspace"
);
let metadata_keyspace = &metadata_partition.parts[0];
self.create_delta_layer(
&frozen_layer,
Some(
metadata_keyspace.0.ranges.first().unwrap().start
..metadata_keyspace.0.ranges.last().unwrap().end,
),
ctx,
)
.await
.map_err(|e| FlushLayerError::from_anyhow(self, e))?
} else {
None
};
// For image layers, we add them immediately into the layer map.
let mut layers_to_upload = Vec::new();
layers_to_upload.extend(
self.create_image_layers(
@@ -3608,27 +3627,13 @@ impl Timeline {
)
.await?,
);
if !metadata_partition.parts.is_empty() {
assert_eq!(
metadata_partition.parts.len(),
1,
"currently sparse keyspace should only contain a single metadata keyspace"
);
layers_to_upload.extend(
self.create_image_layers(
// Safety: create_image_layers treat sparse keyspaces differently that it does not scan
// every single key within the keyspace, and therefore, it's safe to force converting it
// into a dense keyspace before calling this function.
&metadata_partition.into_dense(),
self.initdb_lsn,
ImageLayerCreationMode::Initial,
ctx,
)
.await?,
);
}
(layers_to_upload, None)
if let Some(delta_layer) = delta_layer {
layers_to_upload.push(delta_layer.clone());
(layers_to_upload, Some(delta_layer))
} else {
(layers_to_upload, None)
}
} else {
// Normal case, write out a L0 delta layer file.
// `create_delta_layer` will not modify the layer map.
@@ -4038,6 +4043,8 @@ impl Timeline {
mode: ImageLayerCreationMode,
start: Key,
) -> Result<ImageLayerCreationOutcome, CreateImageLayersError> {
assert!(!matches!(mode, ImageLayerCreationMode::Initial));
// Metadata keys image layer creation.
let mut reconstruct_state = ValuesReconstructState::default();
let data = self
@@ -4203,13 +4210,15 @@ impl Timeline {
"metadata keys must be partitioned separately"
);
}
if mode == ImageLayerCreationMode::Initial {
return Err(CreateImageLayersError::Other(anyhow::anyhow!("no image layer should be created for metadata keys when flushing frozen layers")));
}
if mode == ImageLayerCreationMode::Try && !check_for_image_layers {
// Skip compaction if there are not enough updates. Metadata compaction will do a scan and
// might mess up with evictions.
start = img_range.end;
continue;
}
// For initial and force modes, we always generate image layers for metadata keys.
} else if let ImageLayerCreationMode::Try = mode {
// check_for_image_layers = false -> skip
// check_for_image_layers = true -> check time_for_new_image_layer -> skip/generate
@@ -4217,8 +4226,7 @@ impl Timeline {
start = img_range.end;
continue;
}
}
if let ImageLayerCreationMode::Force = mode {
} else if let ImageLayerCreationMode::Force = mode {
// When forced to create image layers, we might try and create them where they already
// exist. This mode is only used in tests/debug.
let layers = self.layers.read().await;
@@ -4232,7 +4240,6 @@ impl Timeline {
img_range.start,
img_range.end
);
start = img_range.end;
continue;
}
}
@@ -4588,7 +4595,7 @@ impl Timeline {
// for compact_level0_phase1 creating an L0, which does not happen in practice
// because we have not implemented L0 => L0 compaction.
duplicated_layers.insert(l.layer_desc().key());
} else if LayerMap::is_l0(&l.layer_desc().key_range, l.layer_desc().is_delta) {
} else if LayerMap::is_l0(&l.layer_desc().key_range) {
return Err(CompactionError::Other(anyhow::anyhow!("compaction generates a L0 layer file as output, which will cause infinite compaction.")));
} else {
insert_layers.push(l.clone());
@@ -5444,17 +5451,12 @@ impl Timeline {
!(a.end <= b.start || b.end <= a.start)
}
if deltas.key_range.start.next() != deltas.key_range.end {
let guard = self.layers.read().await;
let mut invalid_layers =
guard.layer_map()?.iter_historic_layers().filter(|layer| {
layer.is_delta()
&& overlaps_with(&layer.lsn_range, &deltas.lsn_range)
&& layer.lsn_range != deltas.lsn_range
// skip single-key layer files
&& layer.key_range.start.next() != layer.key_range.end
});
if let Some(layer) = invalid_layers.next() {
let guard = self.layers.read().await;
for layer in guard.layer_map()?.iter_historic_layers() {
if layer.is_delta()
&& overlaps_with(&layer.lsn_range, &deltas.lsn_range)
&& layer.lsn_range != deltas.lsn_range
{
// If a delta layer overlaps with another delta layer AND their LSN range is not the same, panic
panic!(
"inserted layer violates delta layer LSN invariant: current_lsn_range={}..{}, conflict_lsn_range={}..{}",
@@ -5588,6 +5590,44 @@ enum OpenLayerAction {
}
impl<'a> TimelineWriter<'a> {
/// Put a new page version that can be constructed from a WAL record
///
/// This will implicitly extend the relation, if the page is beyond the
/// current end-of-file.
pub(crate) async fn put(
&mut self,
key: Key,
lsn: Lsn,
value: &Value,
ctx: &RequestContext,
) -> anyhow::Result<()> {
// Avoid doing allocations for "small" values.
// In the regression test suite, the limit of 256 avoided allocations in 95% of cases:
// https://github.com/neondatabase/neon/pull/5056#discussion_r1301975061
let mut buf = smallvec::SmallVec::<[u8; 256]>::new();
value.ser_into(&mut buf)?;
let buf_size: u64 = buf.len().try_into().expect("oversized value buf");
let action = self.get_open_layer_action(lsn, buf_size);
let layer = self.handle_open_layer_action(lsn, action, ctx).await?;
let res = layer.put_value(key.to_compact(), lsn, &buf, ctx).await;
if res.is_ok() {
// Update the current size only when the entire write was ok.
// In case of failures, we may have had partial writes which
// render the size tracking out of sync. That's ok because
// the checkpoint distance should be significantly smaller
// than the S3 single shot upload limit of 5GiB.
let state = self.write_guard.as_mut().unwrap();
state.current_size += buf_size;
state.prev_lsn = Some(lsn);
state.max_lsn = std::cmp::max(state.max_lsn, Some(lsn));
}
res
}
async fn handle_open_layer_action(
&mut self,
at: Lsn,
@@ -5693,58 +5733,18 @@ impl<'a> TimelineWriter<'a> {
}
/// Put a batch of keys at the specified Lsns.
///
/// The batch is sorted by Lsn (enforced by usage of [`utils::vec_map::VecMap`].
pub(crate) async fn put_batch(
&mut self,
batch: Vec<(CompactKey, Lsn, usize, Value)>,
batch: VecMap<Lsn, (Key, Value)>,
ctx: &RequestContext,
) -> anyhow::Result<()> {
if batch.is_empty() {
return Ok(());
for (lsn, (key, val)) in batch {
self.put(key, lsn, &val, ctx).await?
}
let serialized_batch = inmemory_layer::SerializedBatch::from_values(batch);
let batch_max_lsn = serialized_batch.max_lsn;
let buf_size: u64 = serialized_batch.raw.len() as u64;
let action = self.get_open_layer_action(batch_max_lsn, buf_size);
let layer = self
.handle_open_layer_action(batch_max_lsn, action, ctx)
.await?;
let res = layer.put_batch(serialized_batch, ctx).await;
if res.is_ok() {
// Update the current size only when the entire write was ok.
// In case of failures, we may have had partial writes which
// render the size tracking out of sync. That's ok because
// the checkpoint distance should be significantly smaller
// than the S3 single shot upload limit of 5GiB.
let state = self.write_guard.as_mut().unwrap();
state.current_size += buf_size;
state.prev_lsn = Some(batch_max_lsn);
state.max_lsn = std::cmp::max(state.max_lsn, Some(batch_max_lsn));
}
res
}
#[cfg(test)]
/// Test helper, for tests that would like to poke individual values without composing a batch
pub(crate) async fn put(
&mut self,
key: Key,
lsn: Lsn,
value: &Value,
ctx: &RequestContext,
) -> anyhow::Result<()> {
use utils::bin_ser::BeSer;
let val_ser_size = value.serialized_size().unwrap() as usize;
self.put_batch(
vec![(key.to_compact(), lsn, val_ser_size, value.clone())],
ctx,
)
.await
Ok(())
}
pub(crate) async fn delete_batch(
@@ -5885,7 +5885,7 @@ mod tests {
};
// Apart from L0s, newest Layers should come first
if !LayerMap::is_l0(layer.name.key_range(), layer.name.is_delta()) {
if !LayerMap::is_l0(layer.name.key_range()) {
assert!(layer_lsn <= last_lsn);
last_lsn = layer_lsn;
}

View File

@@ -14,7 +14,7 @@ use super::{
RecordedDuration, Timeline,
};
use anyhow::{anyhow, bail, Context};
use anyhow::{anyhow, Context};
use bytes::Bytes;
use enumset::EnumSet;
use fail::fail_point;
@@ -32,9 +32,6 @@ use crate::page_cache;
use crate::tenant::config::defaults::{DEFAULT_CHECKPOINT_DISTANCE, DEFAULT_COMPACTION_THRESHOLD};
use crate::tenant::remote_timeline_client::WaitCompletionError;
use crate::tenant::storage_layer::merge_iterator::MergeIterator;
use crate::tenant::storage_layer::split_writer::{
SplitDeltaLayerWriter, SplitImageLayerWriter, SplitWriterResult,
};
use crate::tenant::storage_layer::{
AsLayerDesc, PersistentLayerDesc, PersistentLayerKey, ValueReconstructState,
};
@@ -74,60 +71,15 @@ pub(crate) struct KeyHistoryRetention {
}
impl KeyHistoryRetention {
/// Hack: skip delta layer if we need to produce a layer of a same key-lsn.
///
/// This can happen if we have removed some deltas in "the middle" of some existing layer's key-lsn-range.
/// For example, consider the case where a single delta with range [0x10,0x50) exists.
/// And we have branches at LSN 0x10, 0x20, 0x30.
/// Then we delete branch @ 0x20.
/// Bottom-most compaction may now delete the delta [0x20,0x30).
/// And that wouldnt' change the shape of the layer.
///
/// Note that bottom-most-gc-compaction never _adds_ new data in that case, only removes.
///
/// `discard_key` will only be called when the writer reaches its target (instead of for every key), so it's fine to grab a lock inside.
async fn discard_key(key: &PersistentLayerKey, tline: &Arc<Timeline>, dry_run: bool) -> bool {
if dry_run {
return true;
}
let guard = tline.layers.read().await;
if !guard.contains_key(key) {
return false;
}
let layer_generation = guard.get_from_key(key).metadata().generation;
drop(guard);
if layer_generation == tline.generation {
info!(
key=%key,
?layer_generation,
"discard layer due to duplicated layer key in the same generation",
);
true
} else {
false
}
}
/// Pipe a history of a single key to the writers.
///
/// If `image_writer` is none, the images will be placed into the delta layers.
/// The delta writer will contain all images and deltas (below and above the horizon) except the bottom-most images.
#[allow(clippy::too_many_arguments)]
async fn pipe_to(
self,
key: Key,
tline: &Arc<Timeline>,
delta_writer: &mut SplitDeltaLayerWriter,
mut image_writer: Option<&mut SplitImageLayerWriter>,
delta_writer: &mut Vec<(Key, Lsn, Value)>,
mut image_writer: Option<&mut ImageLayerWriter>,
stat: &mut CompactionStatistics,
dry_run: bool,
ctx: &RequestContext,
) -> anyhow::Result<()> {
let mut first_batch = true;
let discard = |key: &PersistentLayerKey| {
let key = key.clone();
async move { Self::discard_key(&key, tline, dry_run).await }
};
for (cutoff_lsn, KeyLogAtLsn(logs)) in self.below_horizon {
if first_batch {
if logs.len() == 1 && logs[0].1.is_image() {
@@ -136,45 +88,28 @@ impl KeyHistoryRetention {
};
stat.produce_image_key(img);
if let Some(image_writer) = image_writer.as_mut() {
image_writer
.put_image_with_discard_fn(key, img.clone(), tline, ctx, discard)
.await?;
image_writer.put_image(key, img.clone(), ctx).await?;
} else {
delta_writer
.put_value_with_discard_fn(
key,
cutoff_lsn,
Value::Image(img.clone()),
tline,
ctx,
discard,
)
.await?;
delta_writer.push((key, cutoff_lsn, Value::Image(img.clone())));
}
} else {
for (lsn, val) in logs {
stat.produce_key(&val);
delta_writer
.put_value_with_discard_fn(key, lsn, val, tline, ctx, discard)
.await?;
delta_writer.push((key, lsn, val));
}
}
first_batch = false;
} else {
for (lsn, val) in logs {
stat.produce_key(&val);
delta_writer
.put_value_with_discard_fn(key, lsn, val, tline, ctx, discard)
.await?;
delta_writer.push((key, lsn, val));
}
}
}
let KeyLogAtLsn(above_horizon_logs) = self.above_horizon;
for (lsn, val) in above_horizon_logs {
stat.produce_key(&val);
delta_writer
.put_value_with_discard_fn(key, lsn, val, tline, ctx, discard)
.await?;
delta_writer.push((key, lsn, val));
}
Ok(())
}
@@ -1879,27 +1814,11 @@ impl Timeline {
}
let mut selected_layers = Vec::new();
drop(gc_info);
// Pick all the layers intersect or below the gc_cutoff, get the largest LSN in the selected layers.
let Some(max_layer_lsn) = layers
.iter_historic_layers()
.filter(|desc| desc.get_lsn_range().start <= gc_cutoff)
.map(|desc| desc.get_lsn_range().end)
.max()
else {
info!("no layers to compact with gc");
return Ok(());
};
// Then, pick all the layers that are below the max_layer_lsn. This is to ensure we can pick all single-key
// layers to compact.
for desc in layers.iter_historic_layers() {
if desc.get_lsn_range().end <= max_layer_lsn {
if desc.get_lsn_range().start <= gc_cutoff {
selected_layers.push(guard.get_from_desc(&desc));
}
}
if selected_layers.is_empty() {
info!("no layers to compact with gc");
return Ok(());
}
retain_lsns_below_horizon.sort();
(selected_layers, gc_cutoff, retain_lsns_below_horizon)
};
@@ -1929,53 +1848,27 @@ impl Timeline {
lowest_retain_lsn
);
// Step 1: (In the future) construct a k-merge iterator over all layers. For now, simply collect all keys + LSNs.
// Also, verify if the layer map can be split by drawing a horizontal line at every LSN start/end split point.
let mut lsn_split_point = BTreeSet::new(); // TODO: use a better data structure (range tree / range set?)
// Also, collect the layer information to decide when to split the new delta layers.
let mut downloaded_layers = Vec::new();
let mut delta_split_points = BTreeSet::new();
for layer in &layer_selection {
let resident_layer = layer.download_and_keep_resident().await?;
downloaded_layers.push(resident_layer);
let desc = layer.layer_desc();
if desc.is_delta() {
// ignore single-key layer files
if desc.key_range.start.next() != desc.key_range.end {
let lsn_range = &desc.lsn_range;
lsn_split_point.insert(lsn_range.start);
lsn_split_point.insert(lsn_range.end);
}
// TODO: is it correct to only record split points for deltas intersecting with the GC horizon? (exclude those below/above the horizon)
// so that we can avoid having too many small delta layers.
let key_range = desc.get_key_range();
delta_split_points.insert(key_range.start);
delta_split_points.insert(key_range.end);
stat.visit_delta_layer(desc.file_size());
} else {
stat.visit_image_layer(desc.file_size());
}
}
for layer in &layer_selection {
let desc = layer.layer_desc();
let key_range = &desc.key_range;
if desc.is_delta() && key_range.start.next() != key_range.end {
let lsn_range = desc.lsn_range.clone();
let intersects = lsn_split_point.range(lsn_range).collect_vec();
if intersects.len() > 1 {
bail!(
"cannot run gc-compaction because it violates the layer map LSN split assumption: layer {} intersects with LSN [{}]",
desc.key(),
intersects.into_iter().map(|lsn| lsn.to_string()).join(", ")
);
}
}
}
// The maximum LSN we are processing in this compaction loop
let end_lsn = layer_selection
.iter()
.map(|l| l.layer_desc().lsn_range.end)
.max()
.unwrap();
// We don't want any of the produced layers to cover the full key range (i.e., MIN..MAX) b/c it will then be recognized
// as an L0 layer.
let hack_end_key = Key::NON_L0_MAX;
let mut delta_layers = Vec::new();
let mut image_layers = Vec::new();
let mut downloaded_layers = Vec::new();
for layer in &layer_selection {
let resident_layer = layer.download_and_keep_resident().await?;
downloaded_layers.push(resident_layer);
}
for resident_layer in &downloaded_layers {
if resident_layer.layer_desc().is_delta() {
let layer = resident_layer.get_as_delta(ctx).await?;
@@ -1991,17 +1884,138 @@ impl Timeline {
let mut accumulated_values = Vec::new();
let mut last_key: Option<Key> = None;
enum FlushDeltaResult {
/// Create a new resident layer
CreateResidentLayer(ResidentLayer),
/// Keep an original delta layer
KeepLayer(PersistentLayerKey),
}
#[allow(clippy::too_many_arguments)]
async fn flush_deltas(
deltas: &mut Vec<(Key, Lsn, crate::repository::Value)>,
last_key: Key,
delta_split_points: &[Key],
current_delta_split_point: &mut usize,
tline: &Arc<Timeline>,
lowest_retain_lsn: Lsn,
ctx: &RequestContext,
stats: &mut CompactionStatistics,
dry_run: bool,
last_batch: bool,
) -> anyhow::Result<Option<FlushDeltaResult>> {
// Check if we need to split the delta layer. We split at the original delta layer boundary to avoid
// overlapping layers.
//
// If we have a structure like this:
//
// | Delta 1 | | Delta 4 |
// |---------| Delta 2 |---------|
// | Delta 3 | | Delta 5 |
//
// And we choose to compact delta 2+3+5. We will get an overlapping delta layer with delta 1+4.
// A simple solution here is to split the delta layers using the original boundary, while this
// might produce a lot of small layers. This should be improved and fixed in the future.
let mut need_split = false;
while *current_delta_split_point < delta_split_points.len()
&& last_key >= delta_split_points[*current_delta_split_point]
{
*current_delta_split_point += 1;
need_split = true;
}
if !need_split && !last_batch {
return Ok(None);
}
let deltas: Vec<(Key, Lsn, Value)> = std::mem::take(deltas);
if deltas.is_empty() {
return Ok(None);
}
let end_lsn = deltas.iter().map(|(_, lsn, _)| lsn).max().copied().unwrap() + 1;
let delta_key = PersistentLayerKey {
key_range: {
let key_start = deltas.first().unwrap().0;
let key_end = deltas.last().unwrap().0.next();
key_start..key_end
},
lsn_range: lowest_retain_lsn..end_lsn,
is_delta: true,
};
{
// Hack: skip delta layer if we need to produce a layer of a same key-lsn.
//
// This can happen if we have removed some deltas in "the middle" of some existing layer's key-lsn-range.
// For example, consider the case where a single delta with range [0x10,0x50) exists.
// And we have branches at LSN 0x10, 0x20, 0x30.
// Then we delete branch @ 0x20.
// Bottom-most compaction may now delete the delta [0x20,0x30).
// And that wouldnt' change the shape of the layer.
//
// Note that bottom-most-gc-compaction never _adds_ new data in that case, only removes.
// That's why it's safe to skip.
let guard = tline.layers.read().await;
if guard.contains_key(&delta_key) {
let layer_generation = guard.get_from_key(&delta_key).metadata().generation;
drop(guard);
if layer_generation == tline.generation {
stats.discard_delta_layer();
// TODO: depending on whether we design this compaction process to run along with
// other compactions, there could be layer map modifications after we drop the
// layer guard, and in case it creates duplicated layer key, we will still error
// in the end.
info!(
key=%delta_key,
?layer_generation,
"discard delta layer due to duplicated layer in the same generation"
);
return Ok(Some(FlushDeltaResult::KeepLayer(delta_key)));
}
}
}
let mut delta_layer_writer = DeltaLayerWriter::new(
tline.conf,
tline.timeline_id,
tline.tenant_shard_id,
delta_key.key_range.start,
lowest_retain_lsn..end_lsn,
ctx,
)
.await?;
for (key, lsn, val) in deltas {
delta_layer_writer.put_value(key, lsn, val, ctx).await?;
}
stats.produce_delta_layer(delta_layer_writer.size());
if dry_run {
return Ok(None);
}
let (desc, path) = delta_layer_writer
.finish(delta_key.key_range.end, ctx)
.await?;
let delta_layer = Layer::finish_creating(tline.conf, tline, desc, &path)?;
Ok(Some(FlushDeltaResult::CreateResidentLayer(delta_layer)))
}
// Hack the key range to be min..(max-1). Otherwise, the image layer will be
// interpreted as an L0 delta layer.
let hack_image_layer_range = {
let mut end_key = Key::MAX;
end_key.field6 -= 1;
Key::MIN..end_key
};
// Only create image layers when there is no ancestor branches. TODO: create covering image layer
// when some condition meet.
let mut image_layer_writer = if self.ancestor_timeline.is_none() {
Some(
SplitImageLayerWriter::new(
ImageLayerWriter::new(
self.conf,
self.timeline_id,
self.tenant_shard_id,
Key::MIN,
&hack_image_layer_range, // covers the full key range
lowest_retain_lsn,
self.get_compaction_target_size(),
ctx,
)
.await?,
@@ -2010,17 +2024,6 @@ impl Timeline {
None
};
let mut delta_layer_writer = SplitDeltaLayerWriter::new(
self.conf,
self.timeline_id,
self.tenant_shard_id,
Key::MIN,
lowest_retain_lsn..end_lsn,
self.get_compaction_target_size(),
ctx,
)
.await?;
/// Returns None if there is no ancestor branch. Throw an error when the key is not found.
///
/// Currently, we always get the ancestor image for each key in the child branch no matter whether the image
@@ -2041,11 +2044,47 @@ impl Timeline {
let img = tline.get(key, tline.ancestor_lsn, ctx).await?;
Ok(Some((key, tline.ancestor_lsn, img)))
}
let image_layer_key = PersistentLayerKey {
key_range: hack_image_layer_range,
lsn_range: PersistentLayerDesc::image_layer_lsn_range(lowest_retain_lsn),
is_delta: false,
};
// Like with delta layers, it can happen that we re-produce an already existing image layer.
// This could happen when a user triggers force compaction and image generation. In this case,
// it's always safe to rewrite the layer.
let discard_image_layer = {
let guard = self.layers.read().await;
if guard.contains_key(&image_layer_key) {
let layer_generation = guard.get_from_key(&image_layer_key).metadata().generation;
drop(guard);
if layer_generation == self.generation {
// TODO: depending on whether we design this compaction process to run along with
// other compactions, there could be layer map modifications after we drop the
// layer guard, and in case it creates duplicated layer key, we will still error
// in the end.
info!(
key=%image_layer_key,
?layer_generation,
"discard image layer due to duplicated layer key in the same generation",
);
true
} else {
false
}
} else {
false
}
};
// Actually, we can decide not to write to the image layer at all at this point because
// the key and LSN range are determined. However, to keep things simple here, we still
// create this writer, and discard the writer in the end.
let mut delta_values = Vec::new();
let delta_split_points = delta_split_points.into_iter().collect_vec();
let mut current_delta_split_point = 0;
let mut delta_layers = Vec::new();
while let Some((key, lsn, val)) = merge_iter.next().await? {
if cancel.is_cancelled() {
return Err(anyhow!("cancelled")); // TODO: refactor to CompactionError and pass cancel error
@@ -2076,14 +2115,27 @@ impl Timeline {
retention
.pipe_to(
*last_key,
self,
&mut delta_layer_writer,
&mut delta_values,
image_layer_writer.as_mut(),
&mut stat,
dry_run,
ctx,
)
.await?;
delta_layers.extend(
flush_deltas(
&mut delta_values,
*last_key,
&delta_split_points,
&mut current_delta_split_point,
self,
lowest_retain_lsn,
ctx,
&mut stat,
dry_run,
false,
)
.await?,
);
accumulated_values.clear();
*last_key = key;
accumulated_values.push((key, lsn, val));
@@ -2107,75 +2159,43 @@ impl Timeline {
retention
.pipe_to(
last_key,
self,
&mut delta_layer_writer,
&mut delta_values,
image_layer_writer.as_mut(),
&mut stat,
dry_run,
ctx,
)
.await?;
delta_layers.extend(
flush_deltas(
&mut delta_values,
last_key,
&delta_split_points,
&mut current_delta_split_point,
self,
lowest_retain_lsn,
ctx,
&mut stat,
dry_run,
true,
)
.await?,
);
assert!(delta_values.is_empty(), "unprocessed keys");
let discard = |key: &PersistentLayerKey| {
let key = key.clone();
async move { KeyHistoryRetention::discard_key(&key, self, dry_run).await }
};
let produced_image_layers = if let Some(writer) = image_layer_writer {
let image_layer = if discard_image_layer {
stat.discard_image_layer();
None
} else if let Some(writer) = image_layer_writer {
stat.produce_image_layer(writer.size());
if !dry_run {
writer
.finish_with_discard_fn(self, ctx, hack_end_key, discard)
.await?
Some(writer.finish(self, ctx).await?)
} else {
let (layers, _) = writer.take()?;
assert!(layers.is_empty(), "image layers produced in dry run mode?");
Vec::new()
None
}
} else {
Vec::new()
None
};
let produced_delta_layers = if !dry_run {
delta_layer_writer
.finish_with_discard_fn(self, ctx, hack_end_key, discard)
.await?
} else {
let (layers, _) = delta_layer_writer.take()?;
assert!(layers.is_empty(), "delta layers produced in dry run mode?");
Vec::new()
};
let mut compact_to = Vec::new();
let mut keep_layers = HashSet::new();
let produced_delta_layers_len = produced_delta_layers.len();
let produced_image_layers_len = produced_image_layers.len();
for action in produced_delta_layers {
match action {
SplitWriterResult::Produced(layer) => {
stat.produce_delta_layer(layer.layer_desc().file_size());
compact_to.push(layer);
}
SplitWriterResult::Discarded(l) => {
keep_layers.insert(l);
stat.discard_delta_layer();
}
}
}
for action in produced_image_layers {
match action {
SplitWriterResult::Produced(layer) => {
stat.produce_image_layer(layer.layer_desc().file_size());
compact_to.push(layer);
}
SplitWriterResult::Discarded(l) => {
keep_layers.insert(l);
stat.discard_image_layer();
}
}
}
let mut layer_selection = layer_selection;
layer_selection.retain(|x| !keep_layers.contains(&x.layer_desc().key()));
info!(
"gc-compaction statistics: {}",
serde_json::to_string(&stat)?
@@ -2186,11 +2206,28 @@ impl Timeline {
}
info!(
"produced {} delta layers and {} image layers, {} layers are kept",
produced_delta_layers_len,
produced_image_layers_len,
layer_selection.len()
"produced {} delta layers and {} image layers",
delta_layers.len(),
if image_layer.is_some() { 1 } else { 0 }
);
let mut compact_to = Vec::new();
let mut keep_layers = HashSet::new();
for action in delta_layers {
match action {
FlushDeltaResult::CreateResidentLayer(layer) => {
compact_to.push(layer);
}
FlushDeltaResult::KeepLayer(l) => {
keep_layers.insert(l);
}
}
}
if discard_image_layer {
keep_layers.insert(image_layer_key);
}
let mut layer_selection = layer_selection;
layer_selection.retain(|x| !keep_layers.contains(&x.layer_desc().key()));
compact_to.extend(image_layer);
// Step 3: Place back to the layer map.
{

View File

@@ -27,8 +27,8 @@ use super::TaskStateUpdate;
use crate::{
context::RequestContext,
metrics::{LIVE_CONNECTIONS, WALRECEIVER_STARTED_CONNECTIONS, WAL_INGEST},
pgdatadir_mapping::DatadirModification,
task_mgr::{TaskKind, WALRECEIVER_RUNTIME},
task_mgr::TaskKind,
task_mgr::WALRECEIVER_RUNTIME,
tenant::{debug_assert_current_span_has_tenant_and_timeline_id, Timeline, WalReceiverInfo},
walingest::WalIngest,
walrecord::DecodedWALRecord,
@@ -345,10 +345,7 @@ pub(super) async fn handle_walreceiver_connection(
// Commit every ingest_batch_size records. Even if we filtered out
// all records, we still need to call commit to advance the LSN.
uncommitted_records += 1;
if uncommitted_records >= ingest_batch_size
|| modification.approx_pending_bytes()
> DatadirModification::MAX_PENDING_BYTES
{
if uncommitted_records >= ingest_batch_size {
WAL_INGEST
.records_committed
.inc_by(uncommitted_records - filtered_records);

View File

@@ -9,7 +9,7 @@ use utils::serde_percent::Percent;
use pageserver_api::models::PageserverUtilization;
use crate::{config::PageServerConf, metrics::NODE_UTILIZATION_SCORE, tenant::mgr::TenantManager};
use crate::{config::PageServerConf, tenant::mgr::TenantManager};
pub(crate) fn regenerate(
conf: &PageServerConf,
@@ -58,13 +58,13 @@ pub(crate) fn regenerate(
disk_usable_pct,
shard_count,
max_shard_count: MAX_SHARDS,
utilization_score: None,
utilization_score: 0,
captured_at: utils::serde_system_time::SystemTime(captured_at),
};
// Initialize `PageserverUtilization::utilization_score`
let score = doc.cached_score();
NODE_UTILIZATION_SCORE.set(score);
doc.refresh_score();
// TODO: make utilization_score into a metric
Ok(doc)
}

View File

@@ -110,8 +110,7 @@ get_cached_relsize(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber *size)
tag.rinfo = rinfo;
tag.forknum = forknum;
/* We need exclusive lock here because of LRU list manipulation */
LWLockAcquire(relsize_lock, LW_EXCLUSIVE);
LWLockAcquire(relsize_lock, LW_SHARED);
entry = hash_search(relsize_hash, &tag, HASH_FIND, NULL);
if (entry != NULL)
{

View File

@@ -2,7 +2,6 @@
import argparse
import enum
import os
import subprocess
import sys
from typing import List
@@ -94,7 +93,7 @@ if __name__ == "__main__":
"--no-color",
action="store_true",
help="disable colored output",
default=not sys.stdout.isatty() or os.getenv("TERM") == "dumb",
default=not sys.stdout.isatty(),
)
args = parser.parse_args()

View File

@@ -114,6 +114,9 @@ rsa = "0.9"
workspace_hack.workspace = true
[target.'cfg(target_os = "linux")'.dependencies]
ktls = "6"
[dev-dependencies]
camino-tempfile.workspace = true
fallible-iterator.workspace = true

View File

@@ -36,7 +36,7 @@ To play with it locally one may start proxy over a local postgres installation
```
If both postgres and proxy are running you may send a SQL query:
```console
```json
curl -k -X POST 'https://proxy.localtest.me:4444/sql' \
-H 'Neon-Connection-String: postgres://stas:pass@proxy.localtest.me:4444/postgres' \
-H 'Content-Type: application/json' \
@@ -44,8 +44,7 @@ curl -k -X POST 'https://proxy.localtest.me:4444/sql' \
"query":"SELECT $1::int[] as arr, $2::jsonb as obj, 42 as num",
"params":[ "{{1,2},{\"3\",4}}", {"key":"val", "ikey":4242}]
}' | jq
```
```json
{
"command": "SELECT",
"fields": [

View File

@@ -2,15 +2,14 @@ mod classic;
mod hacks;
pub mod jwt;
mod link;
pub mod local;
use std::net::IpAddr;
use std::os::fd::AsRawFd;
use std::sync::Arc;
use std::time::Duration;
use ipnet::{Ipv4Net, Ipv6Net};
pub use link::LinkAuthError;
use local::LocalBackend;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_postgres::config::AuthKeys;
use tracing::{info, warn};
@@ -25,6 +24,7 @@ use crate::context::RequestMonitoring;
use crate::intern::EndpointIdInt;
use crate::metrics::Metrics;
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::proxy::handshake::KtlsAsyncReadReady;
use crate::proxy::NeonOptions;
use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter, RateBucketInfo};
use crate::stream::Stream;
@@ -70,8 +70,6 @@ pub enum BackendType<'a, T, D> {
Console(MaybeOwned<'a, ConsoleBackend>, T),
/// Authentication via a web browser.
Link(MaybeOwned<'a, url::ApiUrl>, D),
/// Local proxy uses configured auth credentials and does not wake compute
Local(MaybeOwned<'a, LocalBackend>),
}
pub trait TestBackend: Send + Sync + 'static {
@@ -85,7 +83,7 @@ pub trait TestBackend: Send + Sync + 'static {
impl std::fmt::Display for BackendType<'_, (), ()> {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Console(api, ()) => match &**api {
Self::Console(api, _) => match &**api {
ConsoleBackend::Console(endpoint) => {
fmt.debug_tuple("Console").field(&endpoint.url()).finish()
}
@@ -96,8 +94,7 @@ impl std::fmt::Display for BackendType<'_, (), ()> {
#[cfg(test)]
ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(),
},
Self::Link(url, ()) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
Self::Local(_) => fmt.debug_tuple("Local").finish(),
Self::Link(url, _) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
}
}
}
@@ -109,7 +106,6 @@ impl<T, D> BackendType<'_, T, D> {
match self {
Self::Console(c, x) => BackendType::Console(MaybeOwned::Borrowed(c), x),
Self::Link(c, x) => BackendType::Link(MaybeOwned::Borrowed(c), x),
Self::Local(l) => BackendType::Local(MaybeOwned::Borrowed(l)),
}
}
}
@@ -122,7 +118,6 @@ impl<'a, T, D> BackendType<'a, T, D> {
match self {
Self::Console(c, x) => BackendType::Console(c, f(x)),
Self::Link(c, x) => BackendType::Link(c, x),
Self::Local(l) => BackendType::Local(l),
}
}
}
@@ -133,7 +128,6 @@ impl<'a, T, D, E> BackendType<'a, Result<T, E>, D> {
match self {
Self::Console(c, x) => x.map(|x| BackendType::Console(c, x)),
Self::Link(c, x) => Ok(BackendType::Link(c, x)),
Self::Local(l) => Ok(BackendType::Local(l)),
}
}
}
@@ -165,7 +159,6 @@ impl ComputeUserInfo {
pub enum ComputeCredentialKeys {
Password(Vec<u8>),
AuthKeys(AuthKeys),
None,
}
impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
@@ -283,7 +276,9 @@ async fn auth_quirks(
ctx: &RequestMonitoring,
api: &impl console::Api,
user_info: ComputeUserInfoMaybeEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut stream::PqStream<
Stream<impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>,
>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -298,7 +293,7 @@ async fn auth_quirks(
ctx.set_endpoint_id(res.info.endpoint.clone());
let password = match res.keys {
ComputeCredentialKeys::Password(p) => p,
ComputeCredentialKeys::AuthKeys(_) | ComputeCredentialKeys::None => {
ComputeCredentialKeys::AuthKeys(_) => {
unreachable!("password hack should return a password")
}
};
@@ -324,20 +319,21 @@ async fn auth_quirks(
};
let (cached_entry, secret) = cached_secret.take_value();
let secret = if let Some(secret) = secret {
config.check_rate_limit(
let secret = match secret {
Some(secret) => config.check_rate_limit(
ctx,
config,
secret,
&info.endpoint,
unauthenticated_password.is_some() || allow_cleartext,
)?
} else {
// If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps).
// This mocked secret will never lead to successful authentication.
info!("authentication info not found, mocking it");
AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
)?,
None => {
// If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps).
// This mocked secret will never lead to successful authentication.
info!("authentication info not found, mocking it");
AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
}
};
match authenticate_with_secret(
@@ -366,7 +362,9 @@ async fn authenticate_with_secret(
ctx: &RequestMonitoring,
secret: AuthSecret,
info: ComputeUserInfo,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut stream::PqStream<
Stream<impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>,
>,
unauthenticated_password: Option<Vec<u8>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
@@ -408,8 +406,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
pub fn get_endpoint(&self) -> Option<EndpointId> {
match self {
Self::Console(_, user_info) => user_info.endpoint_id.clone(),
Self::Link(_, ()) => Some("link".into()),
Self::Local(_) => Some("local".into()),
Self::Link(_, _) => Some("link".into()),
}
}
@@ -417,8 +414,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
pub fn get_user(&self) -> &str {
match self {
Self::Console(_, user_info) => &user_info.user,
Self::Link(_, ()) => "link",
Self::Local(_) => "local",
Self::Link(_, _) => "link",
}
}
@@ -427,7 +423,9 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
pub async fn authenticate(
self,
ctx: &RequestMonitoring,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut stream::PqStream<
Stream<impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>,
>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -453,16 +451,13 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
BackendType::Console(api, credentials)
}
// NOTE: this auth backend doesn't use client credentials.
Self::Link(url, ()) => {
Self::Link(url, _) => {
info!("performing link authentication");
let info = link::authenticate(ctx, &url, client).await?;
BackendType::Link(url, info)
}
Self::Local(_) => {
return Err(auth::AuthError::bad_auth_method("invalid for local proxy"))
}
};
info!("user successfully authenticated");
@@ -477,8 +472,7 @@ impl BackendType<'_, ComputeUserInfo, &()> {
) -> Result<CachedRoleSecret, GetAuthInfoError> {
match self {
Self::Console(api, user_info) => api.get_role_secret(ctx, user_info).await,
Self::Link(_, ()) => Ok(Cached::new_uncached(None)),
Self::Local(_) => Ok(Cached::new_uncached(None)),
Self::Link(_, _) => Ok(Cached::new_uncached(None)),
}
}
@@ -488,8 +482,7 @@ impl BackendType<'_, ComputeUserInfo, &()> {
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
match self {
Self::Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await,
Self::Link(_, ()) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
Self::Link(_, _) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
}
}
}
@@ -503,15 +496,13 @@ impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> {
match self {
Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
Self::Link(_, info) => Ok(Cached::new_uncached(info.clone())),
Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
}
}
fn get_keys(&self) -> &ComputeCredentialKeys {
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
match self {
Self::Console(_, creds) => &creds.keys,
Self::Link(_, _) => &ComputeCredentialKeys::None,
Self::Local(_) => &ComputeCredentialKeys::None,
Self::Console(_, creds) => Some(&creds.keys),
Self::Link(_, _) => None,
}
}
}
@@ -524,16 +515,14 @@ impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> {
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
match self {
Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
Self::Link(_, ()) => unreachable!("link auth flow doesn't support waking the compute"),
Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
Self::Link(_, _) => unreachable!("link auth flow doesn't support waking the compute"),
}
}
fn get_keys(&self) -> &ComputeCredentialKeys {
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
match self {
Self::Console(_, creds) => &creds.keys,
Self::Link(_, ()) => &ComputeCredentialKeys::None,
Self::Local(_) => &ComputeCredentialKeys::None,
Self::Console(_, creds) => Some(&creds.keys),
Self::Link(_, _) => None,
}
}
}
@@ -561,7 +550,7 @@ mod tests {
CachedNodeInfo,
},
context::RequestMonitoring,
proxy::NeonOptions,
proxy::{tests::DummyClient, NeonOptions},
rate_limiter::{EndpointRateLimiter, RateBucketInfo},
scram::{threadpool::ThreadPool, ServerSecret},
stream::{PqStream, Stream},
@@ -669,7 +658,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_scram() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut stream = PqStream::new(Stream::from_raw(DummyClient(server)));
let ctx = RequestMonitoring::test();
let api = Auth {
@@ -746,7 +735,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_cleartext() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut stream = PqStream::new(Stream::from_raw(DummyClient(server)));
let ctx = RequestMonitoring::test();
let api = Auth {
@@ -798,7 +787,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_password_hack() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut stream = PqStream::new(Stream::from_raw(DummyClient(server)));
let ctx = RequestMonitoring::test();
let api = Auth {

View File

@@ -1,3 +1,5 @@
use std::os::fd::AsRawFd;
use super::{ComputeCredentials, ComputeUserInfo};
use crate::{
auth::{self, backend::ComputeCredentialKeys, AuthFlow},
@@ -5,6 +7,7 @@ use crate::{
config::AuthenticationConfig,
console::AuthSecret,
context::RequestMonitoring,
proxy::handshake::KtlsAsyncReadReady,
sasl,
stream::{PqStream, Stream},
};
@@ -14,7 +17,9 @@ use tracing::{info, warn};
pub(super) async fn authenticate(
ctx: &RequestMonitoring,
creds: ComputeUserInfo,
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut PqStream<
Stream<impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>,
>,
config: &'static AuthenticationConfig,
secret: AuthSecret,
) -> auth::Result<ComputeCredentials> {

View File

@@ -1,3 +1,5 @@
use std::os::fd::AsRawFd;
use super::{
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint,
};
@@ -7,6 +9,7 @@ use crate::{
console::AuthSecret,
context::RequestMonitoring,
intern::EndpointIdInt,
proxy::handshake::KtlsAsyncReadReady,
sasl,
stream::{self, Stream},
};
@@ -20,7 +23,9 @@ use tracing::{info, warn};
pub async fn authenticate_cleartext(
ctx: &RequestMonitoring,
info: ComputeUserInfo,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut stream::PqStream<
Stream<impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>,
>,
secret: AuthSecret,
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials> {
@@ -62,7 +67,9 @@ pub async fn authenticate_cleartext(
pub async fn password_hack_no_authentication(
ctx: &RequestMonitoring,
info: ComputeUserInfoNoEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut stream::PqStream<
Stream<impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>,
>,
) -> auth::Result<ComputeCredentials> {
warn!("project not specified, resorting to the password hack auth flow");
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);

View File

@@ -1,21 +1,15 @@
use std::{
future::Future,
sync::Arc,
time::{Duration, SystemTime},
};
use std::{future::Future, sync::Arc, time::Duration};
use anyhow::{bail, ensure, Context};
use arc_swap::ArcSwapOption;
use dashmap::DashMap;
use jose_jwk::crypto::KeyInfo;
use serde::{Deserialize, Deserializer};
use signature::Verifier;
use tokio::time::Instant;
use crate::{context::RequestMonitoring, http::parse_json_body_with_limit, EndpointId, RoleName};
use crate::{http::parse_json_body_with_limit, intern::EndpointIdInt};
// TODO(conrad): make these configurable.
const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30);
const MIN_RENEW: Duration = Duration::from_secs(30);
const AUTO_RENEW: Duration = Duration::from_secs(300);
const MAX_RENEW: Duration = Duration::from_secs(3600);
@@ -23,56 +17,30 @@ const MAX_JWK_BODY_SIZE: usize = 64 * 1024;
/// How to get the JWT auth rules
pub trait FetchAuthRules: Clone + Send + Sync + 'static {
fn fetch_auth_rules(
&self,
role_name: RoleName,
) -> impl Future<Output = anyhow::Result<Vec<AuthRule>>> + Send;
fn fetch_auth_rules(&self) -> impl Future<Output = anyhow::Result<AuthRules>> + Send;
}
pub struct AuthRule {
pub id: String,
pub jwks_url: url::Url,
pub audience: Option<String>,
#[derive(Clone)]
struct FetchAuthRulesFromCplane {
#[allow(dead_code)]
endpoint: EndpointIdInt,
}
impl FetchAuthRules for FetchAuthRulesFromCplane {
async fn fetch_auth_rules(&self) -> anyhow::Result<AuthRules> {
Err(anyhow::anyhow!("not yet implemented"))
}
}
pub struct AuthRules {
jwks_urls: Vec<url::Url>,
}
#[derive(Default)]
pub struct JwkCache {
client: reqwest::Client,
map: DashMap<(EndpointId, RoleName), Arc<JwkCacheEntryLock>>,
}
pub struct JwkCacheEntry {
/// Should refetch at least every hour to verify when old keys have been removed.
/// Should refetch when new key IDs are seen only every 5 minutes or so
last_retrieved: Instant,
/// cplane will return multiple JWKs urls that we need to scrape.
key_sets: ahash::HashMap<String, KeySet>,
}
impl JwkCacheEntry {
fn find_jwk_and_audience(&self, key_id: &str) -> Option<(&jose_jwk::Jwk, Option<&str>)> {
self.key_sets.values().find_map(|key_set| {
key_set
.find_key(key_id)
.map(|jwk| (jwk, key_set.audience.as_deref()))
})
}
}
struct KeySet {
jwks: jose_jwk::JwkSet,
audience: Option<String>,
}
impl KeySet {
fn find_key(&self, key_id: &str) -> Option<&jose_jwk::Jwk> {
self.jwks
.keys
.iter()
.find(|jwk| jwk.prm.kid.as_deref() == Some(key_id))
}
map: DashMap<EndpointIdInt, Arc<JwkCacheEntryLock>>,
}
pub struct JwkCacheEntryLock {
@@ -89,6 +57,15 @@ impl Default for JwkCacheEntryLock {
}
}
pub struct JwkCacheEntry {
/// Should refetch at least every hour to verify when old keys have been removed.
/// Should refetch when new key IDs are seen only every 5 minutes or so
last_retrieved: Instant,
/// cplane will return multiple JWKs urls that we need to scrape.
key_sets: ahash::HashMap<url::Url, jose_jwk::JwkSet>,
}
impl JwkCacheEntryLock {
async fn acquire_permit<'a>(self: &'a Arc<Self>) -> JwkRenewalPermit<'a> {
JwkRenewalPermit::acquire_permit(self).await
@@ -102,7 +79,6 @@ impl JwkCacheEntryLock {
&self,
_permit: JwkRenewalPermit<'_>,
client: &reqwest::Client,
role_name: RoleName,
auth_rules: &F,
) -> anyhow::Result<Arc<JwkCacheEntry>> {
// double check that no one beat us to updating the cache.
@@ -115,19 +91,20 @@ impl JwkCacheEntryLock {
}
}
let rules = auth_rules.fetch_auth_rules(role_name).await?;
let mut key_sets =
ahash::HashMap::with_capacity_and_hasher(rules.len(), ahash::RandomState::new());
let rules = auth_rules.fetch_auth_rules().await?;
let mut key_sets = ahash::HashMap::with_capacity_and_hasher(
rules.jwks_urls.len(),
ahash::RandomState::new(),
);
// TODO(conrad): run concurrently
// TODO(conrad): strip the JWKs urls (should be checked by cplane as well - cloud#16284)
for rule in rules {
let req = client.get(rule.jwks_url.clone());
for url in rules.jwks_urls {
let req = client.get(url.clone());
// TODO(conrad): eventually switch to using reqwest_middleware/`new_client_with_timeout`.
// TODO(conrad): We need to filter out URLs that point to local resources. Public internet only.
match req.send().await.and_then(|r| r.error_for_status()) {
// todo: should we re-insert JWKs if we want to keep this JWKs URL?
// I expect these failures would be quite sparse.
Err(e) => tracing::warn!(url=?rule.jwks_url, error=?e, "could not fetch JWKs"),
Err(e) => tracing::warn!(?url, error=?e, "could not fetch JWKs"),
Ok(r) => {
let resp: http::Response<reqwest::Body> = r.into();
match parse_json_body_with_limit::<jose_jwk::JwkSet>(
@@ -136,17 +113,9 @@ impl JwkCacheEntryLock {
)
.await
{
Err(e) => {
tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs");
}
Err(e) => tracing::warn!(?url, error=?e, "could not decode JWKs"),
Ok(jwks) => {
key_sets.insert(
rule.id,
KeySet {
jwks,
audience: rule.audience,
},
);
key_sets.insert(url, jwks);
}
}
}
@@ -164,9 +133,7 @@ impl JwkCacheEntryLock {
async fn get_or_update_jwk_cache<F: FetchAuthRules>(
self: &Arc<Self>,
ctx: &RequestMonitoring,
client: &reqwest::Client,
role_name: RoleName,
fetch: &F,
) -> Result<Arc<JwkCacheEntry>, anyhow::Error> {
let now = Instant::now();
@@ -174,20 +141,18 @@ impl JwkCacheEntryLock {
// if we have no cached JWKs, try and get some
let Some(cached) = guard else {
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let permit = self.acquire_permit().await;
return self.renew_jwks(permit, client, role_name, fetch).await;
return self.renew_jwks(permit, client, fetch).await;
};
let last_update = now.duration_since(cached.last_retrieved);
// check if the cached JWKs need updating.
if last_update > MAX_RENEW {
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let permit = self.acquire_permit().await;
// it's been too long since we checked the keys. wait for them to update.
return self.renew_jwks(permit, client, role_name, fetch).await;
return self.renew_jwks(permit, client, fetch).await;
}
// every 5 minutes we should spawn a job to eagerly update the token.
@@ -199,7 +164,7 @@ impl JwkCacheEntryLock {
let client = client.clone();
let fetch = fetch.clone();
tokio::spawn(async move {
if let Err(e) = entry.renew_jwks(permit, &client, role_name, &fetch).await {
if let Err(e) = entry.renew_jwks(permit, &client, &fetch).await {
tracing::warn!(error=?e, "could not fetch JWKs in background job");
}
});
@@ -213,10 +178,8 @@ impl JwkCacheEntryLock {
async fn check_jwt<F: FetchAuthRules>(
self: &Arc<Self>,
ctx: &RequestMonitoring,
jwt: &str,
jwt: String,
client: &reqwest::Client,
role_name: RoleName,
fetch: &F,
) -> Result<(), anyhow::Error> {
// JWT compact form is defined to be
@@ -224,38 +187,38 @@ impl JwkCacheEntryLock {
// where Signature = alg(<B64(Header)> || . || <B64(Payload)>);
let (header_payload, signature) = jwt
.rsplit_once('.')
.rsplit_once(".")
.context("Provided authentication token is not a valid JWT encoding")?;
let (header, payload) = header_payload
.split_once('.')
let (header, _payload) = header_payload
.split_once(".")
.context("Provided authentication token is not a valid JWT encoding")?;
let header = base64::decode_config(header, base64::URL_SAFE_NO_PAD)
.context("Provided authentication token is not a valid JWT encoding")?;
let header = serde_json::from_slice::<JwtHeader<'_>>(&header)
let header = serde_json::from_slice::<JWTHeader<'_>>(&header)
.context("Provided authentication token is not a valid JWT encoding")?;
let sig = base64::decode_config(signature, base64::URL_SAFE_NO_PAD)
.context("Provided authentication token is not a valid JWT encoding")?;
ensure!(header.typ == "JWT");
let kid = header.key_id.context("missing key id")?;
let kid = header.kid.context("missing key id")?;
let mut guard = self
.get_or_update_jwk_cache(ctx, client, role_name.clone(), fetch)
.await?;
let mut guard = self.get_or_update_jwk_cache(client, fetch).await?;
// get the key from the JWKs if possible. If not, wait for the keys to update.
let (jwk, expected_audience) = loop {
match guard.find_jwk_and_audience(kid) {
let jwk = loop {
let jwk = guard
.key_sets
.values()
.flat_map(|jwks| &jwks.keys)
.find(|jwk| jwk.prm.kid.as_deref() == Some(kid));
match jwk {
Some(jwk) => break jwk,
None if guard.last_retrieved.elapsed() > MIN_RENEW => {
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let permit = self.acquire_permit().await;
guard = self
.renew_jwks(permit, client, role_name.clone(), fetch)
.await?;
guard = self.renew_jwks(permit, client, fetch).await?;
}
_ => {
bail!("jwk not found");
@@ -264,7 +227,7 @@ impl JwkCacheEntryLock {
};
ensure!(
jwk.is_supported(&header.algorithm),
jwk.is_supported(&header.alg),
"signature algorithm not supported"
);
@@ -278,57 +241,31 @@ impl JwkCacheEntryLock {
key => bail!("unsupported key type {key:?}"),
};
let payload = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)
.context("Provided authentication token is not a valid JWT encoding")?;
let payload = serde_json::from_slice::<JwtPayload<'_>>(&payload)
.context("Provided authentication token is not a valid JWT encoding")?;
tracing::debug!(?payload, "JWT signature valid with claims");
match (expected_audience, payload.audience) {
// check the audience matches
(Some(aud1), Some(aud2)) => ensure!(aud1 == aud2, "invalid JWT token audience"),
// the audience is expected but is missing
(Some(_), None) => bail!("invalid JWT token audience"),
// we don't care for the audience field
(None, _) => {}
}
let now = SystemTime::now();
if let Some(exp) = payload.expiration {
ensure!(now < exp + CLOCK_SKEW_LEEWAY);
}
if let Some(nbf) = payload.not_before {
ensure!(nbf < now + CLOCK_SKEW_LEEWAY);
}
// TODO(conrad): verify iss, exp, nbf, etc...
Ok(())
}
}
impl JwkCache {
pub async fn check_jwt<F: FetchAuthRules>(
pub async fn check_jwt(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
role_name: RoleName,
fetch: &F,
jwt: &str,
endpoint: EndpointIdInt,
jwt: String,
) -> Result<(), anyhow::Error> {
// try with just a read lock first
let key = (endpoint, role_name.clone());
let entry = self.map.get(&key).as_deref().map(Arc::clone);
let entry = entry.unwrap_or_else(|| {
// acquire a write lock after to insert.
let entry = self.map.entry(key).or_default();
Arc::clone(&*entry)
});
let entry = self.map.get(&endpoint).as_deref().map(Arc::clone);
let entry = match entry {
Some(entry) => entry,
None => {
// acquire a write lock after to insert.
let entry = self.map.entry(endpoint).or_default();
Arc::clone(&*entry)
}
};
entry
.check_jwt(ctx, jwt, &self.client, role_name, fetch)
.await
let fetch = FetchAuthRulesFromCplane { endpoint };
entry.check_jwt(jwt, &self.client, &fetch).await
}
}
@@ -378,49 +315,13 @@ fn verify_rsa_signature(
/// <https://datatracker.ietf.org/doc/html/rfc7515#section-4.1>
#[derive(serde::Deserialize, serde::Serialize)]
struct JwtHeader<'a> {
struct JWTHeader<'a> {
/// must be "JWT"
#[serde(rename = "typ")]
typ: &'a str,
/// must be a supported alg
#[serde(rename = "alg")]
algorithm: jose_jwa::Algorithm,
alg: jose_jwa::Algorithm,
/// key id, must be provided for our usecase
#[serde(rename = "kid")]
key_id: Option<&'a str>,
}
/// <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1>
#[derive(serde::Deserialize, serde::Serialize, Debug)]
struct JwtPayload<'a> {
/// Audience - Recipient for which the JWT is intended
#[serde(rename = "aud")]
audience: Option<&'a str>,
/// Expiration - Time after which the JWT expires
#[serde(deserialize_with = "numeric_date_opt", rename = "exp", default)]
expiration: Option<SystemTime>,
/// Not before - Time after which the JWT expires
#[serde(deserialize_with = "numeric_date_opt", rename = "nbf", default)]
not_before: Option<SystemTime>,
// the following entries are only extracted for the sake of debug logging.
/// Issuer of the JWT
#[serde(rename = "iss")]
issuer: Option<&'a str>,
/// Subject of the JWT (the user)
#[serde(rename = "sub")]
subject: Option<&'a str>,
/// Unique token identifier
#[serde(rename = "jti")]
jwt_id: Option<&'a str>,
/// Unique session identifier
#[serde(rename = "sid")]
session_id: Option<&'a str>,
}
fn numeric_date_opt<'de, D: Deserializer<'de>>(d: D) -> Result<Option<SystemTime>, D::Error> {
let d = <Option<u64>>::deserialize(d)?;
Ok(d.map(|n| SystemTime::UNIX_EPOCH + Duration::from_secs(n)))
kid: Option<&'a str>,
}
struct JwkRenewalPermit<'a> {
@@ -487,8 +388,6 @@ impl Drop for JwkRenewalPermit<'_> {
#[cfg(test)]
mod tests {
use crate::RoleName;
use super::*;
use std::{future::IntoFuture, net::SocketAddr, time::SystemTime};
@@ -532,10 +431,10 @@ mod tests {
}
fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String {
let header = JwtHeader {
let header = JWTHeader {
typ: "JWT",
algorithm: jose_jwa::Algorithm::Signing(sig),
key_id: Some(&kid),
alg: jose_jwa::Algorithm::Signing(sig),
kid: Some(&kid),
};
let body = typed_json::json! {{
"exp": SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() + 3600,
@@ -625,40 +524,33 @@ mod tests {
struct Fetch(SocketAddr);
impl FetchAuthRules for Fetch {
async fn fetch_auth_rules(
&self,
_role_name: RoleName,
) -> anyhow::Result<Vec<AuthRule>> {
Ok(vec![
AuthRule {
id: "foo".to_owned(),
jwks_url: format!("http://{}/foo", self.0).parse().unwrap(),
audience: None,
},
AuthRule {
id: "bar".to_owned(),
jwks_url: format!("http://{}/bar", self.0).parse().unwrap(),
audience: None,
},
])
async fn fetch_auth_rules(&self) -> anyhow::Result<AuthRules> {
Ok(AuthRules {
jwks_urls: vec![
format!("http://{}/foo", self.0).parse().unwrap(),
format!("http://{}/bar", self.0).parse().unwrap(),
],
})
}
}
let role_name = RoleName::from("user");
let jwk_cache = Arc::new(JwkCacheEntryLock::default());
for token in [jwt1, jwt2, jwt3, jwt4] {
jwk_cache
.check_jwt(
&RequestMonitoring::test(),
&token,
&client,
role_name.clone(),
&Fetch(addr),
)
.await
.unwrap();
}
jwk_cache
.check_jwt(jwt1, &client, &Fetch(addr))
.await
.unwrap();
jwk_cache
.check_jwt(jwt2, &client, &Fetch(addr))
.await
.unwrap();
jwk_cache
.check_jwt(jwt3, &client, &Fetch(addr))
.await
.unwrap();
jwk_cache
.check_jwt(jwt4, &client, &Fetch(addr))
.await
.unwrap();
}
}

View File

@@ -1,79 +0,0 @@
use std::{collections::HashMap, net::SocketAddr};
use anyhow::Context;
use arc_swap::ArcSwapOption;
use crate::{
compute::ConnCfg,
console::{
messages::{ColdStartInfo, EndpointJwksResponse, MetricsAuxInfo},
NodeInfo,
},
intern::{BranchIdInt, BranchIdTag, EndpointIdTag, InternId, ProjectIdInt, ProjectIdTag},
RoleName,
};
use super::jwt::{AuthRule, FetchAuthRules, JwkCache};
pub struct LocalBackend {
pub jwks_cache: JwkCache,
pub postgres_addr: SocketAddr,
pub node_info: NodeInfo,
}
impl LocalBackend {
pub fn new(postgres_addr: SocketAddr) -> Self {
LocalBackend {
jwks_cache: JwkCache::default(),
postgres_addr,
node_info: NodeInfo {
config: {
let mut cfg = ConnCfg::new();
cfg.host(&postgres_addr.ip().to_string());
cfg.port(postgres_addr.port());
cfg
},
// TODO(conrad): make this better reflect compute info rather than endpoint info.
aux: MetricsAuxInfo {
endpoint_id: EndpointIdTag::get_interner().get_or_intern("local"),
project_id: ProjectIdTag::get_interner().get_or_intern("local"),
branch_id: BranchIdTag::get_interner().get_or_intern("local"),
cold_start_info: ColdStartInfo::WarmCached,
},
allow_self_signed_compute: false,
},
}
}
}
#[derive(Clone, Copy)]
pub struct StaticAuthRules;
pub static JWKS_ROLE_MAP: ArcSwapOption<JwksRoleSettings> = ArcSwapOption::const_empty();
#[derive(Debug, Clone)]
pub struct JwksRoleSettings {
pub roles: HashMap<RoleName, EndpointJwksResponse>,
pub project_id: ProjectIdInt,
pub branch_id: BranchIdInt,
}
impl FetchAuthRules for StaticAuthRules {
async fn fetch_auth_rules(&self, role_name: RoleName) -> anyhow::Result<Vec<AuthRule>> {
let mappings = JWKS_ROLE_MAP.load();
let role_mappings = mappings
.as_deref()
.and_then(|m| m.roles.get(&role_name))
.context("JWKs settings for this role were not configured")?;
let mut rules = vec![];
for setting in &role_mappings.jwks {
rules.push(AuthRule {
id: setting.id.clone(),
jwks_url: setting.jwks_url.clone(),
audience: setting.jwt_audience.clone(),
});
}
Ok(rules)
}
}

View File

@@ -86,8 +86,7 @@ impl ComputeUserInfoMaybeEndpoint {
pub fn parse(
ctx: &RequestMonitoring,
params: &StartupMessageParams,
sni: Option<&str>,
common_names: Option<&HashSet<String>>,
endpoint_from_domain: Option<EndpointId>,
) -> Result<Self, ComputeUserInfoParseError> {
// Some parameters are stored in the startup message.
let get_param = |key| {
@@ -111,16 +110,7 @@ impl ComputeUserInfoMaybeEndpoint {
})
.map(|name| name.into());
let endpoint_from_domain = if let Some(sni_str) = sni {
if let Some(cn) = common_names {
endpoint_sni(sni_str, cn)?
} else {
None
}
} else {
None
};
let is_sni = endpoint_from_domain.is_some();
let endpoint = match (endpoint_option, endpoint_from_domain) {
// Invariant: if we have both project name variants, they should match.
(Some(option), Some(domain)) if option != domain => {
@@ -130,12 +120,9 @@ impl ComputeUserInfoMaybeEndpoint {
}))
}
// Invariant: project name may not contain certain characters.
(a, b) => a.or(b).map(|name| {
if project_name_valid(name.as_ref()) {
Ok(name)
} else {
Err(ComputeUserInfoParseError::MalformedProjectName(name))
}
(a, b) => a.or(b).map(|name| match project_name_valid(name.as_ref()) {
false => Err(ComputeUserInfoParseError::MalformedProjectName(name)),
true => Ok(name),
}),
}
.transpose()?;
@@ -146,7 +133,7 @@ impl ComputeUserInfoMaybeEndpoint {
let metrics = Metrics::get();
info!(%user, "credentials");
if sni.is_some() {
if is_sni {
info!("Connection with sni");
metrics.proxy.accepted_connections_by_sni.inc(SniKind::Sni);
} else if endpoint.is_some() {
@@ -258,7 +245,7 @@ mod tests {
// According to postgresql, only `user` should be required.
let options = StartupMessageParams::new([("user", "john_doe")]);
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id, None);
@@ -273,7 +260,7 @@ mod tests {
("foo", "bar"), // should be ignored
]);
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id, None);
@@ -284,12 +271,8 @@ mod tests {
fn parse_project_from_sni() -> anyhow::Result<()> {
let options = StartupMessageParams::new([("user", "john_doe")]);
let sni = Some("foo.localhost");
let common_names = Some(["localhost".into()].into());
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, Some("foo".into()))?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id.as_deref(), Some("foo"));
assert_eq!(user_info.options.get_cache_key("foo"), "foo");
@@ -305,7 +288,7 @@ mod tests {
]);
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
@@ -320,7 +303,7 @@ mod tests {
]);
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
@@ -338,7 +321,7 @@ mod tests {
]);
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?;
assert_eq!(user_info.user, "john_doe");
assert!(user_info.endpoint_id.is_none());
@@ -353,7 +336,7 @@ mod tests {
]);
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?;
assert_eq!(user_info.user, "john_doe");
assert!(user_info.endpoint_id.is_none());
@@ -364,49 +347,21 @@ mod tests {
fn parse_projects_identical() -> anyhow::Result<()> {
let options = StartupMessageParams::new([("user", "john_doe"), ("options", "project=baz")]);
let sni = Some("baz.localhost");
let common_names = Some(["localhost".into()].into());
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, Some("baz".into()))?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id.as_deref(), Some("baz"));
Ok(())
}
#[test]
fn parse_multi_common_names() -> anyhow::Result<()> {
let options = StartupMessageParams::new([("user", "john_doe")]);
let common_names = Some(["a.com".into(), "b.com".into()].into());
let sni = Some("p1.a.com");
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
let common_names = Some(["a.com".into(), "b.com".into()].into());
let sni = Some("p1.b.com");
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
Ok(())
}
#[test]
fn parse_projects_different() {
let options =
StartupMessageParams::new([("user", "john_doe"), ("options", "project=first")]);
let sni = Some("second.localhost");
let common_names = Some(["localhost".into()].into());
let ctx = RequestMonitoring::test();
let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, Some("second".into()))
.expect_err("should fail");
match err {
InconsistentProjectNames { domain, option } => {
@@ -417,24 +372,6 @@ mod tests {
}
}
#[test]
fn parse_inconsistent_sni() {
let options = StartupMessageParams::new([("user", "john_doe")]);
let sni = Some("project.localhost");
let common_names = Some(["example.com".into()].into());
let ctx = RequestMonitoring::test();
let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
.expect_err("should fail");
match err {
UnknownCommonName { cn } => {
assert_eq!(cn, "localhost");
}
_ => panic!("bad error: {err:?}"),
}
}
#[test]
fn parse_neon_options() -> anyhow::Result<()> {
let options = StartupMessageParams::new([
@@ -442,11 +379,9 @@ mod tests {
("options", "neon_lsn:0/2 neon_endpoint_type:read_write"),
]);
let sni = Some("project.localhost");
let common_names = Some(["localhost".into()].into());
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, Some("project".into()))?;
assert_eq!(user_info.endpoint_id.as_deref(), Some("project"));
assert_eq!(
user_info.options.get_cache_key("project"),

View File

@@ -6,13 +6,14 @@ use crate::{
console::AuthSecret,
context::RequestMonitoring,
intern::EndpointIdInt,
proxy::handshake::KtlsAsyncReadReady,
sasl,
scram::{self, threadpool::ThreadPool},
stream::{PqStream, Stream},
};
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
use std::{io, sync::Arc};
use std::{io, os::fd::AsRawFd, sync::Arc};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
@@ -70,7 +71,7 @@ impl AuthMethod for CleartextPassword {
/// This wrapper for [`PqStream`] performs client authentication.
#[must_use]
pub struct AuthFlow<'a, S, State> {
pub struct AuthFlow<'a, S: AsRawFd, State> {
/// The underlying stream which implements libpq's protocol.
stream: &'a mut PqStream<Stream<S>>,
/// State might contain ancillary data (see [`Self::begin`]).
@@ -79,7 +80,7 @@ pub struct AuthFlow<'a, S, State> {
}
/// Initial state of the stream wrapper.
impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
impl<'a, S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady> AuthFlow<'a, S, Begin> {
/// Create a new wrapper for client authentication.
pub fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
let tls_server_end_point = stream.get_ref().tls_server_end_point();
@@ -105,7 +106,9 @@ impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
impl<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>
AuthFlow<'_, S, PasswordHack>
{
/// Perform user authentication. Raise an error in case authentication failed.
pub async fn get_password(self) -> super::Result<PasswordHackPayload> {
let msg = self.stream.read_password_message().await?;
@@ -124,7 +127,9 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
impl<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>
AuthFlow<'_, S, CleartextPassword>
{
/// Perform user authentication. Raise an error in case authentication failed.
pub async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
let msg = self.stream.read_password_message().await?;
@@ -149,7 +154,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
}
/// Stream wrapper for handling [SCRAM](crate::scram) auth.
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
impl<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady> AuthFlow<'_, S, Scram<'_>> {
/// Perform user authentication. Raise an error in case authentication failed.
pub async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
let Scram(secret, ctx) = self.state;

View File

@@ -1,316 +0,0 @@
use std::{
net::SocketAddr,
path::{Path, PathBuf},
pin::pin,
sync::Arc,
time::Duration,
};
use anyhow::{bail, ensure};
use dashmap::DashMap;
use futures::{future::Either, FutureExt};
use proxy::{
auth::backend::local::{JwksRoleSettings, LocalBackend, JWKS_ROLE_MAP},
cancellation::CancellationHandlerMain,
config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig},
console::{locks::ApiLocks, messages::JwksRoleMapping},
http::health_server::AppMetrics,
metrics::{Metrics, ThreadPoolMetrics},
rate_limiter::{BucketRateLimiter, EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo},
scram::threadpool::ThreadPool,
serverless::{self, cancel_set::CancelSet, GlobalConnPoolOptions},
};
project_git_version!(GIT_VERSION);
project_build_tag!(BUILD_TAG);
use clap::Parser;
use tokio::{net::TcpListener, task::JoinSet};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn};
use utils::{project_build_tag, project_git_version, sentry_init::init_sentry};
#[global_allocator]
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
/// Neon proxy/router
#[derive(Parser)]
#[command(version = GIT_VERSION, about)]
struct LocalProxyCliArgs {
/// listen for incoming metrics connections on ip:port
#[clap(long, default_value = "127.0.0.1:7001")]
metrics: String,
/// listen for incoming http connections on ip:port
#[clap(long)]
http: String,
/// timeout for the TLS handshake
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
handshake_timeout: tokio::time::Duration,
/// lock for `connect_compute` api method. example: "shards=32,permits=4,epoch=10m,timeout=1s". (use `permits=0` to disable).
#[clap(long, default_value = config::ConcurrencyLockOptions::DEFAULT_OPTIONS_CONNECT_COMPUTE_LOCK)]
connect_compute_lock: String,
#[clap(flatten)]
sql_over_http: SqlOverHttpArgs,
/// User rate limiter max number of requests per second.
///
/// Provided in the form `<Requests Per Second>@<Bucket Duration Size>`.
/// Can be given multiple times for different bucket sizes.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
user_rps_limit: Vec<RateBucketInfo>,
/// Whether the auth rate limiter actually takes effect (for testing)
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
auth_rate_limit_enabled: bool,
/// Authentication rate limiter max number of hashes per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)]
auth_rate_limit: Vec<RateBucketInfo>,
/// The IP subnet to use when considering whether two IP addresses are considered the same.
#[clap(long, default_value_t = 64)]
auth_rate_limit_ip_subnet: u8,
/// Whether to retry the connection to the compute node
#[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)]
connect_to_compute_retry: String,
/// Address of the postgres server
#[clap(long, default_value = "127.0.0.1:5432")]
compute: SocketAddr,
/// File address of the local proxy config file
#[clap(long, default_value = "./localproxy.json")]
config_path: PathBuf,
}
#[derive(clap::Args, Clone, Copy, Debug)]
struct SqlOverHttpArgs {
/// How many connections to pool for each endpoint. Excess connections are discarded
#[clap(long, default_value_t = 200)]
sql_over_http_pool_max_total_conns: usize,
/// How long pooled connections should remain idle for before closing
#[clap(long, default_value = "5m", value_parser = humantime::parse_duration)]
sql_over_http_idle_timeout: tokio::time::Duration,
#[clap(long, default_value_t = 100)]
sql_over_http_client_conn_threshold: u64,
#[clap(long, default_value_t = 16)]
sql_over_http_cancel_set_shards: usize,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let _logging_guard = proxy::logging::init().await?;
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
info!("Version: {GIT_VERSION}");
info!("Build_tag: {BUILD_TAG}");
let neon_metrics = ::metrics::NeonMetrics::new(::metrics::BuildInfo {
revision: GIT_VERSION,
build_tag: BUILD_TAG,
});
let jemalloc = match proxy::jemalloc::MetricRecorder::new() {
Ok(t) => Some(t),
Err(e) => {
tracing::error!(error = ?e, "could not start jemalloc metrics loop");
None
}
};
let args = LocalProxyCliArgs::parse();
let config = build_config(&args)?;
let metrics_listener = TcpListener::bind(args.metrics).await?.into_std()?;
let http_listener = TcpListener::bind(args.http).await?;
let shutdown = CancellationToken::new();
// todo: should scale with CU
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
LeakyBucketConfig {
rps: 10.0,
max: 100.0,
},
16,
));
refresh_config(args.config_path.clone()).await;
let mut maintenance_tasks = JoinSet::new();
maintenance_tasks.spawn(proxy::handle_signals(shutdown.clone(), move || {
refresh_config(args.config_path.clone()).map(Ok)
}));
maintenance_tasks.spawn(proxy::http::health_server::task_main(
metrics_listener,
AppMetrics {
jemalloc,
neon_metrics,
proxy: proxy::metrics::Metrics::get(),
},
));
let task = serverless::task_main(
config,
http_listener,
shutdown.clone(),
Arc::new(CancellationHandlerMain::new(
Arc::new(DashMap::new()),
None,
proxy::metrics::CancellationSource::Local,
)),
endpoint_rate_limiter,
);
match futures::future::select(pin!(maintenance_tasks.join_next()), pin!(task)).await {
// exit immediately on maintenance task completion
Either::Left((Some(res), _)) => match proxy::flatten_err(res)? {},
// exit with error immediately if all maintenance tasks have ceased (should be caught by branch above)
Either::Left((None, _)) => bail!("no maintenance tasks running. invalid state"),
// exit immediately on client task error
Either::Right((res, _)) => res?,
}
Ok(())
}
/// ProxyConfig is created at proxy startup, and lives forever.
fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
let config::ConcurrencyLockOptions {
shards,
limiter,
epoch,
timeout,
} = args.connect_compute_lock.parse()?;
info!(
?limiter,
shards,
?epoch,
"Using NodeLocks (connect_compute)"
);
let connect_compute_locks = ApiLocks::new(
"connect_compute_lock",
limiter,
shards,
timeout,
epoch,
&Metrics::get().proxy.connect_compute_lock,
)?;
let http_config = HttpConfig {
accept_websockets: false,
pool_options: GlobalConnPoolOptions {
gc_epoch: Duration::from_secs(60),
pool_shards: 2,
idle_timeout: args.sql_over_http.sql_over_http_idle_timeout,
opt_in: false,
max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_total_conns,
max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns,
},
cancel_set: CancelSet::new(args.sql_over_http.sql_over_http_cancel_set_shards),
client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold,
};
Ok(Box::leak(Box::new(ProxyConfig {
tls_config: None,
auth_backend: proxy::auth::BackendType::Local(proxy::auth::backend::MaybeOwned::Owned(
LocalBackend::new(args.compute),
)),
metric_collection: None,
allow_self_signed_compute: false,
http_config,
authentication_config: AuthenticationConfig {
thread_pool: ThreadPool::new(0),
scram_protocol_timeout: Duration::from_secs(10),
rate_limiter_enabled: false,
rate_limiter: BucketRateLimiter::new(vec![]),
rate_limit_ip_subnet: 64,
},
require_client_ip: false,
handshake_timeout: Duration::from_secs(10),
region: "local".into(),
wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
connect_compute_locks,
connect_to_compute_retry_config: RetryConfig::parse(
RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES,
)?,
})))
}
async fn refresh_config(path: PathBuf) {
match refresh_config_inner(&path).await {
Ok(()) => {}
Err(e) => {
error!(error=?e, ?path, "could not read config file");
}
}
}
async fn refresh_config_inner(path: &Path) -> anyhow::Result<()> {
let bytes = tokio::fs::read(&path).await?;
let mut data: JwksRoleMapping = serde_json::from_slice(&bytes)?;
let mut settings = None;
for mapping in data.roles.values_mut() {
for jwks in &mut mapping.jwks {
ensure!(
jwks.jwks_url.has_authority()
&& (jwks.jwks_url.scheme() == "http" || jwks.jwks_url.scheme() == "https"),
"Invalid JWKS url. Must be HTTP",
);
ensure!(
jwks.jwks_url
.host()
.is_some_and(|h| h != url::Host::Domain("")),
"Invalid JWKS url. No domain listed",
);
// clear username, password and ports
jwks.jwks_url.set_username("").expect(
"url can be a base and has a valid host and is not a file. should not error",
);
jwks.jwks_url.set_password(None).expect(
"url can be a base and has a valid host and is not a file. should not error",
);
// local testing is hard if we need to have a specific restricted port
if cfg!(not(feature = "testing")) {
jwks.jwks_url.set_port(None).expect(
"url can be a base and has a valid host and is not a file. should not error",
);
}
// clear query params
jwks.jwks_url.set_fragment(None);
jwks.jwks_url.query_pairs_mut().clear().finish();
if jwks.jwks_url.scheme() != "https" {
// local testing is hard if we need to set up https support.
if cfg!(not(feature = "testing")) {
jwks.jwks_url
.set_scheme("https")
.expect("should not error to set the scheme to https if it was http");
} else {
warn!(scheme = jwks.jwks_url.scheme(), "JWKS url is not HTTPS");
}
}
let (pr, br) = settings.get_or_insert((jwks.project_id, jwks.branch_id));
ensure!(
*pr == jwks.project_id,
"inconsistent project IDs configured"
);
ensure!(*br == jwks.branch_id, "inconsistent branch IDs configured");
}
}
if let Some((project_id, branch_id)) = settings {
JWKS_ROLE_MAP.store(Some(Arc::new(JwksRoleSettings {
roles: data.roles,
project_id,
branch_id,
})));
}
Ok(())
}

View File

@@ -1,3 +1,4 @@
use std::os::fd::AsRawFd;
/// A stand-alone program that routes connections, e.g. from
/// `aaa--bbb--1234.external.domain` to `aaa.bbb.internal.domain:1234`.
///
@@ -7,9 +8,9 @@ use std::{net::SocketAddr, sync::Arc};
use futures::future::Either;
use itertools::Itertools;
use proxy::config::TlsServerEndPoint;
use proxy::context::RequestMonitoring;
use proxy::metrics::{Metrics, ThreadPoolMetrics};
use proxy::proxy::handshake::KtlsAsyncReadReady;
use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource};
use rustls::pki_types::PrivateKeyDer;
use tokio::net::TcpListener;
@@ -20,6 +21,7 @@ use futures::TryFutureExt;
use proxy::stream::{PqStream, Stream};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::server::TlsStream;
use tokio_util::sync::CancellationToken;
use utils::{project_git_version, sentry_init::init_sentry};
@@ -72,7 +74,7 @@ async fn main() -> anyhow::Result<()> {
let destination: String = args.get_one::<String>("dest").unwrap().parse()?;
// Configure TLS
let (tls_config, tls_server_end_point): (Arc<rustls::ServerConfig>, TlsServerEndPoint) = match (
let tls_config = match (
args.get_one::<String>("tls-key"),
args.get_one::<String>("tls-cert"),
) {
@@ -102,19 +104,14 @@ async fn main() -> anyhow::Result<()> {
})?
};
// needed for channel bindings
let first_cert = cert_chain.first().context("missing certificate")?;
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let tls_config = rustls::ServerConfig::builder_with_protocol_versions(&[
&rustls::version::TLS13,
&rustls::version::TLS12,
])
.with_no_client_auth()
.with_single_cert(cert_chain, key)?
.into();
(tls_config, tls_server_end_point)
Arc::new(
rustls::ServerConfig::builder_with_protocol_versions(&[
&rustls::version::TLS13,
&rustls::version::TLS12,
])
.with_no_client_auth()
.with_single_cert(cert_chain, key)?,
)
}
_ => bail!("tls-key and tls-cert must be specified"),
};
@@ -129,13 +126,10 @@ async fn main() -> anyhow::Result<()> {
let main = tokio::spawn(task_main(
Arc::new(destination),
tls_config,
tls_server_end_point,
proxy_listener,
cancellation_token.clone(),
));
let signals_task = tokio::spawn(proxy::handle_signals(cancellation_token, || async {
Ok(())
}));
let signals_task = tokio::spawn(proxy::handle_signals(cancellation_token));
// the signal task cant ever succeed.
// the main task can error, or can succeed on cancellation.
@@ -153,7 +147,6 @@ async fn main() -> anyhow::Result<()> {
async fn task_main(
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
) -> anyhow::Result<()> {
@@ -185,7 +178,7 @@ async fn task_main(
proxy::metrics::Protocol::SniRouter,
"sni",
);
handle_client(ctx, dest_suffix, tls_config, tls_server_end_point, socket).await
handle_client(ctx, dest_suffix, tls_config, socket).await
}
.unwrap_or_else(|e| {
// Acknowledge that the task has finished with an error.
@@ -206,12 +199,11 @@ async fn task_main(
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>(
ctx: &RequestMonitoring,
raw_stream: S,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
) -> anyhow::Result<Stream<S>> {
) -> anyhow::Result<Box<TlsStream<S>>> {
let mut stream = PqStream::new(Stream::from_raw(raw_stream));
let msg = stream.read_startup_packet().await?;
@@ -237,13 +229,10 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
bail!("data is sent before server replied with EncryptionResponse");
}
Ok(Stream::Tls {
tls: Box::new(
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
.await?,
),
tls_server_end_point,
})
Ok(Box::new(
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
.await?,
))
}
unexpected => {
info!(
@@ -261,15 +250,18 @@ async fn handle_client(
ctx: RequestMonitoring,
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
stream: impl AsyncRead + AsyncWrite + Unpin,
stream: impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady,
) -> anyhow::Result<()> {
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?;
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config).await?;
// Cut off first part of the SNI domain
// We receive required destination details in the format of
// `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain`
let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?;
let sni = tls_stream
.get_ref()
.1
.server_name()
.ok_or(anyhow!("SNI missing"))?;
let dest: Vec<&str> = sni
.split_once('.')
.context("invalid SNI")?

View File

@@ -148,7 +148,7 @@ struct ProxyCliArgs {
disable_dynamic_rate_limiter: bool,
/// Endpoint rate limiter max number of requests per second.
///
/// Provided in the form `<Requests Per Second>@<Bucket Duration Size>`.
/// Provided in the form '<Requests Per Second>@<Bucket Duration Size>'.
/// Can be given multiple times for different bucket sizes.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
endpoint_rps_limit: Vec<RateBucketInfo>,
@@ -285,7 +285,7 @@ async fn main() -> anyhow::Result<()> {
};
let args = ProxyCliArgs::parse();
let config = build_config(&args)?;
let config = build_config(&args).await?;
info!("Authentication backend: {}", config.auth_backend);
info!("Using region: {}", args.aws_region);
@@ -447,10 +447,7 @@ async fn main() -> anyhow::Result<()> {
// maintenance tasks. these never return unless there's an error
let mut maintenance_tasks = JoinSet::new();
maintenance_tasks.spawn(proxy::handle_signals(
cancellation_token.clone(),
|| async { Ok(()) },
));
maintenance_tasks.spawn(proxy::handle_signals(cancellation_token.clone()));
maintenance_tasks.spawn(http::health_server::task_main(
http_listener,
AppMetrics {
@@ -532,16 +529,14 @@ async fn main() -> anyhow::Result<()> {
}
/// ProxyConfig is created at proxy startup, and lives forever.
fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
async fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
let thread_pool = ThreadPool::new(args.scram_thread_pool_size);
Metrics::install(thread_pool.metrics.clone());
let tls_config = match (&args.tls_key, &args.tls_cert) {
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(
key_path,
cert_path,
args.certs_dir.as_ref(),
)?),
(Some(key_path), Some(cert_path)) => {
Some(config::configure_tls(key_path, cert_path, args.certs_dir.as_ref()).await?)
}
(None, None) => None,
_ => bail!("either both or neither tls-key and tls-cert must be specified"),
};

View File

@@ -274,13 +274,13 @@ impl ProjectInfoCacheImpl {
let ttl_disabled_since_us = self
.ttl_disabled_since_us
.load(std::sync::atomic::Ordering::Relaxed);
let ignore_cache_since = if ttl_disabled_since_us == u64::MAX {
None
} else {
let ignore_cache_since = if ttl_disabled_since_us != u64::MAX {
let ignore_cache_since = self.start_time + Duration::from_micros(ttl_disabled_since_us);
// We are fine if entry is not older than ttl or was added before we are getting notifications.
valid_since = valid_since.min(ignore_cache_since);
Some(ignore_cache_since)
} else {
None
};
(valid_since, ignore_cache_since)
}
@@ -306,7 +306,7 @@ impl ProjectInfoCacheImpl {
let mut removed = 0;
let shard = self.project2ep.shards()[shard].write();
for (_, endpoints) in shard.iter() {
for endpoint in endpoints.get() {
for endpoint in endpoints.get().iter() {
self.cache.remove(endpoint);
removed += 1;
}

View File

@@ -220,8 +220,7 @@ mod tests {
#[tokio::test]
async fn cancel_session_noop_regression() {
let handler =
CancellationHandler::<()>::new(CancelMap::default(), CancellationSource::Local);
let handler = CancellationHandler::<()>::new(Default::default(), CancellationSource::Local);
handler
.cancel_session(
CancelKeyData {

View File

@@ -286,7 +286,7 @@ impl ConnCfg {
let client_config = if allow_self_signed_compute {
// Allow all certificates for creating the connection
let verifier = Arc::new(AcceptEverythingVerifier);
let verifier = Arc::new(AcceptEverythingVerifier) as Arc<dyn ServerCertVerifier>;
rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(verifier)

View File

@@ -10,7 +10,7 @@ use anyhow::{bail, ensure, Context, Ok};
use itertools::Itertools;
use remote_storage::RemoteStorageConfig;
use rustls::{
crypto::ring::sign,
crypto::aws_lc_rs::sign,
pki_types::{CertificateDer, PrivateKeyDer},
};
use sha2::{Digest, Sha256};
@@ -76,7 +76,7 @@ impl TlsConfig {
pub const PG_ALPN_PROTOCOL: &[u8] = b"postgresql";
/// Configure TLS for the main endpoint.
pub fn configure_tls(
pub async fn configure_tls(
key_path: &str,
cert_path: &str,
certs_dir: Option<&String>,
@@ -110,13 +110,20 @@ pub fn configure_tls(
let cert_resolver = Arc::new(cert_resolver);
let provider = rustls::crypto::aws_lc_rs::default_provider();
#[cfg(target_os = "linux")]
let provider = {
let mut provider = provider;
let compat = ktls::CompatibleCiphers::new().await?;
provider.cipher_suites.retain(|s| compat.is_compatible(*s));
provider
};
// allow TLS 1.2 to be compatible with older client libraries
let mut config = rustls::ServerConfig::builder_with_protocol_versions(&[
&rustls::version::TLS13,
&rustls::version::TLS12,
])
.with_no_client_auth()
.with_cert_resolver(cert_resolver.clone());
let mut config = rustls::ServerConfig::builder_with_provider(Arc::new(provider))
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
.with_no_client_auth()
.with_cert_resolver(cert_resolver.clone());
config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];
@@ -318,7 +325,7 @@ impl CertResolver {
// a) Instead of multi-cert approach use single cert with extra
// domains listed in Subject Alternative Name (SAN).
// b) Deploy separate proxy instances for extra domains.
self.default.clone()
self.default.as_ref().cloned()
}
}
}

View File

@@ -1,13 +1,11 @@
use measured::FixedCardinalityLabel;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::{self, Display};
use crate::auth::IpPattern;
use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt};
use crate::proxy::retry::CouldRetry;
use crate::RoleName;
/// Generic error response with human-readable description.
/// Note that we can't always present it to user as is.
@@ -343,26 +341,6 @@ impl ColdStartInfo {
}
}
#[derive(Debug, Deserialize, Clone)]
pub struct JwksRoleMapping {
pub roles: HashMap<RoleName, EndpointJwksResponse>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct EndpointJwksResponse {
pub jwks: Vec<JwksSettings>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct JwksSettings {
pub id: String,
pub project_id: ProjectIdInt,
pub branch_id: BranchIdInt,
pub jwks_url: url::Url,
pub provider_name: String,
pub jwt_audience: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -305,7 +305,6 @@ impl NodeInfo {
match keys {
ComputeCredentialKeys::Password(password) => self.config.password(password),
ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
ComputeCredentialKeys::None => &mut self.config,
};
}
}

View File

@@ -64,7 +64,7 @@ impl Api {
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
tokio::spawn(connection);
let secret = if let Some(entry) = get_execute_postgres_query(
let secret = match get_execute_postgres_query(
&client,
"select rolpassword from pg_catalog.pg_authid where rolname = $1",
&[&&*user_info.user],
@@ -72,12 +72,15 @@ impl Api {
)
.await?
{
info!("got a secret: {entry}"); // safe since it's not a prod scenario
let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
} else {
warn!("user '{}' does not exist", user_info.user);
None
Some(entry) => {
info!("got a secret: {entry}"); // safe since it's not a prod scenario
let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
}
None => {
warn!("user '{}' does not exist", user_info.user);
None
}
};
let allowed_ips = match get_execute_postgres_query(
&client,
@@ -139,11 +142,12 @@ async fn get_execute_postgres_query(
let rows = client.query(query, params).await?;
// We can get at most one row, because `rolname` is unique.
let Some(row) = rows.first() else {
let row = match rows.first() {
Some(row) => row,
// This means that the user doesn't exist, so there can be no secret.
// However, this is still a *valid* outcome which is very similar
// to getting `404 Not found` from the Neon console.
return Ok(None);
None => return Ok(None),
};
let entry = row.try_get(idx).map_err(MockApiError::PasswordNotSet)?;

View File

@@ -38,9 +38,9 @@ impl Api {
locks: &'static ApiLocks<EndpointCacheKey>,
wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
) -> Self {
let jwt = match std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN") {
let jwt: String = match std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN") {
Ok(v) => v,
Err(_) => String::new(),
Err(_) => "".to_string(),
};
Self {
endpoint,
@@ -96,10 +96,10 @@ impl Api {
// Error 404 is special: it's ok not to have a secret.
// TODO(anna): retry
Err(e) => {
return if e.get_reason().is_not_found() {
Ok(AuthInfo::default())
if e.get_reason().is_not_found() {
return Ok(AuthInfo::default());
} else {
Err(e.into())
return Err(e.into());
}
}
};

View File

@@ -1,83 +1,6 @@
// rustc lints/lint groups
// https://doc.rust-lang.org/rustc/lints/groups.html
#![deny(
deprecated,
future_incompatible,
// TODO: consider let_underscore
nonstandard_style,
rust_2024_compatibility
)]
#![warn(clippy::all, clippy::pedantic, clippy::cargo)]
// List of denied lints from the clippy::restriction group.
// https://rust-lang.github.io/rust-clippy/master/index.html#?groups=restriction
#![warn(
clippy::undocumented_unsafe_blocks,
// TODO: Enable once all individual checks are enabled.
//clippy::as_conversions,
clippy::dbg_macro,
clippy::empty_enum_variants_with_brackets,
clippy::exit,
clippy::float_cmp_const,
clippy::lossy_float_literal,
clippy::macro_use_imports,
clippy::manual_ok_or,
// TODO: consider clippy::map_err_ignore
// TODO: consider clippy::mem_forget
clippy::rc_mutex,
clippy::rest_pat_in_fully_bound_structs,
clippy::string_add,
clippy::string_to_string,
clippy::todo,
// TODO: consider clippy::unimplemented
// TODO: consider clippy::unwrap_used
)]
// List of permanently allowed lints.
#![allow(
// It's ok to cast bool to u8, etc.
clippy::cast_lossless,
// Seems unavoidable.
clippy::multiple_crate_versions,
// While #[must_use] is a great feature this check is too noisy.
clippy::must_use_candidate,
// Inline consts, structs, fns, imports, etc. are ok if they're used by
// the following statement(s).
clippy::items_after_statements,
)]
// List of temporarily allowed lints.
// TODO: Switch to except() once stable with 1.81.
// TODO: fix code and reduce list or move to permanent list above.
#![allow(
clippy::cargo_common_metadata,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_precision_loss,
clippy::cast_sign_loss,
clippy::doc_markdown,
clippy::implicit_hasher,
clippy::inline_always,
clippy::match_same_arms,
clippy::match_wild_err_arm,
clippy::missing_errors_doc,
clippy::missing_panics_doc,
clippy::module_name_repetitions,
clippy::needless_pass_by_value,
clippy::needless_raw_string_hashes,
clippy::redundant_closure_for_method_calls,
clippy::return_self_not_must_use,
clippy::similar_names,
clippy::single_match_else,
clippy::struct_excessive_bools,
clippy::struct_field_names,
clippy::too_many_lines,
clippy::unreadable_literal,
clippy::unused_async,
clippy::unused_self,
clippy::wildcard_imports
)]
// List of temporarily allowed lints to unblock beta/nightly.
#![allow(unknown_lints, clippy::manual_inspect)]
#![deny(clippy::undocumented_unsafe_blocks)]
use std::{convert::Infallible, future::Future};
use std::convert::Infallible;
use anyhow::{bail, Context};
use intern::{EndpointIdInt, EndpointIdTag, InternId};
@@ -112,14 +35,7 @@ pub mod usage_metrics;
pub mod waiters;
/// Handle unix signals appropriately.
pub async fn handle_signals<F, Fut>(
token: CancellationToken,
mut refresh_config: F,
) -> anyhow::Result<Infallible>
where
F: FnMut() -> Fut,
Fut: Future<Output = anyhow::Result<()>>,
{
pub async fn handle_signals(token: CancellationToken) -> anyhow::Result<Infallible> {
use tokio::signal::unix::{signal, SignalKind};
let mut hangup = signal(SignalKind::hangup())?;
@@ -130,8 +46,7 @@ where
tokio::select! {
// Hangup is commonly used for config reload.
_ = hangup.recv() => {
warn!("received SIGHUP");
refresh_config().await?;
warn!("received SIGHUP; config reload is not supported");
}
// Shut down the whole application.
_ = interrupt.recv() => {

View File

@@ -3,6 +3,7 @@
use std::{
io,
net::SocketAddr,
os::fd::AsRawFd,
pin::Pin,
task::{Context, Poll},
};
@@ -20,6 +21,23 @@ pin_project! {
}
}
impl<S: AsRawFd> AsRawFd for ChainRW<S> {
fn as_raw_fd(&self) -> std::os::unix::prelude::RawFd {
self.inner.as_raw_fd()
}
}
#[cfg(all(target_os = "linux", not(test)))]
impl<S: ktls::AsyncReadReady> ktls::AsyncReadReady for ChainRW<S> {
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
if self.buf.is_empty() {
self.inner.poll_read_ready(cx)
} else {
Poll::Ready(Ok(()))
}
}
}
impl<T: AsyncWrite> AsyncWrite for ChainRW<T> {
#[inline]
fn poll_write(

View File

@@ -1,5 +1,5 @@
#[cfg(test)]
mod tests;
pub mod tests;
pub mod connect_compute;
mod copy_bidirectional;
@@ -9,6 +9,7 @@ pub mod retry;
pub mod wake_compute;
pub use copy_bidirectional::copy_bidirectional_client_compute;
pub use copy_bidirectional::ErrorSource;
use handshake::KtlsAsyncReadReady;
use crate::{
auth,
@@ -21,7 +22,7 @@ use crate::{
protocol2::read_proxy_protocol,
proxy::handshake::{handshake, HandshakeData},
rate_limiter::EndpointRateLimiter,
stream::{PqStream, Stream},
stream::PqStream,
EndpointCacheKey,
};
use futures::TryFutureExt;
@@ -30,6 +31,7 @@ use once_cell::sync::OnceCell;
use pq_proto::{BeMessage as Be, StartupMessageParams};
use regex::Regex;
use smol_str::{format_smolstr, SmolStr};
use std::os::fd::AsRawFd;
use std::sync::Arc;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
@@ -191,13 +193,6 @@ impl ClientMode {
}
}
fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
match self {
ClientMode::Tcp => s.sni_hostname(),
ClientMode::Websockets { hostname } => hostname.as_deref(),
}
}
fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
match self {
ClientMode::Tcp => tls,
@@ -238,7 +233,7 @@ impl ReportableError for ClientRequestError {
}
}
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>(
config: &'static ProxyConfig,
ctx: &RequestMonitoring,
cancellation_handler: Arc<CancellationHandlerMain>,
@@ -254,16 +249,16 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let metrics = &Metrics::get().proxy;
let proto = ctx.protocol();
let request_gauge = metrics.connection_requests.guard(proto);
let _request_gauge = metrics.connection_requests.guard(proto);
let tls = config.tls_config.as_ref();
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
let (mut stream, params) =
let (mut stream, ep, params) =
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Startup(stream, ep, params) => (stream, ep, params),
HandshakeData::Cancel(cancel_key_data) => {
return Ok(cancellation_handler
.cancel_session(cancel_key_data, ctx.session_id())
@@ -275,15 +270,11 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
ctx.set_db_options(params.clone());
let hostname = mode.hostname(stream.get_ref());
let common_names = tls.map(|tls| &tls.common_names);
// Extract credentials which we're going to use for auth.
let result = config
.auth_backend
.as_ref()
.map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names))
.map(|_| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, ep))
.transpose();
let user_info = match result {
@@ -340,7 +331,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
client: stream,
aux: node.aux.clone(),
compute: node,
req: request_gauge,
req: _request_gauge,
conn: conn_gauge,
cancel: session,
}))

View File

@@ -30,10 +30,9 @@ pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo {
if is_cached {
warn!("invalidating stalled compute node info cache entry");
}
let label = if is_cached {
ConnectionFailureKind::ComputeCached
} else {
ConnectionFailureKind::ComputeUncached
let label = match is_cached {
true => ConnectionFailureKind::ComputeCached,
false => ConnectionFailureKind::ComputeUncached,
};
Metrics::get().proxy.connection_failures_total.inc(label);
@@ -62,7 +61,7 @@ pub trait ComputeConnectBackend {
ctx: &RequestMonitoring,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
fn get_keys(&self) -> &ComputeCredentialKeys;
fn get_keys(&self) -> Option<&ComputeCredentialKeys>;
}
pub struct TcpMechanism<'a> {
@@ -113,8 +112,9 @@ where
let mut num_retries = 0;
let mut node_info =
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
node_info.set_keys(user_info.get_keys());
if let Some(keys) = user_info.get_keys() {
node_info.set_keys(keys);
}
node_info.allow_self_signed_compute = allow_self_signed_compute;
// let mut node_info = credentials.get_node_info(ctx, user_info).await?;
mechanism.update_connect_config(&mut node_info.config);

View File

@@ -230,10 +230,11 @@ impl CopyBuffer {
io::ErrorKind::WriteZero,
"write zero byte into writer",
))));
} else {
self.pos += i;
self.amt += i as u64;
self.need_flush = true;
}
self.pos += i;
self.amt += i as u64;
self.need_flush = true;
}
// If pos larger than cap, this loop will never stop.

View File

@@ -1,3 +1,5 @@
use std::os::fd::AsRawFd;
use bytes::Buf;
use pq_proto::{
framed::Framed, BeMessage as Be, CancelKeyData, FeStartupPacket, ProtocolVersion,
@@ -15,6 +17,7 @@ use crate::{
metrics::Metrics,
proxy::ERR_INSECURE_CONNECTION,
stream::{PqStream, Stream, StreamUpgradeError},
EndpointId,
};
#[derive(Error, Debug)]
@@ -31,6 +34,10 @@ pub enum HandshakeError {
#[error("{0}")]
StreamUpgradeError(#[from] StreamUpgradeError),
#[cfg(all(target_os = "linux", not(test)))]
#[error("{0}")]
KtlsUpgradeError(#[from] ktls::Error),
#[error("{0}")]
Io(#[from] std::io::Error),
@@ -43,6 +50,8 @@ impl ReportableError for HandshakeError {
match self {
HandshakeError::EarlyData => crate::error::ErrorKind::User,
HandshakeError::ProtocolViolation => crate::error::ErrorKind::User,
#[cfg(all(target_os = "linux", not(test)))]
HandshakeError::KtlsUpgradeError(_) => crate::error::ErrorKind::Service,
// This error should not happen, but will if we have no default certificate and
// the client sends no SNI extension.
// If they provide SNI then we can be sure there is a certificate that matches.
@@ -57,22 +66,39 @@ impl ReportableError for HandshakeError {
}
}
pub enum HandshakeData<S> {
Startup(PqStream<Stream<S>>, StartupMessageParams),
pub enum HandshakeData<S: AsRawFd> {
Startup(
PqStream<Stream<S>>,
Option<EndpointId>,
StartupMessageParams,
),
Cancel(CancelKeyData),
}
#[cfg(any(not(target_os = "linux"), test))]
pub trait KtlsAsyncReadReady {}
#[cfg(all(target_os = "linux", not(test)))]
pub trait KtlsAsyncReadReady: ktls::AsyncReadReady {}
#[cfg(any(not(target_os = "linux"), test))]
impl<K: AsyncRead> KtlsAsyncReadReady for K {}
#[cfg(all(target_os = "linux", not(test)))]
impl<K: ktls::AsyncReadReady> KtlsAsyncReadReady for K {}
/// Establish a (most probably, secure) connection with the client.
/// For better testing experience, `stream` can be any object satisfying the traits.
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
/// we also take an extra care of propagating only the select handshake errors to client.
#[tracing::instrument(skip_all)]
pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
pub async fn handshake<S>(
ctx: &RequestMonitoring,
stream: S,
mut tls: Option<&TlsConfig>,
record_handshake_error: bool,
) -> Result<HandshakeData<S>, HandshakeError> {
) -> Result<HandshakeData<S>, HandshakeError>
where
S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady,
{
// Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false);
@@ -80,6 +106,7 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0);
let mut stream = PqStream::new(Stream::from_raw(stream));
let mut ep = None;
loop {
let msg = stream.read_startup_packet().await?;
match msg {
@@ -113,6 +140,9 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
));
};
#[cfg(all(target_os = "linux", not(test)))]
let raw = ktls::CorkStream::new(raw);
let mut read_buf = read_buf.reader();
let mut res = Ok(());
let accept = tokio_rustls::TlsAcceptor::from(tls.to_server_config())
@@ -145,11 +175,11 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
let conn_info = tls_stream.get_ref().1;
// try parse endpoint
let ep = conn_info
ep = conn_info
.server_name()
.and_then(|sni| endpoint_sni(sni, &tls.common_names).ok().flatten());
if let Some(ep) = ep {
ctx.set_endpoint_id(ep);
if let Some(ep) = &ep {
ctx.set_endpoint_id(ep.clone());
}
// check the ALPN, if exists, as required.
@@ -170,7 +200,10 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream = PqStream {
framed: Framed {
stream: Stream::Tls {
tls: Box::new(tls_stream),
#[cfg(any(not(target_os = "linux"), test))]
tls: Box::pin(tls_stream),
#[cfg(all(target_os = "linux", not(test)))]
tls: ktls::config_ktls_server(tls_stream).await?,
tls_server_end_point,
},
read_buf,
@@ -207,7 +240,7 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
session_type = "normal",
"successful handshake"
);
break Ok(HandshakeData::Startup(stream, params));
break Ok(HandshakeData::Startup(stream, ep, params));
}
// downgrade protocol version
FeStartupPacket::StartupMessage { params, version }
@@ -238,7 +271,7 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
session_type = "normal",
"successful handshake; unsupported minor version requested"
);
break Ok(HandshakeData::Startup(stream, params));
break Ok(HandshakeData::Startup(stream, ep, params));
}
FeStartupPacket::StartupMessage { version, .. } => {
warn!(

View File

@@ -1,3 +1,5 @@
use std::os::fd::AsRawFd;
use crate::{
cancellation,
compute::PostgresConnection,
@@ -10,7 +12,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
use utils::measured_stream::MeasuredStream;
use super::copy_bidirectional::ErrorSource;
use super::{copy_bidirectional::ErrorSource, handshake::KtlsAsyncReadReady};
/// Forward bytes in both directions (client <-> compute).
#[tracing::instrument(skip_all)]
@@ -57,7 +59,7 @@ pub async fn proxy_pass(
Ok(())
}
pub struct ProxyPassthrough<P, S> {
pub struct ProxyPassthrough<P, S: AsRawFd> {
pub client: Stream<S>,
pub compute: PostgresConnection,
pub aux: MetricsAuxInfo,
@@ -67,7 +69,7 @@ pub struct ProxyPassthrough<P, S> {
pub cancel: cancellation::Session<P>,
}
impl<P, S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<P, S> {
impl<P, S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady> ProxyPassthrough<P, S> {
pub async fn proxy_pass(self) -> Result<(), ErrorSource> {
let res = proxy_pass(self.client, self.compute.stream, self.aux).await;
if let Err(err) = self.compute.cancel_closure.try_cancel_query().await {

View File

@@ -2,6 +2,8 @@
mod mitm;
use std::pin::Pin;
use std::task::Poll;
use std::time::Duration;
use super::connect_compute::ConnectMechanism;
@@ -16,12 +18,14 @@ use crate::console::messages::{ConsoleError, Details, MetricsAuxInfo, Status};
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend};
use crate::console::{self, CachedNodeInfo, NodeInfo};
use crate::error::ErrorKind;
use crate::stream::Stream;
use crate::{http, sasl, scram, BranchId, EndpointId, ProjectId};
use anyhow::{bail, Context};
use async_trait::async_trait;
use retry::{retry_after, ShouldRetryWakeCompute};
use rstest::rstest;
use rustls::pki_types;
use tokio::io::DuplexStream;
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream};
@@ -35,28 +39,73 @@ fn generate_certs(
pki_types::CertificateDer<'static>,
pki_types::PrivateKeyDer<'static>,
)> {
let ca = rcgen::Certificate::from_params({
let ca_key = rcgen::KeyPair::generate()?;
let cert_key = rcgen::KeyPair::generate()?;
let ca = {
let mut params = rcgen::CertificateParams::default();
params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
params
})?;
params.self_signed(&ca_key)?
};
let cert = rcgen::Certificate::from_params({
let mut params = rcgen::CertificateParams::new(vec![hostname.into()]);
let cert = {
let mut params = rcgen::CertificateParams::new(vec![hostname.into()])?;
params.distinguished_name = rcgen::DistinguishedName::new();
params
.distinguished_name
.push(rcgen::DnType::CommonName, common_name);
params
})?;
params.signed_by(&cert_key, &ca, &ca_key)?
};
Ok((
pki_types::CertificateDer::from(ca.serialize_der()?),
pki_types::CertificateDer::from(cert.serialize_der_with_signer(&ca)?),
pki_types::PrivateKeyDer::Pkcs8(cert.serialize_private_key_der().into()),
ca.into(),
cert.into(),
pki_types::PrivateKeyDer::Pkcs8(cert_key.serialize_der().into()),
))
}
pub struct DummyClient(pub DuplexStream);
impl AsRawFd for DummyClient {
fn as_raw_fd(&self) -> std::os::unix::prelude::RawFd {
unreachable!()
}
}
impl AsyncWrite for DummyClient {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.0).poll_flush(cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}
}
impl AsyncRead for DummyClient {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}
struct ClientConfig<'a> {
config: rustls::ClientConfig,
hostname: &'a str,
@@ -121,7 +170,9 @@ fn generate_tls_config<'a>(
#[async_trait]
trait TestAuth: Sized {
async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
async fn authenticate<
S: AsyncRead + AsyncWrite + Unpin + Send + AsRawFd + KtlsAsyncReadReady,
>(
self,
stream: &mut PqStream<Stream<S>>,
) -> anyhow::Result<()> {
@@ -150,7 +201,9 @@ impl Scram {
#[async_trait]
impl TestAuth for Scram {
async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
async fn authenticate<
S: AsyncRead + AsyncWrite + Unpin + Send + AsRawFd + KtlsAsyncReadReady,
>(
self,
stream: &mut PqStream<Stream<S>>,
) -> anyhow::Result<()> {
@@ -170,14 +223,14 @@ impl TestAuth for Scram {
/// A dummy proxy impl which performs a handshake and reports auth success.
async fn dummy_proxy(
client: impl AsyncRead + AsyncWrite + Unpin + Send,
client: impl AsyncRead + AsyncWrite + Unpin + Send + AsRawFd,
tls: Option<TlsConfig>,
auth: impl TestAuth + Send,
) -> anyhow::Result<()> {
let (client, _) = read_proxy_protocol(client).await?;
let mut stream =
match handshake(&RequestMonitoring::test(), client, tls.as_ref(), false).await? {
HandshakeData::Startup(stream, _) => stream,
HandshakeData::Startup(stream, ..) => stream,
HandshakeData::Cancel(_) => bail!("cancellation not supported"),
};
@@ -196,7 +249,11 @@ async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let (_, server_config) = generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
let proxy = tokio::spawn(dummy_proxy(
DummyClient(client),
Some(server_config),
NoAuth,
));
let client_err = tokio_postgres::Config::new()
.user("john_doe")
@@ -225,7 +282,11 @@ async fn handshake_tls() -> anyhow::Result<()> {
let (client_config, server_config) =
generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
let proxy = tokio::spawn(dummy_proxy(
DummyClient(client),
Some(server_config),
NoAuth,
));
let (_client, _conn) = tokio_postgres::Config::new()
.user("john_doe")
@@ -241,7 +302,7 @@ async fn handshake_tls() -> anyhow::Result<()> {
async fn handshake_raw() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth));
let proxy = tokio::spawn(dummy_proxy(DummyClient(client), None, NoAuth));
let (_client, _conn) = tokio_postgres::Config::new()
.user("john_doe")
@@ -285,7 +346,7 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
let (client_config, server_config) =
generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(
client,
DummyClient(client),
Some(server_config),
Scram::new(password).await?,
));
@@ -309,7 +370,7 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
let (client_config, server_config) =
generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(
client,
DummyClient(client),
Some(server_config),
Scram::new("password").await?,
));
@@ -332,7 +393,11 @@ async fn scram_auth_mock() -> anyhow::Result<()> {
let (client_config, server_config) =
generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), Scram::mock()));
let proxy = tokio::spawn(dummy_proxy(
DummyClient(client),
Some(server_config),
Scram::mock(),
));
use rand::{distributions::Alphanumeric, Rng};
let password: String = rand::thread_rng()
@@ -433,7 +498,7 @@ impl ReportableError for TestConnectError {
impl std::fmt::Display for TestConnectError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
write!(f, "{:?}", self)
}
}
@@ -475,7 +540,7 @@ impl ConnectMechanism for TestConnectMechanism {
retryable: false,
kind: ErrorKind::Compute,
}),
x => panic!("expecting action {x:?}, connect is called instead"),
x => panic!("expecting action {:?}, connect is called instead", x),
}
}
@@ -515,7 +580,7 @@ impl TestBackend for TestConnectMechanism {
assert!(err.could_retry());
Err(console::errors::WakeComputeError::ApiError(err))
}
x => panic!("expecting action {x:?}, wake_compute is called instead"),
x => panic!("expecting action {:?}, wake_compute is called instead", x),
}
}

View File

@@ -36,14 +36,14 @@ async fn proxy_mitm(
let end_server = connect_tls(server2, client_config2.make_tls_connect().unwrap()).await;
let (end_client, startup) = match handshake(
&RequestMonitoring::test(),
client1,
DummyClient(client1),
Some(&server_config1),
false,
)
.await
.unwrap()
{
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Startup(stream, _ep, params) => (stream, params),
HandshakeData::Cancel(_) => panic!("cancellation not supported"),
};
@@ -115,7 +115,9 @@ where
let mut buf = [0];
stream.read_exact(&mut buf).await.unwrap();
assert!(buf[0] == b'S', "ssl not supported by server");
if buf[0] != b'S' {
panic!("ssl not supported by server");
}
tls.connect(stream).await.unwrap()
}
@@ -152,7 +154,7 @@ impl Encoder<Bytes> for PgFrame {
async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
let (server, client, client_config, server_config) = proxy_mitm(Intercept::None).await;
let proxy = tokio::spawn(dummy_proxy(
client,
DummyClient(client),
Some(server_config),
Scram::new("password").await?,
));
@@ -235,7 +237,7 @@ async fn connect_failure(
) -> anyhow::Result<()> {
let (server, client, client_config, server_config) = proxy_mitm(intercept).await;
let proxy = tokio::spawn(dummy_proxy(
client,
DummyClient(client),
Some(server_config),
Scram::new("password").await?,
));

View File

@@ -119,7 +119,6 @@ impl Default for LeakyBucketState {
}
#[cfg(test)]
#[allow(clippy::float_cmp)]
mod tests {
use std::time::Duration;

View File

@@ -174,8 +174,9 @@ impl DynamicLimiter {
let mut inner = self.inner.lock();
if inner.take(&self.ready).is_some() {
break Ok(Token::new(self.clone()));
} else {
notified.set(self.ready.notified());
}
notified.set(self.ready.notified());
}
notified.as_mut().await;
ready = true;

View File

@@ -150,7 +150,7 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
}
}
}
Notification::AllowedIpsUpdate { .. } | Notification::PasswordUpdate { .. } => {
_ => {
invalidate_cache(self.cache.clone(), msg.clone());
if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
Metrics::get()

View File

@@ -89,7 +89,7 @@ impl<'a> ClientFirstMessage<'a> {
write!(&mut message, "r={}", self.nonce).unwrap();
base64::encode_config_buf(nonce, base64::STANDARD, &mut message);
let combined_nonce = 2..message.len();
write!(&mut message, ",s={salt_base64},i={iterations}").unwrap();
write!(&mut message, ",s={},i={}", salt_base64, iterations).unwrap();
// This design guarantees that it's impossible to create a
// server-first-message without receiving a client-first-message

View File

@@ -82,7 +82,13 @@ mod tests {
let stored_key = "D5h6KTMBlUvDJk2Y8ELfC1Sjtc6k9YHjRyuRZyBNJns=";
let server_key = "Pi3QHbcluX//NDfVkKlFl88GGzlJ5LkyPwcdlN/QBvI=";
let secret = format!("SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}");
let secret = format!(
"SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}",
iterations = iterations,
salt = salt,
stored_key = stored_key,
server_key = server_key,
);
let parsed = ServerSecret::parse(&secret).unwrap();
assert_eq!(parsed.iterations, iterations);

View File

@@ -222,11 +222,12 @@ fn thread_rt(pool: Arc<ThreadPool>, worker: Worker<JobSpec>, index: usize) {
}
for i in 0.. {
let Some(mut job) = worker
let mut job = match worker
.pop()
.or_else(|| pool.steal(&mut rng, index, &worker))
else {
continue 'wait;
{
Some(job) => job,
None => continue 'wait,
};
pool.metrics

View File

@@ -93,11 +93,11 @@ pub async fn task_main(
let mut tls_server_config = rustls::ServerConfig::clone(&config.to_server_config());
// prefer http2, but support http/1.1
tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
Arc::new(tls_server_config)
Arc::new(tls_server_config) as Arc<_>
}
None => {
warn!("TLS config is missing");
Arc::new(NoTls)
Arc::new(NoTls) as Arc<_>
}
};
@@ -190,7 +190,19 @@ trait MaybeTlsAcceptor: Send + Sync + 'static {
#[async_trait]
impl MaybeTlsAcceptor for rustls::ServerConfig {
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW> {
Ok(Box::pin(TlsAcceptor::from(self).accept(conn).await?))
#[cfg(all(target_os = "linux", not(test)))]
let conn = ktls::CorkStream::new(conn);
let tls = TlsAcceptor::from(self).accept(conn).await?;
#[cfg(all(target_os = "linux", not(test)))]
return ktls::config_ktls_server(tls)
.await
.map(|s| Box::pin(s) as _)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));
#[cfg(any(not(target_os = "linux"), test))]
Ok(Box::pin(tls))
}
}
@@ -407,7 +419,7 @@ async fn request_handler(
.header("Access-Control-Allow-Origin", "*")
.header(
"Access-Control-Allow-Headers",
"Authorization, Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level",
"Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Neon-Batch-Read-Only, Neon-Batch-Isolation-Level",
)
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code

View File

@@ -4,10 +4,7 @@ use async_trait::async_trait;
use tracing::{field::display, info};
use crate::{
auth::{
backend::{local::StaticAuthRules, ComputeCredentials, ComputeUserInfo},
check_peer_addr_is_in_list, AuthError,
},
auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError},
compute,
config::{AuthenticationConfig, ProxyConfig},
console::{
@@ -27,7 +24,7 @@ use crate::{
Host,
};
use super::conn_pool::{poll_client, AuthData, Client, ConnInfo, GlobalConnPool};
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
pub struct PoolingBackend {
pub pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
@@ -36,26 +33,21 @@ pub struct PoolingBackend {
}
impl PoolingBackend {
pub async fn authenticate_with_password(
pub async fn authenticate(
&self,
ctx: &RequestMonitoring,
config: &AuthenticationConfig,
user_info: &ComputeUserInfo,
password: &[u8],
conn_info: &ConnInfo,
) -> Result<ComputeCredentials, AuthError> {
let user_info = user_info.clone();
let backend = self
.config
.auth_backend
.as_ref()
.map(|()| user_info.clone());
let user_info = conn_info.user_info.clone();
let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone());
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) {
return Err(AuthError::ip_address_not_allowed(ctx.peer_addr()));
}
if !self
.endpoint_rate_limiter
.check(user_info.endpoint.clone().into(), 1)
.check(conn_info.user_info.endpoint.clone().into(), 1)
{
return Err(AuthError::too_many_connections());
}
@@ -78,10 +70,14 @@ impl PoolingBackend {
return Err(AuthError::auth_failed(&*user_info.user));
}
};
let ep = EndpointIdInt::from(&user_info.endpoint);
let auth_outcome =
crate::auth::validate_password_and_exchange(&config.thread_pool, ep, password, secret)
.await?;
let ep = EndpointIdInt::from(&conn_info.user_info.endpoint);
let auth_outcome = crate::auth::validate_password_and_exchange(
&config.thread_pool,
ep,
&conn_info.password,
secret,
)
.await?;
let res = match auth_outcome {
crate::sasl::Outcome::Success(key) => {
info!("user successfully authenticated");
@@ -89,7 +85,7 @@ impl PoolingBackend {
}
crate::sasl::Outcome::Failure(reason) => {
info!("auth backend failed with an error: {reason}");
Err(AuthError::auth_failed(&*user_info.user))
Err(AuthError::auth_failed(&*conn_info.user_info.user))
}
};
res.map(|key| ComputeCredentials {
@@ -98,39 +94,6 @@ impl PoolingBackend {
})
}
pub async fn authenticate_with_jwt(
&self,
ctx: &RequestMonitoring,
user_info: &ComputeUserInfo,
jwt: &str,
) -> Result<ComputeCredentials, AuthError> {
match &self.config.auth_backend {
crate::auth::BackendType::Console(_, ()) => {
Err(AuthError::auth_failed("JWT login is not yet supported"))
}
crate::auth::BackendType::Link(_, ()) => Err(AuthError::auth_failed(
"JWT login over link proxy is not supported",
)),
crate::auth::BackendType::Local(cache) => {
cache
.jwks_cache
.check_jwt(
ctx,
user_info.endpoint.clone(),
user_info.user.clone(),
&StaticAuthRules,
jwt,
)
.await
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
Ok(ComputeCredentials {
info: user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
})
}
}
}
// Wake up the destination if needed. Code here is a bit involved because
// we reuse the code from the usual proxy and we need to prepare few structures
// that this code expects.
@@ -142,12 +105,12 @@ impl PoolingBackend {
keys: ComputeCredentials,
force_new: bool,
) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
let maybe_client = if force_new {
info!("pool: pool is disabled");
None
} else {
let maybe_client = if !force_new {
info!("pool: looking for an existing connection");
self.pool.get(ctx, &conn_info)?
} else {
info!("pool: pool is disabled");
None
};
if let Some(client) = maybe_client {
@@ -156,7 +119,7 @@ impl PoolingBackend {
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.config.auth_backend.as_ref().map(|()| keys);
let backend = self.config.auth_backend.as_ref().map(|_| keys);
crate::proxy::connect_compute::connect_to_compute(
ctx,
&TokioMechanism {
@@ -269,16 +232,10 @@ impl ConnectMechanism for TokioMechanism {
let mut config = (*node_info.config).clone();
let config = config
.user(&self.conn_info.user_info.user)
.password(&*self.conn_info.password)
.dbname(&self.conn_info.dbname)
.connect_timeout(timeout);
match &self.conn_info.auth {
AuthData::Jwt(_) => {}
AuthData::Password(pw) => {
config.password(pw);
}
}
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let res = config.connect(tokio_postgres::NoTls).await;
drop(pause);

View File

@@ -33,13 +33,7 @@ use super::backend::HttpConnError;
pub struct ConnInfo {
pub user_info: ComputeUserInfo,
pub dbname: DbName,
pub auth: AuthData,
}
#[derive(Debug, Clone)]
pub enum AuthData {
Password(SmallVec<[u8; 16]>),
Jwt(String),
pub password: SmallVec<[u8; 16]>,
}
impl ConnInfo {
@@ -339,9 +333,9 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
} = pool.get_mut();
// ensure that closed clients are removed
for db_pool in pools.values_mut() {
pools.iter_mut().for_each(|(_, db_pool)| {
clients_removed += db_pool.clear_closed_clients(total_conns);
}
});
// we only remove this pool if it has no active connections
if *total_conns == 0 {
@@ -405,20 +399,21 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
if client.is_closed() {
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
return Ok(None);
} else {
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
tracing::Span::current().record(
"pid",
tracing::field::display(client.inner.get_process_id()),
);
info!(
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"pool: reusing connection '{conn_info}'"
);
client.session.send(ctx.session_id())?;
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
ctx.success();
return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
}
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
tracing::Span::current().record(
"pid",
tracing::field::display(client.inner.get_process_id()),
);
info!(
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"pool: reusing connection '{conn_info}'"
);
client.session.send(ctx.session_id())?;
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
ctx.success();
return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
}
Ok(None)
}
@@ -659,7 +654,7 @@ impl<C: ClientInnerExt> Client<C> {
span: _,
} = self;
let inner = inner.as_mut().expect("client inner should not be removed");
(&mut inner.inner, Discard { conn_info, pool })
(&mut inner.inner, Discard { pool, conn_info })
}
}
@@ -721,9 +716,7 @@ impl<C: ClientInnerExt> Drop for Client<C> {
mod tests {
use std::{mem, sync::atomic::AtomicBool};
use crate::{
proxy::NeonOptions, serverless::cancel_set::CancelSet, BranchId, EndpointId, ProjectId,
};
use crate::{serverless::cancel_set::CancelSet, BranchId, EndpointId, ProjectId};
use super::*;
@@ -782,10 +775,10 @@ mod tests {
user_info: ComputeUserInfo {
user: "user".into(),
endpoint: "endpoint".into(),
options: NeonOptions::default(),
options: Default::default(),
},
dbname: "dbname".into(),
auth: AuthData::Password("password".as_bytes().into()),
password: "password".as_bytes().into(),
};
let ep_pool = Arc::downgrade(
&pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
@@ -840,10 +833,10 @@ mod tests {
user_info: ComputeUserInfo {
user: "user".into(),
endpoint: "endpoint-2".into(),
options: NeonOptions::default(),
options: Default::default(),
},
dbname: "dbname".into(),
auth: AuthData::Password("password".as_bytes().into()),
password: "password".as_bytes().into(),
};
let ep_pool = Arc::downgrade(
&pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),

View File

@@ -55,7 +55,7 @@ fn json_array_to_pg_array(value: &Value) -> Option<String> {
.collect::<Vec<_>>()
.join(",");
Some(format!("{{{vals}}}"))
Some(format!("{{{}}}", vals))
}
}
}

View File

@@ -7,7 +7,6 @@ use futures::future::try_join;
use futures::future::Either;
use futures::StreamExt;
use futures::TryFutureExt;
use http::header::AUTHORIZATION;
use http_body_util::BodyExt;
use http_body_util::Full;
use hyper1::body::Body;
@@ -57,7 +56,6 @@ use crate::DbName;
use crate::RoleName;
use super::backend::PoolingBackend;
use super::conn_pool::AuthData;
use super::conn_pool::Client;
use super::conn_pool::ConnInfo;
use super::http_util::json_response;
@@ -90,7 +88,6 @@ enum Payload {
const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024; // 10 MiB
const MAX_REQUEST_SIZE: u64 = 10 * 1024 * 1024; // 10 MiB
static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
@@ -112,7 +109,7 @@ where
#[derive(Debug, thiserror::Error)]
pub enum ConnInfoError {
#[error("invalid header: {0}")]
InvalidHeader(&'static HeaderName),
InvalidHeader(&'static str),
#[error("invalid connection string: {0}")]
UrlParseError(#[from] url::ParseError),
#[error("incorrect scheme")]
@@ -156,10 +153,10 @@ fn get_conn_info(
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
let connection_string = headers
.get(&CONN_STRING)
.ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))?
.get("Neon-Connection-String")
.ok_or(ConnInfoError::InvalidHeader("Neon-Connection-String"))?
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?;
.map_err(|_| ConnInfoError::InvalidHeader("Neon-Connection-String"))?;
let connection_url = Url::parse(connection_string)?;
@@ -182,23 +179,10 @@ fn get_conn_info(
}
ctx.set_user(username.clone());
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
let auth = auth
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
AuthData::Jwt(
auth.strip_prefix("Bearer ")
.ok_or(ConnInfoError::MissingPassword)?
.into(),
)
} else if let Some(pass) = connection_url.password() {
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
std::borrow::Cow::Borrowed(b) => b.into(),
std::borrow::Cow::Owned(b) => b.into(),
})
} else {
return Err(ConnInfoError::MissingPassword);
};
let password = connection_url
.password()
.ok_or(ConnInfoError::MissingPassword)?;
let password = urlencoding::decode_binary(password.as_bytes());
let endpoint = match connection_url.host() {
Some(url::Host::Domain(hostname)) => {
@@ -207,12 +191,12 @@ fn get_conn_info(
.ok_or(ConnInfoError::MalformedEndpoint)?
} else {
hostname
.split_once('.')
.split_once(".")
.map_or(hostname, |(prefix, _)| prefix)
.into()
}
}
Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => {
Some(url::Host::Ipv4(_)) | Some(url::Host::Ipv6(_)) | None => {
return Err(ConnInfoError::MissingHostname)
}
};
@@ -241,7 +225,10 @@ fn get_conn_info(
Ok(ConnInfo {
user_info,
dbname,
auth,
password: match password {
std::borrow::Cow::Borrowed(b) => b.into(),
std::borrow::Cow::Owned(b) => b.into(),
},
})
}
@@ -563,24 +550,9 @@ async fn handle_inner(
let authenticate_and_connect = Box::pin(
async {
let keys = match &conn_info.auth {
AuthData::Password(pw) => {
backend
.authenticate_with_password(
ctx,
&config.authentication_config,
&conn_info.user_info,
pw,
)
.await?
}
AuthData::Jwt(jwt) => {
backend
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
.await?
}
};
let keys = backend
.authenticate(ctx, &config.authentication_config, &conn_info)
.await?;
let client = backend
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
.await?;

View File

@@ -16,6 +16,7 @@ use hyper1::upgrade::OnUpgrade;
use hyper_util::rt::TokioIo;
use pin_project_lite::pin_project;
use std::os::fd::AsRawFd;
use std::{
pin::Pin,
sync::Arc,
@@ -45,6 +46,18 @@ impl<S> WebSocketRw<S> {
}
}
impl<S> AsRawFd for WebSocketRw<S> {
fn as_raw_fd(&self) -> std::os::unix::prelude::RawFd {
unreachable!("ktls should not need to be used for websocket rw")
}
}
#[cfg(all(target_os = "linux", not(test)))]
impl<S> ktls::AsyncReadReady for WebSocketRw<S> {
fn poll_read_ready(&self, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
unreachable!("ktls should not need to be used for websocket rw")
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for WebSocketRw<S> {
fn poll_write(
self: Pin<&mut Self>,

View File

@@ -1,11 +1,13 @@
use crate::config::TlsServerEndPoint;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::Metrics;
use crate::proxy::handshake::KtlsAsyncReadReady;
use bytes::BytesMut;
use pq_proto::framed::{ConnectionError, Framed};
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
use rustls::ServerConfig;
use std::os::fd::AsRawFd;
use std::pin::Pin;
use std::sync::Arc;
use std::{io, task};
@@ -67,7 +69,7 @@ impl<S: AsyncRead + Unpin> PqStream<S> {
FeMessage::PasswordMessage(msg) => Ok(msg),
bad => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("unexpected message type: {bad:?}"),
format!("unexpected message type: {:?}", bad),
)),
}
}
@@ -172,34 +174,31 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
}
/// Wrapper for upgrading raw streams into secure streams.
pub enum Stream<S> {
pub enum Stream<S: AsRawFd> {
/// We always begin with a raw stream,
/// which may then be upgraded into a secure stream.
Raw { raw: S },
Tls {
/// We box [`TlsStream`] since it can be quite large.
tls: Box<TlsStream<S>>,
#[cfg(any(not(target_os = "linux"), test))]
tls: Pin<Box<TlsStream<S>>>,
#[cfg(all(target_os = "linux", not(test)))]
tls: ktls::KtlsStream<S>,
/// Channel binding parameter
tls_server_end_point: TlsServerEndPoint,
},
}
impl<S: Unpin> Unpin for Stream<S> {}
impl<S: Unpin + AsRawFd> Unpin for Stream<S> {}
impl<S> Stream<S> {
impl<S: AsRawFd> Stream<S> {
/// Construct a new instance from a raw stream.
pub fn from_raw(raw: S) -> Self {
Self::Raw { raw }
}
/// Return SNI hostname when it's available.
pub fn sni_hostname(&self) -> Option<&str> {
match self {
Stream::Raw { .. } => None,
Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
}
}
pub fn tls_server_end_point(&self) -> TlsServerEndPoint {
match self {
Stream::Raw { .. } => TlsServerEndPoint::Undefined,
@@ -221,7 +220,7 @@ pub enum StreamUpgradeError {
Io(#[from] io::Error),
}
impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
impl<S: AsyncRead + AsyncWrite + Unpin + AsRawFd> Stream<S> {
/// If possible, upgrade raw stream into a secure TLS-based stream.
pub async fn upgrade(
self,
@@ -242,7 +241,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
impl<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady> AsyncRead for Stream<S> {
fn poll_read(
mut self: Pin<&mut Self>,
context: &mut task::Context<'_>,
@@ -255,7 +254,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
impl<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady> AsyncWrite for Stream<S> {
fn poll_write(
mut self: Pin<&mut Self>,
context: &mut task::Context<'_>,

View File

@@ -450,9 +450,12 @@ async fn upload_events_chunk(
remote_path: &RemotePath,
cancel: &CancellationToken,
) -> anyhow::Result<()> {
let Some(storage) = storage else {
error!("no remote storage configured");
return Ok(());
let storage = match storage {
Some(storage) => storage,
None => {
error!("no remote storage configured");
return Ok(());
}
};
let data = serde_json::to_vec(&chunk).context("serialize metrics")?;
let mut encoder = GzipEncoder::new(Vec::new());

View File

@@ -31,7 +31,7 @@ pub struct Waiters<T>(pub(self) Mutex<HashMap<String, oneshot::Sender<T>>>);
impl<T> Default for Waiters<T> {
fn default() -> Self {
Waiters(Mutex::default())
Waiters(Default::default())
}
}

View File

@@ -114,16 +114,6 @@ fn check_permission(request: &Request<Body>, tenant_id: Option<TenantId>) -> Res
})
}
/// List all (not deleted) timelines.
async fn timeline_list_handler(request: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permission(&request, None)?;
let res: Vec<TenantTimelineId> = GlobalTimelines::get_all()
.iter()
.map(|tli| tli.ttid)
.collect();
json_response(StatusCode::OK, res)
}
/// Report info about timeline.
async fn timeline_status_handler(request: Request<Body>) -> Result<Response<Body>, ApiError> {
let ttid = TenantTimelineId::new(
@@ -572,9 +562,6 @@ pub fn make_router(conf: SafeKeeperConf) -> RouterBuilder<hyper::Body, ApiError>
.post("/v1/tenant/timeline", |r| {
request_span(r, timeline_create_handler)
})
.get("/v1/tenant/timeline", |r| {
request_span(r, timeline_list_handler)
})
.get("/v1/tenant/:tenant_id/timeline/:timeline_id", |r| {
request_span(r, timeline_status_handler)
})

View File

@@ -68,29 +68,16 @@ const parseReportJson = async ({ reportJsonUrl, fetch }) => {
console.info(`Cannot get BUILD_TYPE and Postgres Version from test name: "${test.name}", defaulting to "release" and "14"`)
buildType = "release"
pgVersion = "16"
pgVersion = "14"
}
pgVersions.add(pgVersion)
// We use `arch` as it is returned by GitHub Actions
// (RUNNER_ARCH env var): X86, X64, ARM, or ARM64
// Ref https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/store-information-in-variables#default-environment-variables
let arch = ""
if (test.parameters.includes("'X64'")) {
arch = "x86-64"
} else if (test.parameters.includes("'ARM64'")) {
arch = "arm64"
} else {
arch = "unknown"
}
// Removing build type and PostgreSQL version from the test name to make it shorter
const testName = test.name.replace(new RegExp(`${buildType}-pg${pgVersion}-?`), "").replace("[]", "")
test.pytestName = `${parentSuite.name.replace(".", "/")}/${suite.name}.py::${testName}`
test.pgVersion = pgVersion
test.buildType = buildType
test.arch = arch
if (test.status === "passed") {
passedTests[pgVersion][testName].push(test)
@@ -157,7 +144,7 @@ const reportSummary = async (params) => {
const links = []
for (const test of tests) {
const allureLink = `${reportUrl}#suites/${test.parentUid}/${test.uid}`
links.push(`[${test.buildType}-${test.arch}](${allureLink})`)
links.push(`[${test.buildType}](${allureLink})`)
}
summary += `- \`${testName}\`: ${links.join(", ")}\n`
}
@@ -188,7 +175,7 @@ const reportSummary = async (params) => {
const links = []
for (const test of tests) {
const allureLink = `${reportUrl}#suites/${test.parentUid}/${test.uid}/retries`
links.push(`[${test.buildType}-${test.arch}](${allureLink})`)
links.push(`[${test.buildType}](${allureLink})`)
}
summary += `- \`${testName}\`: ${links.join(", ")}\n`
}

Some files were not shown because too many files have changed in this diff Show More