Compare commits

..

8 Commits

Author SHA1 Message Date
Conrad Ludgate
2cca1b3e4e fix 2024-08-21 18:44:57 +01:00
Conrad Ludgate
471b3b300d fix pin 2024-08-21 16:29:52 +01:00
Conrad Ludgate
fbd4b91169 asyncreadready 2024-08-21 16:16:49 +01:00
Conrad Ludgate
8cc45ad9bd asrawfd things 2024-08-21 15:28:25 +01:00
Conrad Ludgate
aabbd55187 add ktls handling 2024-08-21 14:42:41 +01:00
Conrad Ludgate
987a859352 start integrating ktls 2024-08-21 14:11:58 +01:00
Conrad Ludgate
e171fd805b add ktls dep 2024-08-21 13:51:02 +01:00
Conrad Ludgate
1e4702b26a update rustls 2024-08-21 13:47:19 +01:00
69 changed files with 879 additions and 914 deletions

View File

@@ -23,30 +23,10 @@ platforms = [
]
[final-excludes]
workspace-members = [
# vm_monitor benefits from the same Cargo.lock as the rest of our artifacts, but
# it is built primarly in separate repo neondatabase/autoscaling and thus is excluded
# from depending on workspace-hack because most of the dependencies are not used.
"vm_monitor",
# All of these exist in libs and are not usually built independently.
# Putting workspace hack there adds a bottleneck for cargo builds.
"compute_api",
"consumption_metrics",
"desim",
"metrics",
"pageserver_api",
"postgres_backend",
"postgres_connection",
"postgres_ffi",
"pq_proto",
"remote_storage",
"safekeeper_api",
"tenant_size_model",
"tracing-utils",
"utils",
"wal_craft",
"walproposer",
]
# vm_monitor benefits from the same Cargo.lock as the rest of our artifacts, but
# it is built primarly in separate repo neondatabase/autoscaling and thus is excluded
# from depending on workspace-hack because most of the dependencies are not used.
workspace-members = ["vm_monitor"]
# Write out exact versions rather than a semver range. (Defaults to false.)
# exact-versions = true

View File

@@ -169,8 +169,10 @@ runs:
EXTRA_PARAMS="--durations-path $TEST_OUTPUT/benchmark_durations.json $EXTRA_PARAMS"
fi
if [[ $BUILD_TYPE == "debug" && $RUNNER_ARCH == 'X64' ]]; then
if [[ "${{ inputs.build_type }}" == "debug" ]]; then
cov_prefix=(scripts/coverage "--profraw-prefix=$GITHUB_JOB" --dir=/tmp/coverage run)
elif [[ "${{ inputs.build_type }}" == "release" ]]; then
cov_prefix=()
else
cov_prefix=()
fi

View File

@@ -94,16 +94,11 @@ jobs:
# We run tests with addtional features, that are turned off by default (e.g. in release builds), see
# corresponding Cargo.toml files for their descriptions.
- name: Set env variables
env:
ARCH: ${{ inputs.arch }}
run: |
CARGO_FEATURES="--features testing"
if [[ $BUILD_TYPE == "debug" && $ARCH == 'x64' ]]; then
if [[ $BUILD_TYPE == "debug" ]]; then
cov_prefix="scripts/coverage --profraw-prefix=$GITHUB_JOB --dir=/tmp/coverage run"
CARGO_FLAGS="--locked"
elif [[ $BUILD_TYPE == "debug" ]]; then
cov_prefix=""
CARGO_FLAGS="--locked"
elif [[ $BUILD_TYPE == "release" ]]; then
cov_prefix=""
CARGO_FLAGS="--locked --release"
@@ -163,8 +158,6 @@ jobs:
# Do install *before* running rust tests because they might recompile the
# binaries with different features/flags.
- name: Install rust binaries
env:
ARCH: ${{ inputs.arch }}
run: |
# Install target binaries
mkdir -p /tmp/neon/bin/
@@ -179,7 +172,7 @@ jobs:
done
# Install test executables and write list of all binaries (for code coverage)
if [[ $BUILD_TYPE == "debug" && $ARCH == 'x64' ]]; then
if [[ $BUILD_TYPE == "debug" ]]; then
# Keep bloated coverage data files away from the rest of the artifact
mkdir -p /tmp/coverage/
@@ -250,8 +243,8 @@ jobs:
uses: ./.github/actions/save-coverage-data
regress-tests:
# Don't run regression tests on debug arm64 builds
if: inputs.build-type != 'debug' || inputs.arch != 'arm64'
# Run test on x64 only
if: inputs.arch == 'x64'
needs: [ build-neon ]
runs-on: ${{ fromJson(format('["self-hosted", "{0}"]', inputs.arch == 'arm64' && 'large-arm64' || 'large')) }}
container:

View File

@@ -198,7 +198,7 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: [ x64, arm64 ]
arch: [ x64 ]
# Do not build or run tests in debug for release branches
build-type: ${{ fromJson((startsWith(github.ref_name, 'release') && github.event_name == 'push') && '["release"]' || '["debug", "release"]') }}
include:

345
Cargo.lock generated
View File

@@ -316,6 +316,33 @@ dependencies = [
"zeroize",
]
[[package]]
name = "aws-lc-rs"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ae74d9bd0a7530e8afd1770739ad34b36838829d6ad61818f9230f683f5ad77"
dependencies = [
"aws-lc-sys",
"mirai-annotations",
"paste",
"zeroize",
]
[[package]]
name = "aws-lc-sys"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f0e249228c6ad2d240c2dc94b714d711629d52bad946075d8e9b2f5391f0703"
dependencies = [
"bindgen 0.69.4",
"cc",
"cmake",
"dunce",
"fs_extra",
"libc",
"paste",
]
[[package]]
name = "aws-runtime"
version = "1.2.1"
@@ -926,7 +953,30 @@ dependencies = [
"lazycell",
"log",
"peeking_take_while",
"prettyplease 0.2.6",
"prettyplease 0.2.17",
"proc-macro2",
"quote",
"regex",
"rustc-hash",
"shlex",
"syn 2.0.52",
"which",
]
[[package]]
name = "bindgen"
version = "0.69.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a00dc851838a2120612785d195287475a3ac45514741da670b735818822129a0"
dependencies = [
"bitflags 2.4.1",
"cexpr",
"clang-sys",
"itertools 0.12.1",
"lazy_static",
"lazycell",
"log",
"prettyplease 0.2.17",
"proc-macro2",
"quote",
"regex",
@@ -1056,6 +1106,12 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "cfg_aliases"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]]
name = "cgroups-rs"
version = "0.3.3"
@@ -1164,6 +1220,15 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b"
[[package]]
name = "cmake"
version = "0.1.51"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fb1e43aa7fd152b1f968787f7dbcdeb306d1867ff373c69955211876c053f91a"
dependencies = [
"cc",
]
[[package]]
name = "colorchoice"
version = "1.0.0"
@@ -1208,6 +1273,7 @@ dependencies = [
"serde_json",
"serde_with",
"utils",
"workspace_hack",
]
[[package]]
@@ -1320,6 +1386,7 @@ dependencies = [
"serde",
"serde_with",
"utils",
"workspace_hack",
]
[[package]]
@@ -1490,7 +1557,7 @@ dependencies = [
"bitflags 1.3.2",
"crossterm_winapi",
"libc",
"mio",
"mio 0.8.11",
"parking_lot 0.12.1",
"signal-hook",
"signal-hook-mio",
@@ -1668,6 +1735,7 @@ dependencies = [
"smallvec",
"tracing",
"utils",
"workspace_hack",
]
[[package]]
@@ -1765,6 +1833,12 @@ dependencies = [
"syn 2.0.52",
]
[[package]]
name = "dunce"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813"
[[package]]
name = "dyn-clone"
version = "1.0.14"
@@ -2066,6 +2140,12 @@ dependencies = [
"tokio-util",
]
[[package]]
name = "fs_extra"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "fsevent-sys"
version = "4.1.0"
@@ -2399,9 +2479,9 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hermit-abi"
version = "0.3.3"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7"
checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024"
[[package]]
name = "hex"
@@ -2919,6 +2999,33 @@ dependencies = [
"libc",
]
[[package]]
name = "ktls"
version = "6.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebe51e4a53d53b396707537bc8a5277798b720fb71f0d1b9c63eb53199a00fde"
dependencies = [
"futures-util",
"ktls-sys",
"libc",
"memoffset 0.9.1",
"nix 0.29.0",
"num_enum",
"pin-project-lite",
"rustls 0.23.12",
"smallvec",
"thiserror",
"tokio",
"tokio-rustls 0.26.0",
"tracing",
]
[[package]]
name = "ktls-sys"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "095b1fc8d841c3df8c3f2db78b7425cb2ec424568a282cb589a880b99d256e84"
[[package]]
name = "lasso"
version = "0.7.2"
@@ -2957,9 +3064,9 @@ dependencies = [
[[package]]
name = "libc"
version = "0.2.150"
version = "0.2.158"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c"
checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439"
[[package]]
name = "libloading"
@@ -3123,9 +3230,9 @@ dependencies = [
[[package]]
name = "memoffset"
version = "0.9.0"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c"
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
dependencies = [
"autocfg",
]
@@ -3144,6 +3251,7 @@ dependencies = [
"rand 0.8.5",
"rand_distr",
"twox-hash",
"workspace_hack",
]
[[package]]
@@ -3200,6 +3308,24 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "mio"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec"
dependencies = [
"hermit-abi",
"libc",
"wasi 0.11.0+wasi-snapshot-preview1",
"windows-sys 0.52.0",
]
[[package]]
name = "mirai-annotations"
version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c9be0862c1b3f26a88803c4a49de6889c10e608b3ee9344e6ef5b45fb37ad3d1"
[[package]]
name = "multimap"
version = "0.8.3"
@@ -3240,7 +3366,20 @@ dependencies = [
"bitflags 2.4.1",
"cfg-if",
"libc",
"memoffset 0.9.0",
"memoffset 0.9.1",
]
[[package]]
name = "nix"
version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46"
dependencies = [
"bitflags 2.4.1",
"cfg-if",
"cfg_aliases",
"libc",
"memoffset 0.9.1",
]
[[package]]
@@ -3267,7 +3406,7 @@ dependencies = [
"kqueue",
"libc",
"log",
"mio",
"mio 0.8.11",
"walkdir",
"windows-sys 0.48.0",
]
@@ -3389,6 +3528,27 @@ dependencies = [
"libc",
]
[[package]]
name = "num_enum"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e613fc340b2220f734a8595782c551f1250e969d87d3be1ae0579e8d4065179"
dependencies = [
"num_enum_derive",
]
[[package]]
name = "num_enum_derive"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af1844ef2428cc3e1cb900be36181049ef3d3193c63e43026cfe202983b27a56"
dependencies = [
"proc-macro-crate",
"proc-macro2",
"quote",
"syn 2.0.52",
]
[[package]]
name = "oauth2"
version = "4.4.2"
@@ -3787,6 +3947,7 @@ dependencies = [
"strum_macros",
"thiserror",
"utils",
"workspace_hack",
]
[[package]]
@@ -4051,9 +4212,9 @@ dependencies = [
[[package]]
name = "pin-project-lite"
version = "0.2.13"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58"
checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02"
[[package]]
name = "pin-utils"
@@ -4178,16 +4339,17 @@ dependencies = [
"futures",
"once_cell",
"pq_proto",
"rustls 0.22.4",
"rustls 0.23.12",
"rustls-pemfile 2.1.1",
"serde",
"thiserror",
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",
"tokio-rustls 0.25.0",
"tokio-rustls 0.26.0",
"tokio-util",
"tracing",
"workspace_hack",
]
[[package]]
@@ -4200,6 +4362,7 @@ dependencies = [
"postgres",
"tokio-postgres",
"url",
"workspace_hack",
]
[[package]]
@@ -4207,7 +4370,7 @@ name = "postgres_ffi"
version = "0.1.0"
dependencies = [
"anyhow",
"bindgen",
"bindgen 0.65.1",
"byteorder",
"bytes",
"crc32c",
@@ -4222,6 +4385,7 @@ dependencies = [
"serde",
"thiserror",
"utils",
"workspace_hack",
]
[[package]]
@@ -4259,6 +4423,7 @@ dependencies = [
"thiserror",
"tokio",
"tracing",
"workspace_hack",
]
[[package]]
@@ -4273,9 +4438,9 @@ dependencies = [
[[package]]
name = "prettyplease"
version = "0.2.6"
version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b69d39aab54d069e7f2fe8cb970493e7834601ca2d8c65fd7bbd183578080d1"
checksum = "8d3928fb5db768cb86f891ff014f0144589297e3c6a1aba6ed7cecfdace270c7"
dependencies = [
"proc-macro2",
"syn 2.0.52",
@@ -4290,6 +4455,15 @@ dependencies = [
"elliptic-curve 0.13.8",
]
[[package]]
name = "proc-macro-crate"
version = "3.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d37c51ca738a55da99dc0c4a34860fd675453b8b36209178c2249bb13651284"
dependencies = [
"toml_edit 0.21.1",
]
[[package]]
name = "proc-macro-hack"
version = "0.5.20+deprecated"
@@ -4448,6 +4622,7 @@ dependencies = [
"itertools 0.10.5",
"jose-jwa",
"jose-jwk",
"ktls",
"lasso",
"md5",
"measured",
@@ -4478,7 +4653,7 @@ dependencies = [
"rsa",
"rstest",
"rustc-hash",
"rustls 0.22.4",
"rustls 0.23.12",
"rustls-native-certs 0.7.0",
"rustls-pemfile 2.1.1",
"scopeguard",
@@ -4497,7 +4672,7 @@ dependencies = [
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",
"tokio-rustls 0.25.0",
"tokio-rustls 0.26.0",
"tokio-tungstenite",
"tokio-util",
"tower-service",
@@ -4663,12 +4838,13 @@ dependencies = [
[[package]]
name = "rcgen"
version = "0.12.1"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48406db8ac1f3cbc7dcdb56ec355343817958a356ff430259bb07baf7607e1e1"
checksum = "54077e1872c46788540de1ea3d7f4ccb1983d12f9aa909b234468676c1a36779"
dependencies = [
"pem",
"ring 0.17.6",
"rustls-pki-types",
"time",
"yasna",
]
@@ -4823,6 +4999,7 @@ dependencies = [
"toml_edit 0.19.10",
"tracing",
"utils",
"workspace_hack",
]
[[package]]
@@ -5180,7 +5357,22 @@ dependencies = [
"log",
"ring 0.17.6",
"rustls-pki-types",
"rustls-webpki 0.102.2",
"rustls-webpki 0.102.6",
"subtle",
"zeroize",
]
[[package]]
name = "rustls"
version = "0.23.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c58f8c84392efc0a126acce10fa59ff7b3d2ac06ab451a33f2741989b806b044"
dependencies = [
"aws-lc-rs",
"log",
"once_cell",
"rustls-pki-types",
"rustls-webpki 0.102.6",
"subtle",
"zeroize",
]
@@ -5231,9 +5423,9 @@ dependencies = [
[[package]]
name = "rustls-pki-types"
version = "1.3.1"
version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ede67b28608b4c60685c7d54122d4400d90f62b40caee7700e700380a390fa8"
checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0"
[[package]]
name = "rustls-webpki"
@@ -5257,10 +5449,11 @@ dependencies = [
[[package]]
name = "rustls-webpki"
version = "0.102.2"
version = "0.102.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "faaa0a62740bedb9b2ef5afa303da42764c012f743917351dc9a237ea1663610"
checksum = "8e6b52d4fda176fd835fdc55a835d4a89b8499cad995885a21149d5ad62f852e"
dependencies = [
"aws-lc-rs",
"ring 0.17.6",
"rustls-pki-types",
"untrusted 0.9.0",
@@ -5347,6 +5540,7 @@ dependencies = [
"serde",
"serde_with",
"utils",
"workspace_hack",
]
[[package]]
@@ -5700,9 +5894,9 @@ dependencies = [
[[package]]
name = "sha2-asm"
version = "0.6.3"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f27ba7066011e3fb30d808b51affff34f0a66d3a03a58edd787c6e420e40e44e"
checksum = "b845214d6175804686b2bd482bcffe96651bb2d1200742b712003504a2dac1ab"
dependencies = [
"cc",
]
@@ -5739,7 +5933,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29ad2e15f37ec9a6cc544097b78a1ec90001e9f71b81338ca39f430adaca99af"
dependencies = [
"libc",
"mio",
"mio 0.8.11",
"signal-hook",
]
@@ -5801,9 +5995,9 @@ dependencies = [
[[package]]
name = "smallvec"
version = "1.13.1"
version = "1.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7"
checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
[[package]]
name = "smol_str"
@@ -5995,7 +6189,7 @@ dependencies = [
"rand 0.8.5",
"remote_storage",
"reqwest 0.12.4",
"rustls 0.22.4",
"rustls 0.23.12",
"rustls-native-certs 0.7.0",
"serde",
"serde_json",
@@ -6005,7 +6199,7 @@ dependencies = [
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",
"tokio-rustls 0.25.0",
"tokio-rustls 0.26.0",
"tokio-stream",
"tokio-util",
"tracing",
@@ -6182,6 +6376,7 @@ dependencies = [
"anyhow",
"serde",
"serde_json",
"workspace_hack",
]
[[package]]
@@ -6216,18 +6411,18 @@ dependencies = [
[[package]]
name = "thiserror"
version = "1.0.57"
version = "1.0.63"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e45bcbe8ed29775f228095caf2cd67af7a4ccf756ebff23a306bf3e8b47b24b"
checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.57"
version = "1.0.63"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81"
checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
dependencies = [
"proc-macro2",
"quote",
@@ -6354,20 +6549,19 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokio"
version = "1.37.0"
version = "1.39.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787"
checksum = "9babc99b9923bfa4804bd74722ff02c0381021eafa4db9949217e3be8e84fff5"
dependencies = [
"backtrace",
"bytes",
"libc",
"mio",
"num_cpus",
"mio 1.0.2",
"pin-project-lite",
"signal-hook-registry",
"socket2 0.5.5",
"tokio-macros",
"windows-sys 0.48.0",
"windows-sys 0.52.0",
]
[[package]]
@@ -6398,9 +6592,9 @@ dependencies = [
[[package]]
name = "tokio-macros"
version = "2.2.0"
version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b"
checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752"
dependencies = [
"proc-macro2",
"quote",
@@ -6432,16 +6626,15 @@ dependencies = [
[[package]]
name = "tokio-postgres-rustls"
version = "0.11.1"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ea13f22eda7127c827983bdaf0d7fff9df21c8817bab02815ac277a21143677"
checksum = "04fb792ccd6bbcd4bba408eb8a292f70fc4a3589e5d793626f45190e6454b6ab"
dependencies = [
"futures",
"ring 0.17.6",
"rustls 0.22.4",
"rustls 0.23.12",
"tokio",
"tokio-postgres",
"tokio-rustls 0.25.0",
"tokio-rustls 0.26.0",
"x509-certificate",
]
@@ -6466,6 +6659,17 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-rustls"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4"
dependencies = [
"rustls 0.23.12",
"rustls-pki-types",
"tokio",
]
[[package]]
name = "tokio-stream"
version = "0.1.14"
@@ -6567,6 +6771,17 @@ dependencies = [
"winnow 0.4.6",
]
[[package]]
name = "toml_edit"
version = "0.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8534fd7f78b5405e860340ad6575217ce99f38d4d5c8f2442cb5ecb50090e1"
dependencies = [
"indexmap 2.0.1",
"toml_datetime",
"winnow 0.5.40",
]
[[package]]
name = "toml_edit"
version = "0.22.14"
@@ -6659,11 +6874,10 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52"
[[package]]
name = "tracing"
version = "0.1.37"
version = "0.1.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8"
checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
dependencies = [
"cfg-if",
"log",
"pin-project-lite",
"tracing-attributes",
@@ -6683,9 +6897,9 @@ dependencies = [
[[package]]
name = "tracing-attributes"
version = "0.1.24"
version = "0.1.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f57e3ca2a01450b1a921183a9c9cbfda207fd822cef4ccb00a65402cbba7a74"
checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
dependencies = [
"proc-macro2",
"quote",
@@ -6694,9 +6908,9 @@ dependencies = [
[[package]]
name = "tracing-core"
version = "0.1.31"
version = "0.1.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a"
checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54"
dependencies = [
"once_cell",
"valuable",
@@ -6782,6 +6996,7 @@ dependencies = [
"tracing",
"tracing-opentelemetry",
"tracing-subscriber",
"workspace_hack",
]
[[package]]
@@ -6999,6 +7214,7 @@ dependencies = [
"url",
"uuid",
"walkdir",
"workspace_hack",
]
[[package]]
@@ -7077,6 +7293,7 @@ dependencies = [
"postgres_ffi",
"regex",
"utils",
"workspace_hack",
]
[[package]]
@@ -7094,9 +7311,10 @@ name = "walproposer"
version = "0.1.0"
dependencies = [
"anyhow",
"bindgen",
"bindgen 0.65.1",
"postgres_ffi",
"utils",
"workspace_hack",
]
[[package]]
@@ -7547,6 +7765,15 @@ dependencies = [
"memchr",
]
[[package]]
name = "winnow"
version = "0.5.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f593a95398737aeed53e489c785df13f3618e41dbcd6718c6addbf1395aa6876"
dependencies = [
"memchr",
]
[[package]]
name = "winnow"
version = "0.6.13"
@@ -7636,6 +7863,8 @@ dependencies = [
"reqwest 0.11.19",
"reqwest 0.12.4",
"rustls 0.21.11",
"rustls-pki-types",
"rustls-webpki 0.102.6",
"scopeguard",
"serde",
"serde_json",
@@ -7653,6 +7882,8 @@ dependencies = [
"tokio",
"tokio-rustls 0.24.0",
"tokio-util",
"toml_datetime",
"toml_edit 0.19.10",
"tonic",
"tower",
"tracing",

View File

@@ -139,7 +139,7 @@ reqwest-retry = "0.5"
routerify = "3"
rpds = "0.13"
rustc-hash = "1.1.0"
rustls = "0.22"
rustls = "0.23"
rustls-pemfile = "2"
rustls-split = "0.3"
scopeguard = "1.1"
@@ -171,8 +171,8 @@ tikv-jemalloc-ctl = "0.5"
tokio = { version = "1.17", features = ["macros"] }
tokio-epoll-uring = { git = "https://github.com/neondatabase/tokio-epoll-uring.git" , branch = "main" }
tokio-io-timeout = "1.2.0"
tokio-postgres-rustls = "0.11.0"
tokio-rustls = "0.25"
tokio-postgres-rustls = "0.12.0"
tokio-rustls = "0.26"
tokio-stream = "0.1"
tokio-tar = "0.3"
tokio-util = { version = "0.7.10", features = ["io", "rt"] }
@@ -232,7 +232,7 @@ workspace_hack = { version = "0.1", path = "./workspace_hack/" }
## Build dependencies
criterion = "0.5.1"
rcgen = "0.12"
rcgen = "0.13"
rstest = "0.18"
camino-tempfile = "1.0.2"
tonic-build = "0.9"

View File

@@ -14,3 +14,5 @@ regex.workspace = true
utils = { path = "../utils" }
remote_storage = { version = "0.1", path = "../remote_storage/" }
workspace_hack.workspace = true

View File

@@ -6,8 +6,10 @@ license = "Apache-2.0"
[dependencies]
anyhow.workspace = true
chrono = { workspace = true, features = ["serde"] }
chrono.workspace = true
rand.workspace = true
serde.workspace = true
serde_with.workspace = true
utils.workspace = true
workspace_hack.workspace = true

View File

@@ -14,3 +14,5 @@ parking_lot.workspace = true
hex.workspace = true
scopeguard.workspace = true
smallvec = { workspace = true, features = ["write"] }
workspace_hack.workspace = true

View File

@@ -12,6 +12,8 @@ chrono.workspace = true
twox-hash.workspace = true
measured.workspace = true
workspace_hack.workspace = true
[target.'cfg(target_os = "linux")'.dependencies]
procfs.workspace = true
measured-process.workspace = true

View File

@@ -21,9 +21,11 @@ hex.workspace = true
humantime.workspace = true
thiserror.workspace = true
humantime-serde.workspace = true
chrono = { workspace = true, features = ["serde"] }
chrono.workspace = true
itertools.workspace = true
workspace_hack.workspace = true
[dev-dependencies]
bincode.workspace = true
rand.workspace = true

View File

@@ -18,6 +18,7 @@ tokio-rustls.workspace = true
tracing.workspace = true
pq_proto.workspace = true
workspace_hack.workspace = true
[dev-dependencies]
once_cell.workspace = true

View File

@@ -11,5 +11,7 @@ postgres.workspace = true
tokio-postgres.workspace = true
url.workspace = true
workspace_hack.workspace = true
[dev-dependencies]
once_cell.workspace = true

View File

@@ -19,6 +19,8 @@ thiserror.workspace = true
serde.workspace = true
utils.workspace = true
workspace_hack.workspace = true
[dev-dependencies]
env_logger.workspace = true
postgres.workspace = true

View File

@@ -14,6 +14,8 @@ postgres.workspace = true
postgres_ffi.workspace = true
camino-tempfile.workspace = true
workspace_hack.workspace = true
[dev-dependencies]
regex.workspace = true
utils.workspace = true

View File

@@ -11,7 +11,9 @@ itertools.workspace = true
pin-project-lite.workspace = true
postgres-protocol.workspace = true
rand.workspace = true
tokio = { workspace = true, features = ["io-util"] }
tokio.workspace = true
tracing.workspace = true
thiserror.workspace = true
serde.workspace = true
workspace_hack.workspace = true

View File

@@ -32,7 +32,7 @@ scopeguard.workspace = true
metrics.workspace = true
utils.workspace = true
pin-project-lite.workspace = true
workspace_hack.workspace = true
azure_core.workspace = true
azure_identity.workspace = true
azure_storage.workspace = true
@@ -46,4 +46,3 @@ sync_wrapper = { workspace = true, features = ["futures"] }
camino-tempfile.workspace = true
test-context.workspace = true
rand.workspace = true
tokio = { workspace = true, features = ["test-util"] }

View File

@@ -9,3 +9,5 @@ serde.workspace = true
serde_with.workspace = true
const_format.workspace = true
utils.workspace = true
workspace_hack.workspace = true

View File

@@ -9,3 +9,5 @@ license.workspace = true
anyhow.workspace = true
serde.workspace = true
serde_json.workspace = true
workspace_hack.workspace = true

View File

@@ -14,3 +14,5 @@ tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
tracing.workspace = true
tracing-opentelemetry.workspace = true
tracing-subscriber.workspace = true
workspace_hack.workspace = true

View File

@@ -39,7 +39,7 @@ thiserror.workspace = true
tokio.workspace = true
tokio-tar.workspace = true
tokio-util.workspace = true
toml_edit = { workspace = true, features = ["serde"] }
toml_edit.workspace = true
tracing.workspace = true
tracing-error.workspace = true
tracing-subscriber = { workspace = true, features = ["json", "registry"] }
@@ -54,6 +54,7 @@ walkdir.workspace = true
pq_proto.workspace = true
postgres_connection.workspace = true
metrics.workspace = true
workspace_hack.workspace = true
const_format.workspace = true
@@ -70,7 +71,6 @@ criterion.workspace = true
hex-literal.workspace = true
camino-tempfile.workspace = true
serde_assert.workspace = true
tokio = { workspace = true, features = ["test-util"] }
[[bench]]
name = "benchmarks"

View File

@@ -9,6 +9,8 @@ anyhow.workspace = true
utils.workspace = true
postgres_ffi.workspace = true
workspace_hack.workspace = true
[build-dependencies]
anyhow.workspace = true
bindgen.workspace = true

View File

@@ -10,7 +10,6 @@ use pageserver::{
page_cache,
repository::Value,
task_mgr::TaskKind,
tenant::storage_layer::inmemory_layer::SerializedBatch,
tenant::storage_layer::InMemoryLayer,
virtual_file,
};
@@ -68,16 +67,12 @@ async fn ingest(
let layer =
InMemoryLayer::create(conf, timeline_id, tenant_shard_id, lsn, entered, &ctx).await?;
let data = Value::Image(Bytes::from(vec![0u8; put_size]));
let data_ser_size = data.serialized_size().unwrap() as usize;
let data = Value::Image(Bytes::from(vec![0u8; put_size])).ser()?;
let ctx = RequestContext::new(
pageserver::task_mgr::TaskKind::WalReceiverConnectionHandler,
pageserver::context::DownloadBehavior::Download,
);
const BATCH_SIZE: usize = 16;
let mut batch = Vec::new();
for i in 0..put_count {
lsn += put_size as u64;
@@ -100,17 +95,7 @@ async fn ingest(
}
}
batch.push((key.to_compact(), lsn, data_ser_size, data.clone()));
if batch.len() >= BATCH_SIZE {
let this_batch = std::mem::take(&mut batch);
let serialized = SerializedBatch::from_values(this_batch);
layer.put_batch(serialized, &ctx).await?;
}
}
if !batch.is_empty() {
let this_batch = std::mem::take(&mut batch);
let serialized = SerializedBatch::from_values(this_batch);
layer.put_batch(serialized, &ctx).await?;
layer.put_value(key.to_compact(), lsn, &data, &ctx).await?;
}
layer.freeze(lsn + 1).await;

View File

@@ -82,51 +82,12 @@ where
return Ok(None);
}
// Our goal is to find an LSN boundary such that:
// Walk the ranges in LSN order.
//
// - There are no layers that overlap with the LSN. In other words, the LSN neatly
// divides the layers into two sets: those that are above it, and those that are
// below it.
//
// - All the layers above it have LSN range smaller than lsn_max_size
//
// Our strategy is to walk through the layers ordered by end LSNs, and maintain two
// values:
//
// - The current best LSN boundary. We have already determined that there are no
// layers that overlap with it.
//
// - Candidate LSN boundary. We have not found any layers that overlap with this LSN
// yet, but we might still see one later, as we continue iterating.
//
// Whenever we find that there are no layers that overlap with the current candidate,
// we make that the new "current best" and start building a new candidate. We start
// with the caller-supplied `end_lsn` as the current best LSN boundary. All the layers
// are below that LSN, so that's a valid stopping point. If we see a layer that is too
// large, we discard the candidate that we were building and return with the "current
// best" that we have seen so far.
//
// Here's an illustration of the state after processing layers A, B, C, and D:
//
// A B C D E
// 10000 <-- end_lsn passed by caller
// 9000 +
// 8000 | + +
// 7000 + | |
// 6000 | |
// 5000 | +
// 4000 +
// 3000 + <-- current_best_start_lsn
// 3000 | +
// 2000 | |
// 2000 + | <--- candidate_start_lsn
// 1000 +
//
// When we saw layer D with end_lsn==3000, and had not seen any layers that cross that
// LSN, we knew that that is a valid stopping point. It became the new
// current_best_start_lsn, and the layer's start_lsn (2000) became the new
// candidate_start_lsn. Because layer E overlaps with it, it is not a valid return
// value, but we will not know that until processing layer E.
// ----- end_lsn
// |
// |
// v
//
layers.sort_by_key(|l| l.lsn_range().end);
let mut candidate_start_lsn = end_lsn;
@@ -173,8 +134,7 @@ where
// Is it small enough to be considered part of this level?
if r.end.0 - r.start.0 > lsn_max_size {
// Too large, this layer belongs to next level. Discard the candidate
// we were building and return with `current_best`.
// Too large, this layer belongs to next level. Stop.
trace!(
"too large {}, size {} vs {}",
l.short_id(),

View File

@@ -88,8 +88,6 @@ pub async fn shutdown_pageserver(
) {
use std::time::Duration;
let started_at = std::time::Instant::now();
// If the orderly shutdown below takes too long, we still want to make
// sure that all walredo processes are killed and wait()ed on by us, not systemd.
//
@@ -243,10 +241,7 @@ pub async fn shutdown_pageserver(
walredo_extraordinary_shutdown_thread.join().unwrap();
info!("walredo_extraordinary_shutdown_thread done");
info!(
elapsed_ms = started_at.elapsed().as_millis(),
"Shut down successfully completed"
);
info!("Shut down successfully completed");
std::process::exit(exit_code);
}

View File

@@ -15,11 +15,12 @@ use crate::{aux_file, repository::*};
use anyhow::{ensure, Context};
use bytes::{Buf, Bytes, BytesMut};
use enum_map::Enum;
use itertools::Itertools;
use pageserver_api::key::{
dbdir_key_range, rel_block_to_key, rel_dir_to_key, rel_key_range, rel_size_to_key,
relmap_file_key, repl_origin_key, repl_origin_key_range, slru_block_to_key, slru_dir_to_key,
slru_segment_key_range, slru_segment_size_to_key, twophase_file_key, twophase_key_range,
CompactKey, AUX_FILES_KEY, CHECKPOINT_KEY, CONTROLFILE_KEY, DBDIR_KEY, TWOPHASEDIR_KEY,
AUX_FILES_KEY, CHECKPOINT_KEY, CONTROLFILE_KEY, DBDIR_KEY, TWOPHASEDIR_KEY,
};
use pageserver_api::keyspace::SparseKeySpace;
use pageserver_api::models::AuxFilePolicy;
@@ -36,6 +37,7 @@ use tokio_util::sync::CancellationToken;
use tracing::{debug, info, trace, warn};
use utils::bin_ser::DeserializeError;
use utils::pausable_failpoint;
use utils::vec_map::{VecMap, VecMapOrdering};
use utils::{bin_ser::BeSer, lsn::Lsn};
/// Max delta records appended to the AUX_FILES_KEY (for aux v1). The write path will write a full image once this threshold is reached.
@@ -172,7 +174,6 @@ impl Timeline {
pending_deletions: Vec::new(),
pending_nblocks: 0,
pending_directory_entries: Vec::new(),
pending_bytes: 0,
lsn,
}
}
@@ -1021,33 +1022,21 @@ pub struct DatadirModification<'a> {
// The put-functions add the modifications here, and they are flushed to the
// underlying key-value store by the 'finish' function.
pending_lsns: Vec<Lsn>,
pending_updates: HashMap<Key, Vec<(Lsn, usize, Value)>>,
pending_updates: HashMap<Key, Vec<(Lsn, Value)>>,
pending_deletions: Vec<(Range<Key>, Lsn)>,
pending_nblocks: i64,
/// For special "directory" keys that store key-value maps, track the size of the map
/// if it was updated in this modification.
pending_directory_entries: Vec<(DirectoryKind, usize)>,
/// An **approximation** of how large our EphemeralFile write will be when committed.
pending_bytes: usize,
}
impl<'a> DatadirModification<'a> {
// When a DatadirModification is committed, we do a monolithic serialization of all its contents. WAL records can
// contain multiple pages, so the pageserver's record-based batch size isn't sufficient to bound this allocation: we
// additionally specify a limit on how much payload a DatadirModification may contain before it should be committed.
pub(crate) const MAX_PENDING_BYTES: usize = 8 * 1024 * 1024;
/// Get the current lsn
pub(crate) fn get_lsn(&self) -> Lsn {
self.lsn
}
pub(crate) fn approx_pending_bytes(&self) -> usize {
self.pending_bytes
}
/// Set the current lsn
pub(crate) fn set_lsn(&mut self, lsn: Lsn) -> anyhow::Result<()> {
ensure!(
@@ -1780,25 +1769,21 @@ impl<'a> DatadirModification<'a> {
// Flush relation and SLRU data blocks, keep metadata.
let mut retained_pending_updates = HashMap::<_, Vec<_>>::new();
for (key, values) in self.pending_updates.drain() {
let mut write_batch = Vec::new();
for (lsn, value_ser_size, value) in values {
for (lsn, value) in values {
if key.is_rel_block_key() || key.is_slru_block_key() {
// This bails out on first error without modifying pending_updates.
// That's Ok, cf this function's doc comment.
write_batch.push((key.to_compact(), lsn, value_ser_size, value));
writer.put(key, lsn, &value, ctx).await?;
} else {
retained_pending_updates.entry(key).or_default().push((
lsn,
value_ser_size,
value,
));
retained_pending_updates
.entry(key)
.or_default()
.push((lsn, value));
}
}
writer.put_batch(write_batch, ctx).await?;
}
self.pending_updates = retained_pending_updates;
self.pending_bytes = 0;
if pending_nblocks != 0 {
writer.update_current_logical_size(pending_nblocks * i64::from(BLCKSZ));
@@ -1824,20 +1809,17 @@ impl<'a> DatadirModification<'a> {
self.pending_nblocks = 0;
if !self.pending_updates.is_empty() {
// Ordering: the items in this batch do not need to be in any global order, but values for
// a particular Key must be in Lsn order relative to one another. InMemoryLayer relies on
// this to do efficient updates to its index.
let batch: Vec<(CompactKey, Lsn, usize, Value)> = self
.pending_updates
.drain()
.flat_map(|(key, values)| {
values.into_iter().map(move |(lsn, val_ser_size, value)| {
(key.to_compact(), lsn, val_ser_size, value)
})
})
.collect::<Vec<_>>();
// The put_batch call below expects expects the inputs to be sorted by Lsn,
// so we do that first.
let lsn_ordered_batch: VecMap<Lsn, (Key, Value)> = VecMap::from_iter(
self.pending_updates
.drain()
.map(|(key, vals)| vals.into_iter().map(move |(lsn, val)| (lsn, (key, val))))
.kmerge_by(|lhs, rhs| lhs.0 < rhs.0),
VecMapOrdering::GreaterOrEqual,
);
writer.put_batch(batch, ctx).await?;
writer.put_batch(lsn_ordered_batch, ctx).await?;
}
if !self.pending_deletions.is_empty() {
@@ -1862,8 +1844,6 @@ impl<'a> DatadirModification<'a> {
writer.update_directory_entries_count(kind, count as u64);
}
self.pending_bytes = 0;
Ok(())
}
@@ -1880,7 +1860,7 @@ impl<'a> DatadirModification<'a> {
// Note: we don't check pending_deletions. It is an error to request a
// value that has been removed, deletion only avoids leaking storage.
if let Some(values) = self.pending_updates.get(&key) {
if let Some((_, _, value)) = values.last() {
if let Some((_, value)) = values.last() {
return if let Value::Image(img) = value {
Ok(img.clone())
} else {
@@ -1908,17 +1888,13 @@ impl<'a> DatadirModification<'a> {
fn put(&mut self, key: Key, val: Value) {
let values = self.pending_updates.entry(key).or_default();
// Replace the previous value if it exists at the same lsn
if let Some((last_lsn, last_value_ser_size, last_value)) = values.last_mut() {
if let Some((last_lsn, last_value)) = values.last_mut() {
if *last_lsn == self.lsn {
*last_value_ser_size = val.serialized_size().unwrap() as usize;
*last_value = val;
return;
}
}
let val_serialized_size = val.serialized_size().unwrap() as usize;
self.pending_bytes += val_serialized_size;
values.push((self.lsn, val_serialized_size, val));
values.push((self.lsn, val));
}
fn delete(&mut self, key_range: Range<Key>) {

View File

@@ -79,8 +79,6 @@ impl EphemeralFile {
self.rw.read_blk(blknum, ctx).await
}
#[cfg(test)]
// This is a test helper: outside of tests, we are always written to via a pre-serialized batch.
pub(crate) async fn write_blob(
&mut self,
srcbuf: &[u8],
@@ -88,30 +86,17 @@ impl EphemeralFile {
) -> Result<u64, io::Error> {
let pos = self.rw.bytes_written();
let mut len_bytes = std::io::Cursor::new(Vec::new());
crate::tenant::storage_layer::inmemory_layer::SerializedBatch::write_blob_length(
srcbuf.len(),
&mut len_bytes,
);
let len_bytes = len_bytes.into_inner();
// Write the length field
self.rw.write_all_borrowed(&len_bytes, ctx).await?;
if srcbuf.len() < 0x80 {
// short one-byte length header
let len_buf = [srcbuf.len() as u8];
// Write the payload
self.rw.write_all_borrowed(srcbuf, ctx).await?;
Ok(pos)
}
/// Returns the offset at which the first byte of the input was written, for use
/// in constructing indices over the written value.
pub(crate) async fn write_raw(
&mut self,
srcbuf: &[u8],
ctx: &RequestContext,
) -> Result<u64, io::Error> {
let pos = self.rw.bytes_written();
self.rw.write_all_borrowed(&len_buf, ctx).await?;
} else {
let mut len_buf = u32::to_be_bytes(srcbuf.len() as u32);
len_buf[0] |= 0x80;
self.rw.write_all_borrowed(&len_buf, ctx).await?;
}
// Write the payload
self.rw.write_all_borrowed(srcbuf, ctx).await?;

View File

@@ -2,7 +2,7 @@
pub mod delta_layer;
pub mod image_layer;
pub mod inmemory_layer;
pub(crate) mod inmemory_layer;
pub(crate) mod layer;
mod layer_desc;
mod layer_name;

View File

@@ -33,7 +33,7 @@ use std::fmt::Write;
use std::ops::Range;
use std::sync::atomic::Ordering as AtomicOrdering;
use std::sync::atomic::{AtomicU64, AtomicUsize};
use tokio::sync::RwLock;
use tokio::sync::{RwLock, RwLockWriteGuard};
use super::{
DeltaLayerWriter, PersistentLayerDesc, ValueReconstructSituation, ValuesReconstructState,
@@ -320,82 +320,6 @@ impl InMemoryLayer {
}
}
/// Offset of a particular Value within a serialized batch.
struct SerializedBatchOffset {
key: CompactKey,
lsn: Lsn,
/// offset in bytes from the start of the batch's buffer to the Value's serialized size header.
offset: u64,
}
pub struct SerializedBatch {
/// Blobs serialized in EphemeralFile's native format, ready for passing to [`EphemeralFile::write_raw`].
pub(crate) raw: Vec<u8>,
/// Index of values in [`Self::raw`], using offsets relative to the start of the buffer.
offsets: Vec<SerializedBatchOffset>,
/// The highest LSN of any value in the batch
pub(crate) max_lsn: Lsn,
}
impl SerializedBatch {
/// Write a blob length in the internal format of the EphemeralFile
pub(crate) fn write_blob_length(len: usize, cursor: &mut std::io::Cursor<Vec<u8>>) {
use std::io::Write;
if len < 0x80 {
// short one-byte length header
let len_buf = [len as u8];
cursor
.write_all(&len_buf)
.expect("Writing to Vec is infallible");
} else {
let mut len_buf = u32::to_be_bytes(len as u32);
len_buf[0] |= 0x80;
cursor
.write_all(&len_buf)
.expect("Writing to Vec is infallible");
}
}
pub fn from_values(batch: Vec<(CompactKey, Lsn, usize, Value)>) -> Self {
// Pre-allocate a big flat buffer to write into. This should be large but not huge: it is soft-limited in practice by
// [`crate::pgdatadir_mapping::DatadirModification::MAX_PENDING_BYTES`]
let buffer_size = batch.iter().map(|i| i.2).sum::<usize>() + 4 * batch.len();
let mut cursor = std::io::Cursor::new(Vec::<u8>::with_capacity(buffer_size));
let mut offsets: Vec<SerializedBatchOffset> = Vec::with_capacity(batch.len());
let mut max_lsn: Lsn = Lsn(0);
for (key, lsn, val_ser_size, val) in batch {
let relative_off = cursor.position();
Self::write_blob_length(val_ser_size, &mut cursor);
val.ser_into(&mut cursor)
.expect("Writing into in-memory buffer is infallible");
offsets.push(SerializedBatchOffset {
key,
lsn,
offset: relative_off,
});
max_lsn = std::cmp::max(max_lsn, lsn);
}
let buffer = cursor.into_inner();
// Assert that we didn't do any extra allocations while building buffer.
debug_assert!(buffer.len() <= buffer_size);
Self {
raw: buffer,
offsets,
max_lsn,
}
}
}
fn inmem_layer_display(mut f: impl Write, start_lsn: Lsn, end_lsn: Lsn) -> std::fmt::Result {
write!(f, "inmem-{:016X}-{:016X}", start_lsn.0, end_lsn.0)
}
@@ -456,20 +380,37 @@ impl InMemoryLayer {
})
}
// Write path.
pub async fn put_batch(
// Write operations
/// Common subroutine of the public put_wal_record() and put_page_image() functions.
/// Adds the page version to the in-memory tree
pub async fn put_value(
&self,
serialized_batch: SerializedBatch,
key: CompactKey,
lsn: Lsn,
buf: &[u8],
ctx: &RequestContext,
) -> Result<()> {
let mut inner = self.inner.write().await;
self.assert_writable();
self.put_value_locked(&mut inner, key, lsn, buf, ctx).await
}
let base_off = {
inner
async fn put_value_locked(
&self,
locked_inner: &mut RwLockWriteGuard<'_, InMemoryLayerInner>,
key: CompactKey,
lsn: Lsn,
buf: &[u8],
ctx: &RequestContext,
) -> Result<()> {
trace!("put_value key {} at {}/{}", key, self.timeline_id, lsn);
let off = {
locked_inner
.file
.write_raw(
&serialized_batch.raw,
.write_blob(
buf,
&RequestContextBuilder::extend(ctx)
.page_content_kind(PageContentKind::InMemoryLayer)
.build(),
@@ -477,23 +418,15 @@ impl InMemoryLayer {
.await?
};
for SerializedBatchOffset {
key,
lsn,
offset: relative_off,
} in serialized_batch.offsets
{
let off = base_off + relative_off;
let vec_map = inner.index.entry(key).or_default();
let old = vec_map.append_or_update_last(lsn, off).unwrap().0;
if old.is_some() {
// We already had an entry for this LSN. That's odd..
warn!("Key {} at {} already exists", key, lsn);
}
let vec_map = locked_inner.index.entry(key).or_default();
let old = vec_map.append_or_update_last(lsn, off).unwrap().0;
if old.is_some() {
// We already had an entry for this LSN. That's odd..
warn!("Key {} at {} already exists", key, lsn);
}
let size = inner.file.len();
inner.resource_units.maybe_publish_size(size);
let size = locked_inner.file.len();
locked_inner.resource_units.maybe_publish_size(size);
Ok(())
}

View File

@@ -22,8 +22,8 @@ use handle::ShardTimelineId;
use once_cell::sync::Lazy;
use pageserver_api::{
key::{
CompactKey, KEY_SIZE, METADATA_KEY_BEGIN_PREFIX, METADATA_KEY_END_PREFIX,
NON_INHERITED_RANGE, NON_INHERITED_SPARSE_RANGE,
KEY_SIZE, METADATA_KEY_BEGIN_PREFIX, METADATA_KEY_END_PREFIX, NON_INHERITED_RANGE,
NON_INHERITED_SPARSE_RANGE,
},
keyspace::{KeySpaceAccum, KeySpaceRandomAccum, SparseKeyPartitioning},
models::{
@@ -44,8 +44,10 @@ use tokio::{
use tokio_util::sync::CancellationToken;
use tracing::*;
use utils::{
bin_ser::BeSer,
fs_ext, pausable_failpoint,
sync::gate::{Gate, GateGuard},
vec_map::VecMap,
};
use std::pin::pin;
@@ -135,10 +137,7 @@ use self::layer_manager::LayerManager;
use self::logical_size::LogicalSize;
use self::walreceiver::{WalReceiver, WalReceiverConf};
use super::{
config::TenantConf, storage_layer::inmemory_layer, storage_layer::LayerVisibilityHint,
upload_queue::NotInitialized,
};
use super::{config::TenantConf, storage_layer::LayerVisibilityHint, upload_queue::NotInitialized};
use super::{debug_assert_current_span_has_tenant_and_timeline_id, AttachedTenantConf};
use super::{remote_timeline_client::index::IndexPart, storage_layer::LayerFringe};
use super::{
@@ -3590,6 +3589,34 @@ impl Timeline {
return Err(FlushLayerError::Cancelled);
}
// FIXME(auxfilesv2): support multiple metadata key partitions might need initdb support as well?
// This code path will not be hit during regression tests. After #7099 we have a single partition
// with two key ranges. If someone wants to fix initdb optimization in the future, this might need
// to be fixed.
// For metadata, always create delta layers.
let delta_layer = if !metadata_partition.parts.is_empty() {
assert_eq!(
metadata_partition.parts.len(),
1,
"currently sparse keyspace should only contain a single metadata keyspace"
);
let metadata_keyspace = &metadata_partition.parts[0];
self.create_delta_layer(
&frozen_layer,
Some(
metadata_keyspace.0.ranges.first().unwrap().start
..metadata_keyspace.0.ranges.last().unwrap().end,
),
ctx,
)
.await
.map_err(|e| FlushLayerError::from_anyhow(self, e))?
} else {
None
};
// For image layers, we add them immediately into the layer map.
let mut layers_to_upload = Vec::new();
layers_to_upload.extend(
self.create_image_layers(
@@ -3600,27 +3627,13 @@ impl Timeline {
)
.await?,
);
if !metadata_partition.parts.is_empty() {
assert_eq!(
metadata_partition.parts.len(),
1,
"currently sparse keyspace should only contain a single metadata keyspace"
);
layers_to_upload.extend(
self.create_image_layers(
// Safety: create_image_layers treat sparse keyspaces differently that it does not scan
// every single key within the keyspace, and therefore, it's safe to force converting it
// into a dense keyspace before calling this function.
&metadata_partition.into_dense(),
self.initdb_lsn,
ImageLayerCreationMode::Initial,
ctx,
)
.await?,
);
}
(layers_to_upload, None)
if let Some(delta_layer) = delta_layer {
layers_to_upload.push(delta_layer.clone());
(layers_to_upload, Some(delta_layer))
} else {
(layers_to_upload, None)
}
} else {
// Normal case, write out a L0 delta layer file.
// `create_delta_layer` will not modify the layer map.
@@ -4030,6 +4043,8 @@ impl Timeline {
mode: ImageLayerCreationMode,
start: Key,
) -> Result<ImageLayerCreationOutcome, CreateImageLayersError> {
assert!(!matches!(mode, ImageLayerCreationMode::Initial));
// Metadata keys image layer creation.
let mut reconstruct_state = ValuesReconstructState::default();
let data = self
@@ -4195,13 +4210,15 @@ impl Timeline {
"metadata keys must be partitioned separately"
);
}
if mode == ImageLayerCreationMode::Initial {
return Err(CreateImageLayersError::Other(anyhow::anyhow!("no image layer should be created for metadata keys when flushing frozen layers")));
}
if mode == ImageLayerCreationMode::Try && !check_for_image_layers {
// Skip compaction if there are not enough updates. Metadata compaction will do a scan and
// might mess up with evictions.
start = img_range.end;
continue;
}
// For initial and force modes, we always generate image layers for metadata keys.
} else if let ImageLayerCreationMode::Try = mode {
// check_for_image_layers = false -> skip
// check_for_image_layers = true -> check time_for_new_image_layer -> skip/generate
@@ -4209,8 +4226,7 @@ impl Timeline {
start = img_range.end;
continue;
}
}
if let ImageLayerCreationMode::Force = mode {
} else if let ImageLayerCreationMode::Force = mode {
// When forced to create image layers, we might try and create them where they already
// exist. This mode is only used in tests/debug.
let layers = self.layers.read().await;
@@ -4224,7 +4240,6 @@ impl Timeline {
img_range.start,
img_range.end
);
start = img_range.end;
continue;
}
}
@@ -5575,6 +5590,44 @@ enum OpenLayerAction {
}
impl<'a> TimelineWriter<'a> {
/// Put a new page version that can be constructed from a WAL record
///
/// This will implicitly extend the relation, if the page is beyond the
/// current end-of-file.
pub(crate) async fn put(
&mut self,
key: Key,
lsn: Lsn,
value: &Value,
ctx: &RequestContext,
) -> anyhow::Result<()> {
// Avoid doing allocations for "small" values.
// In the regression test suite, the limit of 256 avoided allocations in 95% of cases:
// https://github.com/neondatabase/neon/pull/5056#discussion_r1301975061
let mut buf = smallvec::SmallVec::<[u8; 256]>::new();
value.ser_into(&mut buf)?;
let buf_size: u64 = buf.len().try_into().expect("oversized value buf");
let action = self.get_open_layer_action(lsn, buf_size);
let layer = self.handle_open_layer_action(lsn, action, ctx).await?;
let res = layer.put_value(key.to_compact(), lsn, &buf, ctx).await;
if res.is_ok() {
// Update the current size only when the entire write was ok.
// In case of failures, we may have had partial writes which
// render the size tracking out of sync. That's ok because
// the checkpoint distance should be significantly smaller
// than the S3 single shot upload limit of 5GiB.
let state = self.write_guard.as_mut().unwrap();
state.current_size += buf_size;
state.prev_lsn = Some(lsn);
state.max_lsn = std::cmp::max(state.max_lsn, Some(lsn));
}
res
}
async fn handle_open_layer_action(
&mut self,
at: Lsn,
@@ -5680,58 +5733,18 @@ impl<'a> TimelineWriter<'a> {
}
/// Put a batch of keys at the specified Lsns.
///
/// The batch is sorted by Lsn (enforced by usage of [`utils::vec_map::VecMap`].
pub(crate) async fn put_batch(
&mut self,
batch: Vec<(CompactKey, Lsn, usize, Value)>,
batch: VecMap<Lsn, (Key, Value)>,
ctx: &RequestContext,
) -> anyhow::Result<()> {
if batch.is_empty() {
return Ok(());
for (lsn, (key, val)) in batch {
self.put(key, lsn, &val, ctx).await?
}
let serialized_batch = inmemory_layer::SerializedBatch::from_values(batch);
let batch_max_lsn = serialized_batch.max_lsn;
let buf_size: u64 = serialized_batch.raw.len() as u64;
let action = self.get_open_layer_action(batch_max_lsn, buf_size);
let layer = self
.handle_open_layer_action(batch_max_lsn, action, ctx)
.await?;
let res = layer.put_batch(serialized_batch, ctx).await;
if res.is_ok() {
// Update the current size only when the entire write was ok.
// In case of failures, we may have had partial writes which
// render the size tracking out of sync. That's ok because
// the checkpoint distance should be significantly smaller
// than the S3 single shot upload limit of 5GiB.
let state = self.write_guard.as_mut().unwrap();
state.current_size += buf_size;
state.prev_lsn = Some(batch_max_lsn);
state.max_lsn = std::cmp::max(state.max_lsn, Some(batch_max_lsn));
}
res
}
#[cfg(test)]
/// Test helper, for tests that would like to poke individual values without composing a batch
pub(crate) async fn put(
&mut self,
key: Key,
lsn: Lsn,
value: &Value,
ctx: &RequestContext,
) -> anyhow::Result<()> {
use utils::bin_ser::BeSer;
let val_ser_size = value.serialized_size().unwrap() as usize;
self.put_batch(
vec![(key.to_compact(), lsn, val_ser_size, value.clone())],
ctx,
)
.await
Ok(())
}
pub(crate) async fn delete_batch(

View File

@@ -27,8 +27,8 @@ use super::TaskStateUpdate;
use crate::{
context::RequestContext,
metrics::{LIVE_CONNECTIONS, WALRECEIVER_STARTED_CONNECTIONS, WAL_INGEST},
pgdatadir_mapping::DatadirModification,
task_mgr::{TaskKind, WALRECEIVER_RUNTIME},
task_mgr::TaskKind,
task_mgr::WALRECEIVER_RUNTIME,
tenant::{debug_assert_current_span_has_tenant_and_timeline_id, Timeline, WalReceiverInfo},
walingest::WalIngest,
walrecord::DecodedWALRecord,
@@ -345,10 +345,7 @@ pub(super) async fn handle_walreceiver_connection(
// Commit every ingest_batch_size records. Even if we filtered out
// all records, we still need to call commit to advance the LSN.
uncommitted_records += 1;
if uncommitted_records >= ingest_batch_size
|| modification.approx_pending_bytes()
> DatadirModification::MAX_PENDING_BYTES
{
if uncommitted_records >= ingest_batch_size {
WAL_INGEST
.records_committed
.inc_by(uncommitted_records - filtered_records);

View File

@@ -114,6 +114,9 @@ rsa = "0.9"
workspace_hack.workspace = true
[target.'cfg(target_os = "linux")'.dependencies]
ktls = "6"
[dev-dependencies]
camino-tempfile.workspace = true
fallible-iterator.workspace = true

View File

@@ -4,6 +4,7 @@ pub mod jwt;
mod link;
use std::net::IpAddr;
use std::os::fd::AsRawFd;
use std::sync::Arc;
use std::time::Duration;
@@ -23,6 +24,7 @@ use crate::context::RequestMonitoring;
use crate::intern::EndpointIdInt;
use crate::metrics::Metrics;
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::proxy::handshake::KtlsAsyncReadReady;
use crate::proxy::NeonOptions;
use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter, RateBucketInfo};
use crate::stream::Stream;
@@ -274,7 +276,9 @@ async fn auth_quirks(
ctx: &RequestMonitoring,
api: &impl console::Api,
user_info: ComputeUserInfoMaybeEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut stream::PqStream<
Stream<impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>,
>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -358,7 +362,9 @@ async fn authenticate_with_secret(
ctx: &RequestMonitoring,
secret: AuthSecret,
info: ComputeUserInfo,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut stream::PqStream<
Stream<impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>,
>,
unauthenticated_password: Option<Vec<u8>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
@@ -417,7 +423,9 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
pub async fn authenticate(
self,
ctx: &RequestMonitoring,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut stream::PqStream<
Stream<impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>,
>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -542,7 +550,7 @@ mod tests {
CachedNodeInfo,
},
context::RequestMonitoring,
proxy::NeonOptions,
proxy::{tests::DummyClient, NeonOptions},
rate_limiter::{EndpointRateLimiter, RateBucketInfo},
scram::{threadpool::ThreadPool, ServerSecret},
stream::{PqStream, Stream},
@@ -650,7 +658,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_scram() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut stream = PqStream::new(Stream::from_raw(DummyClient(server)));
let ctx = RequestMonitoring::test();
let api = Auth {
@@ -727,7 +735,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_cleartext() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut stream = PqStream::new(Stream::from_raw(DummyClient(server)));
let ctx = RequestMonitoring::test();
let api = Auth {
@@ -779,7 +787,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_password_hack() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut stream = PqStream::new(Stream::from_raw(DummyClient(server)));
let ctx = RequestMonitoring::test();
let api = Auth {

View File

@@ -1,3 +1,5 @@
use std::os::fd::AsRawFd;
use super::{ComputeCredentials, ComputeUserInfo};
use crate::{
auth::{self, backend::ComputeCredentialKeys, AuthFlow},
@@ -5,6 +7,7 @@ use crate::{
config::AuthenticationConfig,
console::AuthSecret,
context::RequestMonitoring,
proxy::handshake::KtlsAsyncReadReady,
sasl,
stream::{PqStream, Stream},
};
@@ -14,7 +17,9 @@ use tracing::{info, warn};
pub(super) async fn authenticate(
ctx: &RequestMonitoring,
creds: ComputeUserInfo,
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut PqStream<
Stream<impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>,
>,
config: &'static AuthenticationConfig,
secret: AuthSecret,
) -> auth::Result<ComputeCredentials> {

View File

@@ -1,3 +1,5 @@
use std::os::fd::AsRawFd;
use super::{
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint,
};
@@ -7,6 +9,7 @@ use crate::{
console::AuthSecret,
context::RequestMonitoring,
intern::EndpointIdInt,
proxy::handshake::KtlsAsyncReadReady,
sasl,
stream::{self, Stream},
};
@@ -20,7 +23,9 @@ use tracing::{info, warn};
pub async fn authenticate_cleartext(
ctx: &RequestMonitoring,
info: ComputeUserInfo,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut stream::PqStream<
Stream<impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>,
>,
secret: AuthSecret,
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials> {
@@ -62,7 +67,9 @@ pub async fn authenticate_cleartext(
pub async fn password_hack_no_authentication(
ctx: &RequestMonitoring,
info: ComputeUserInfoNoEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
client: &mut stream::PqStream<
Stream<impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>,
>,
) -> auth::Result<ComputeCredentials> {
warn!("project not specified, resorting to the password hack auth flow");
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);

View File

@@ -86,8 +86,7 @@ impl ComputeUserInfoMaybeEndpoint {
pub fn parse(
ctx: &RequestMonitoring,
params: &StartupMessageParams,
sni: Option<&str>,
common_names: Option<&HashSet<String>>,
endpoint_from_domain: Option<EndpointId>,
) -> Result<Self, ComputeUserInfoParseError> {
// Some parameters are stored in the startup message.
let get_param = |key| {
@@ -111,16 +110,7 @@ impl ComputeUserInfoMaybeEndpoint {
})
.map(|name| name.into());
let endpoint_from_domain = if let Some(sni_str) = sni {
if let Some(cn) = common_names {
endpoint_sni(sni_str, cn)?
} else {
None
}
} else {
None
};
let is_sni = endpoint_from_domain.is_some();
let endpoint = match (endpoint_option, endpoint_from_domain) {
// Invariant: if we have both project name variants, they should match.
(Some(option), Some(domain)) if option != domain => {
@@ -143,7 +133,7 @@ impl ComputeUserInfoMaybeEndpoint {
let metrics = Metrics::get();
info!(%user, "credentials");
if sni.is_some() {
if is_sni {
info!("Connection with sni");
metrics.proxy.accepted_connections_by_sni.inc(SniKind::Sni);
} else if endpoint.is_some() {
@@ -255,7 +245,7 @@ mod tests {
// According to postgresql, only `user` should be required.
let options = StartupMessageParams::new([("user", "john_doe")]);
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id, None);
@@ -270,7 +260,7 @@ mod tests {
("foo", "bar"), // should be ignored
]);
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id, None);
@@ -281,12 +271,8 @@ mod tests {
fn parse_project_from_sni() -> anyhow::Result<()> {
let options = StartupMessageParams::new([("user", "john_doe")]);
let sni = Some("foo.localhost");
let common_names = Some(["localhost".into()].into());
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, Some("foo".into()))?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id.as_deref(), Some("foo"));
assert_eq!(user_info.options.get_cache_key("foo"), "foo");
@@ -302,7 +288,7 @@ mod tests {
]);
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
@@ -317,7 +303,7 @@ mod tests {
]);
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
@@ -335,7 +321,7 @@ mod tests {
]);
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?;
assert_eq!(user_info.user, "john_doe");
assert!(user_info.endpoint_id.is_none());
@@ -350,7 +336,7 @@ mod tests {
]);
let ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None, None)?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, None)?;
assert_eq!(user_info.user, "john_doe");
assert!(user_info.endpoint_id.is_none());
@@ -361,49 +347,21 @@ mod tests {
fn parse_projects_identical() -> anyhow::Result<()> {
let options = StartupMessageParams::new([("user", "john_doe"), ("options", "project=baz")]);
let sni = Some("baz.localhost");
let common_names = Some(["localhost".into()].into());
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
let user_info = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, Some("baz".into()))?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id.as_deref(), Some("baz"));
Ok(())
}
#[test]
fn parse_multi_common_names() -> anyhow::Result<()> {
let options = StartupMessageParams::new([("user", "john_doe")]);
let common_names = Some(["a.com".into(), "b.com".into()].into());
let sni = Some("p1.a.com");
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
let common_names = Some(["a.com".into(), "b.com".into()].into());
let sni = Some("p1.b.com");
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
Ok(())
}
#[test]
fn parse_projects_different() {
let options =
StartupMessageParams::new([("user", "john_doe"), ("options", "project=first")]);
let sni = Some("second.localhost");
let common_names = Some(["localhost".into()].into());
let ctx = RequestMonitoring::test();
let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, Some("second".into()))
.expect_err("should fail");
match err {
InconsistentProjectNames { domain, option } => {
@@ -414,24 +372,6 @@ mod tests {
}
}
#[test]
fn parse_inconsistent_sni() {
let options = StartupMessageParams::new([("user", "john_doe")]);
let sni = Some("project.localhost");
let common_names = Some(["example.com".into()].into());
let ctx = RequestMonitoring::test();
let err = ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())
.expect_err("should fail");
match err {
UnknownCommonName { cn } => {
assert_eq!(cn, "localhost");
}
_ => panic!("bad error: {err:?}"),
}
}
#[test]
fn parse_neon_options() -> anyhow::Result<()> {
let options = StartupMessageParams::new([
@@ -439,11 +379,9 @@ mod tests {
("options", "neon_lsn:0/2 neon_endpoint_type:read_write"),
]);
let sni = Some("project.localhost");
let common_names = Some(["localhost".into()].into());
let ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, sni, common_names.as_ref())?;
ComputeUserInfoMaybeEndpoint::parse(&ctx, &options, Some("project".into()))?;
assert_eq!(user_info.endpoint_id.as_deref(), Some("project"));
assert_eq!(
user_info.options.get_cache_key("project"),

View File

@@ -6,13 +6,14 @@ use crate::{
console::AuthSecret,
context::RequestMonitoring,
intern::EndpointIdInt,
proxy::handshake::KtlsAsyncReadReady,
sasl,
scram::{self, threadpool::ThreadPool},
stream::{PqStream, Stream},
};
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
use std::{io, sync::Arc};
use std::{io, os::fd::AsRawFd, sync::Arc};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
@@ -70,7 +71,7 @@ impl AuthMethod for CleartextPassword {
/// This wrapper for [`PqStream`] performs client authentication.
#[must_use]
pub struct AuthFlow<'a, S, State> {
pub struct AuthFlow<'a, S: AsRawFd, State> {
/// The underlying stream which implements libpq's protocol.
stream: &'a mut PqStream<Stream<S>>,
/// State might contain ancillary data (see [`Self::begin`]).
@@ -79,7 +80,7 @@ pub struct AuthFlow<'a, S, State> {
}
/// Initial state of the stream wrapper.
impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
impl<'a, S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady> AuthFlow<'a, S, Begin> {
/// Create a new wrapper for client authentication.
pub fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
let tls_server_end_point = stream.get_ref().tls_server_end_point();
@@ -105,7 +106,9 @@ impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
impl<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>
AuthFlow<'_, S, PasswordHack>
{
/// Perform user authentication. Raise an error in case authentication failed.
pub async fn get_password(self) -> super::Result<PasswordHackPayload> {
let msg = self.stream.read_password_message().await?;
@@ -124,7 +127,9 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
impl<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>
AuthFlow<'_, S, CleartextPassword>
{
/// Perform user authentication. Raise an error in case authentication failed.
pub async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
let msg = self.stream.read_password_message().await?;
@@ -149,7 +154,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
}
/// Stream wrapper for handling [SCRAM](crate::scram) auth.
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
impl<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady> AuthFlow<'_, S, Scram<'_>> {
/// Perform user authentication. Raise an error in case authentication failed.
pub async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
let Scram(secret, ctx) = self.state;

View File

@@ -1,3 +1,4 @@
use std::os::fd::AsRawFd;
/// A stand-alone program that routes connections, e.g. from
/// `aaa--bbb--1234.external.domain` to `aaa.bbb.internal.domain:1234`.
///
@@ -7,9 +8,9 @@ use std::{net::SocketAddr, sync::Arc};
use futures::future::Either;
use itertools::Itertools;
use proxy::config::TlsServerEndPoint;
use proxy::context::RequestMonitoring;
use proxy::metrics::{Metrics, ThreadPoolMetrics};
use proxy::proxy::handshake::KtlsAsyncReadReady;
use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource};
use rustls::pki_types::PrivateKeyDer;
use tokio::net::TcpListener;
@@ -20,6 +21,7 @@ use futures::TryFutureExt;
use proxy::stream::{PqStream, Stream};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::server::TlsStream;
use tokio_util::sync::CancellationToken;
use utils::{project_git_version, sentry_init::init_sentry};
@@ -72,7 +74,7 @@ async fn main() -> anyhow::Result<()> {
let destination: String = args.get_one::<String>("dest").unwrap().parse()?;
// Configure TLS
let (tls_config, tls_server_end_point): (Arc<rustls::ServerConfig>, TlsServerEndPoint) = match (
let tls_config = match (
args.get_one::<String>("tls-key"),
args.get_one::<String>("tls-cert"),
) {
@@ -102,19 +104,14 @@ async fn main() -> anyhow::Result<()> {
})?
};
// needed for channel bindings
let first_cert = cert_chain.first().context("missing certificate")?;
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let tls_config = rustls::ServerConfig::builder_with_protocol_versions(&[
&rustls::version::TLS13,
&rustls::version::TLS12,
])
.with_no_client_auth()
.with_single_cert(cert_chain, key)?
.into();
(tls_config, tls_server_end_point)
Arc::new(
rustls::ServerConfig::builder_with_protocol_versions(&[
&rustls::version::TLS13,
&rustls::version::TLS12,
])
.with_no_client_auth()
.with_single_cert(cert_chain, key)?,
)
}
_ => bail!("tls-key and tls-cert must be specified"),
};
@@ -129,7 +126,6 @@ async fn main() -> anyhow::Result<()> {
let main = tokio::spawn(task_main(
Arc::new(destination),
tls_config,
tls_server_end_point,
proxy_listener,
cancellation_token.clone(),
));
@@ -151,7 +147,6 @@ async fn main() -> anyhow::Result<()> {
async fn task_main(
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
) -> anyhow::Result<()> {
@@ -183,7 +178,7 @@ async fn task_main(
proxy::metrics::Protocol::SniRouter,
"sni",
);
handle_client(ctx, dest_suffix, tls_config, tls_server_end_point, socket).await
handle_client(ctx, dest_suffix, tls_config, socket).await
}
.unwrap_or_else(|e| {
// Acknowledge that the task has finished with an error.
@@ -204,12 +199,11 @@ async fn task_main(
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>(
ctx: &RequestMonitoring,
raw_stream: S,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
) -> anyhow::Result<Stream<S>> {
) -> anyhow::Result<Box<TlsStream<S>>> {
let mut stream = PqStream::new(Stream::from_raw(raw_stream));
let msg = stream.read_startup_packet().await?;
@@ -235,13 +229,10 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
bail!("data is sent before server replied with EncryptionResponse");
}
Ok(Stream::Tls {
tls: Box::new(
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
.await?,
),
tls_server_end_point,
})
Ok(Box::new(
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
.await?,
))
}
unexpected => {
info!(
@@ -259,15 +250,18 @@ async fn handle_client(
ctx: RequestMonitoring,
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
stream: impl AsyncRead + AsyncWrite + Unpin,
stream: impl AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady,
) -> anyhow::Result<()> {
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?;
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config).await?;
// Cut off first part of the SNI domain
// We receive required destination details in the format of
// `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain`
let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?;
let sni = tls_stream
.get_ref()
.1
.server_name()
.ok_or(anyhow!("SNI missing"))?;
let dest: Vec<&str> = sni
.split_once('.')
.context("invalid SNI")?

View File

@@ -285,7 +285,7 @@ async fn main() -> anyhow::Result<()> {
};
let args = ProxyCliArgs::parse();
let config = build_config(&args)?;
let config = build_config(&args).await?;
info!("Authentication backend: {}", config.auth_backend);
info!("Using region: {}", args.aws_region);
@@ -529,16 +529,14 @@ async fn main() -> anyhow::Result<()> {
}
/// ProxyConfig is created at proxy startup, and lives forever.
fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
async fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
let thread_pool = ThreadPool::new(args.scram_thread_pool_size);
Metrics::install(thread_pool.metrics.clone());
let tls_config = match (&args.tls_key, &args.tls_cert) {
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(
key_path,
cert_path,
args.certs_dir.as_ref(),
)?),
(Some(key_path), Some(cert_path)) => {
Some(config::configure_tls(key_path, cert_path, args.certs_dir.as_ref()).await?)
}
(None, None) => None,
_ => bail!("either both or neither tls-key and tls-cert must be specified"),
};

View File

@@ -10,7 +10,7 @@ use anyhow::{bail, ensure, Context, Ok};
use itertools::Itertools;
use remote_storage::RemoteStorageConfig;
use rustls::{
crypto::ring::sign,
crypto::aws_lc_rs::sign,
pki_types::{CertificateDer, PrivateKeyDer},
};
use sha2::{Digest, Sha256};
@@ -76,7 +76,7 @@ impl TlsConfig {
pub const PG_ALPN_PROTOCOL: &[u8] = b"postgresql";
/// Configure TLS for the main endpoint.
pub fn configure_tls(
pub async fn configure_tls(
key_path: &str,
cert_path: &str,
certs_dir: Option<&String>,
@@ -110,13 +110,20 @@ pub fn configure_tls(
let cert_resolver = Arc::new(cert_resolver);
let provider = rustls::crypto::aws_lc_rs::default_provider();
#[cfg(target_os = "linux")]
let provider = {
let mut provider = provider;
let compat = ktls::CompatibleCiphers::new().await?;
provider.cipher_suites.retain(|s| compat.is_compatible(*s));
provider
};
// allow TLS 1.2 to be compatible with older client libraries
let mut config = rustls::ServerConfig::builder_with_protocol_versions(&[
&rustls::version::TLS13,
&rustls::version::TLS12,
])
.with_no_client_auth()
.with_cert_resolver(cert_resolver.clone());
let mut config = rustls::ServerConfig::builder_with_provider(Arc::new(provider))
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
.with_no_client_auth()
.with_cert_resolver(cert_resolver.clone());
config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()];

View File

@@ -1,92 +1,4 @@
// rustc lints/lint groups
// https://doc.rust-lang.org/rustc/lints/groups.html
#![deny(
deprecated,
future_incompatible,
// TODO: consider let_underscore
nonstandard_style,
rust_2024_compatibility
)]
#![warn(clippy::all, clippy::pedantic, clippy::cargo)]
// List of denied lints from the clippy::restriction group.
// https://rust-lang.github.io/rust-clippy/master/index.html#?groups=restriction
#![warn(
clippy::undocumented_unsafe_blocks,
clippy::dbg_macro,
clippy::empty_enum_variants_with_brackets,
clippy::exit,
clippy::float_cmp_const,
clippy::lossy_float_literal,
clippy::macro_use_imports,
clippy::manual_ok_or,
// TODO: consider clippy::map_err_ignore
// TODO: consider clippy::mem_forget
clippy::rc_mutex,
clippy::rest_pat_in_fully_bound_structs,
clippy::string_add,
clippy::string_to_string,
clippy::todo,
// TODO: consider clippy::unimplemented
// TODO: consider clippy::unwrap_used
)]
// List of permanently allowed lints.
#![allow(
// It's ok to cast u8 to bool, etc.
clippy::cast_lossless,
)]
// List of temporarily allowed lints.
// TODO: Switch to except() once stable with 1.81.
// TODO: fix code and reduce list or move to permanent list above.
#![allow(
clippy::cargo_common_metadata,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_precision_loss,
clippy::cast_sign_loss,
clippy::default_trait_access,
clippy::doc_markdown,
clippy::explicit_iter_loop,
clippy::float_cmp,
clippy::if_not_else,
clippy::ignored_unit_patterns,
clippy::implicit_hasher,
clippy::inconsistent_struct_constructor,
clippy::inline_always,
clippy::items_after_statements,
clippy::manual_assert,
clippy::manual_let_else,
clippy::manual_string_new,
clippy::match_bool,
clippy::match_same_arms,
clippy::match_wild_err_arm,
clippy::missing_errors_doc,
clippy::missing_panics_doc,
clippy::module_name_repetitions,
clippy::multiple_crate_versions,
clippy::must_use_candidate,
clippy::needless_for_each,
clippy::needless_pass_by_value,
clippy::needless_raw_string_hashes,
clippy::option_as_ref_cloned,
clippy::redundant_closure_for_method_calls,
clippy::redundant_else,
clippy::return_self_not_must_use,
clippy::similar_names,
clippy::single_char_pattern,
clippy::single_match_else,
clippy::struct_excessive_bools,
clippy::struct_field_names,
clippy::too_many_lines,
clippy::uninlined_format_args,
clippy::unnested_or_patterns,
clippy::unreadable_literal,
clippy::unused_async,
clippy::unused_self,
clippy::used_underscore_binding,
clippy::wildcard_imports
)]
// List of temporarily allowed lints to unblock beta/nightly.
#![allow(unknown_lints, clippy::manual_inspect)]
#![deny(clippy::undocumented_unsafe_blocks)]
use std::convert::Infallible;

View File

@@ -3,6 +3,7 @@
use std::{
io,
net::SocketAddr,
os::fd::AsRawFd,
pin::Pin,
task::{Context, Poll},
};
@@ -20,6 +21,23 @@ pin_project! {
}
}
impl<S: AsRawFd> AsRawFd for ChainRW<S> {
fn as_raw_fd(&self) -> std::os::unix::prelude::RawFd {
self.inner.as_raw_fd()
}
}
#[cfg(all(target_os = "linux", not(test)))]
impl<S: ktls::AsyncReadReady> ktls::AsyncReadReady for ChainRW<S> {
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
if self.buf.is_empty() {
self.inner.poll_read_ready(cx)
} else {
Poll::Ready(Ok(()))
}
}
}
impl<T: AsyncWrite> AsyncWrite for ChainRW<T> {
#[inline]
fn poll_write(

View File

@@ -1,5 +1,5 @@
#[cfg(test)]
mod tests;
pub mod tests;
pub mod connect_compute;
mod copy_bidirectional;
@@ -9,6 +9,7 @@ pub mod retry;
pub mod wake_compute;
pub use copy_bidirectional::copy_bidirectional_client_compute;
pub use copy_bidirectional::ErrorSource;
use handshake::KtlsAsyncReadReady;
use crate::{
auth,
@@ -21,7 +22,7 @@ use crate::{
protocol2::read_proxy_protocol,
proxy::handshake::{handshake, HandshakeData},
rate_limiter::EndpointRateLimiter,
stream::{PqStream, Stream},
stream::PqStream,
EndpointCacheKey,
};
use futures::TryFutureExt;
@@ -30,6 +31,7 @@ use once_cell::sync::OnceCell;
use pq_proto::{BeMessage as Be, StartupMessageParams};
use regex::Regex;
use smol_str::{format_smolstr, SmolStr};
use std::os::fd::AsRawFd;
use std::sync::Arc;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
@@ -191,13 +193,6 @@ impl ClientMode {
}
}
fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
match self {
ClientMode::Tcp => s.sni_hostname(),
ClientMode::Websockets { hostname } => hostname.as_deref(),
}
}
fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
match self {
ClientMode::Tcp => tls,
@@ -238,7 +233,7 @@ impl ReportableError for ClientRequestError {
}
}
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady>(
config: &'static ProxyConfig,
ctx: &RequestMonitoring,
cancellation_handler: Arc<CancellationHandlerMain>,
@@ -261,9 +256,9 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
let (mut stream, params) =
let (mut stream, ep, params) =
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Startup(stream, ep, params) => (stream, ep, params),
HandshakeData::Cancel(cancel_key_data) => {
return Ok(cancellation_handler
.cancel_session(cancel_key_data, ctx.session_id())
@@ -275,15 +270,11 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
ctx.set_db_options(params.clone());
let hostname = mode.hostname(stream.get_ref());
let common_names = tls.map(|tls| &tls.common_names);
// Extract credentials which we're going to use for auth.
let result = config
.auth_backend
.as_ref()
.map(|_| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names))
.map(|_| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, ep))
.transpose();
let user_info = match result {

View File

@@ -1,3 +1,5 @@
use std::os::fd::AsRawFd;
use bytes::Buf;
use pq_proto::{
framed::Framed, BeMessage as Be, CancelKeyData, FeStartupPacket, ProtocolVersion,
@@ -15,6 +17,7 @@ use crate::{
metrics::Metrics,
proxy::ERR_INSECURE_CONNECTION,
stream::{PqStream, Stream, StreamUpgradeError},
EndpointId,
};
#[derive(Error, Debug)]
@@ -31,6 +34,10 @@ pub enum HandshakeError {
#[error("{0}")]
StreamUpgradeError(#[from] StreamUpgradeError),
#[cfg(all(target_os = "linux", not(test)))]
#[error("{0}")]
KtlsUpgradeError(#[from] ktls::Error),
#[error("{0}")]
Io(#[from] std::io::Error),
@@ -43,6 +50,8 @@ impl ReportableError for HandshakeError {
match self {
HandshakeError::EarlyData => crate::error::ErrorKind::User,
HandshakeError::ProtocolViolation => crate::error::ErrorKind::User,
#[cfg(all(target_os = "linux", not(test)))]
HandshakeError::KtlsUpgradeError(_) => crate::error::ErrorKind::Service,
// This error should not happen, but will if we have no default certificate and
// the client sends no SNI extension.
// If they provide SNI then we can be sure there is a certificate that matches.
@@ -57,22 +66,39 @@ impl ReportableError for HandshakeError {
}
}
pub enum HandshakeData<S> {
Startup(PqStream<Stream<S>>, StartupMessageParams),
pub enum HandshakeData<S: AsRawFd> {
Startup(
PqStream<Stream<S>>,
Option<EndpointId>,
StartupMessageParams,
),
Cancel(CancelKeyData),
}
#[cfg(any(not(target_os = "linux"), test))]
pub trait KtlsAsyncReadReady {}
#[cfg(all(target_os = "linux", not(test)))]
pub trait KtlsAsyncReadReady: ktls::AsyncReadReady {}
#[cfg(any(not(target_os = "linux"), test))]
impl<K: AsyncRead> KtlsAsyncReadReady for K {}
#[cfg(all(target_os = "linux", not(test)))]
impl<K: ktls::AsyncReadReady> KtlsAsyncReadReady for K {}
/// Establish a (most probably, secure) connection with the client.
/// For better testing experience, `stream` can be any object satisfying the traits.
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
/// we also take an extra care of propagating only the select handshake errors to client.
#[tracing::instrument(skip_all)]
pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
pub async fn handshake<S>(
ctx: &RequestMonitoring,
stream: S,
mut tls: Option<&TlsConfig>,
record_handshake_error: bool,
) -> Result<HandshakeData<S>, HandshakeError> {
) -> Result<HandshakeData<S>, HandshakeError>
where
S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady,
{
// Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false);
@@ -80,6 +106,7 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0);
let mut stream = PqStream::new(Stream::from_raw(stream));
let mut ep = None;
loop {
let msg = stream.read_startup_packet().await?;
match msg {
@@ -113,6 +140,9 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
));
};
#[cfg(all(target_os = "linux", not(test)))]
let raw = ktls::CorkStream::new(raw);
let mut read_buf = read_buf.reader();
let mut res = Ok(());
let accept = tokio_rustls::TlsAcceptor::from(tls.to_server_config())
@@ -145,11 +175,11 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
let conn_info = tls_stream.get_ref().1;
// try parse endpoint
let ep = conn_info
ep = conn_info
.server_name()
.and_then(|sni| endpoint_sni(sni, &tls.common_names).ok().flatten());
if let Some(ep) = ep {
ctx.set_endpoint_id(ep);
if let Some(ep) = &ep {
ctx.set_endpoint_id(ep.clone());
}
// check the ALPN, if exists, as required.
@@ -170,7 +200,10 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream = PqStream {
framed: Framed {
stream: Stream::Tls {
tls: Box::new(tls_stream),
#[cfg(any(not(target_os = "linux"), test))]
tls: Box::pin(tls_stream),
#[cfg(all(target_os = "linux", not(test)))]
tls: ktls::config_ktls_server(tls_stream).await?,
tls_server_end_point,
},
read_buf,
@@ -207,7 +240,7 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
session_type = "normal",
"successful handshake"
);
break Ok(HandshakeData::Startup(stream, params));
break Ok(HandshakeData::Startup(stream, ep, params));
}
// downgrade protocol version
FeStartupPacket::StartupMessage { params, version }
@@ -238,7 +271,7 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
session_type = "normal",
"successful handshake; unsupported minor version requested"
);
break Ok(HandshakeData::Startup(stream, params));
break Ok(HandshakeData::Startup(stream, ep, params));
}
FeStartupPacket::StartupMessage { version, .. } => {
warn!(

View File

@@ -1,3 +1,5 @@
use std::os::fd::AsRawFd;
use crate::{
cancellation,
compute::PostgresConnection,
@@ -10,7 +12,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
use utils::measured_stream::MeasuredStream;
use super::copy_bidirectional::ErrorSource;
use super::{copy_bidirectional::ErrorSource, handshake::KtlsAsyncReadReady};
/// Forward bytes in both directions (client <-> compute).
#[tracing::instrument(skip_all)]
@@ -57,7 +59,7 @@ pub async fn proxy_pass(
Ok(())
}
pub struct ProxyPassthrough<P, S> {
pub struct ProxyPassthrough<P, S: AsRawFd> {
pub client: Stream<S>,
pub compute: PostgresConnection,
pub aux: MetricsAuxInfo,
@@ -67,7 +69,7 @@ pub struct ProxyPassthrough<P, S> {
pub cancel: cancellation::Session<P>,
}
impl<P, S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<P, S> {
impl<P, S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady> ProxyPassthrough<P, S> {
pub async fn proxy_pass(self) -> Result<(), ErrorSource> {
let res = proxy_pass(self.client, self.compute.stream, self.aux).await;
if let Err(err) = self.compute.cancel_closure.try_cancel_query().await {

View File

@@ -2,6 +2,8 @@
mod mitm;
use std::pin::Pin;
use std::task::Poll;
use std::time::Duration;
use super::connect_compute::ConnectMechanism;
@@ -16,12 +18,14 @@ use crate::console::messages::{ConsoleError, Details, MetricsAuxInfo, Status};
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend};
use crate::console::{self, CachedNodeInfo, NodeInfo};
use crate::error::ErrorKind;
use crate::stream::Stream;
use crate::{http, sasl, scram, BranchId, EndpointId, ProjectId};
use anyhow::{bail, Context};
use async_trait::async_trait;
use retry::{retry_after, ShouldRetryWakeCompute};
use rstest::rstest;
use rustls::pki_types;
use tokio::io::DuplexStream;
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream};
@@ -35,28 +39,73 @@ fn generate_certs(
pki_types::CertificateDer<'static>,
pki_types::PrivateKeyDer<'static>,
)> {
let ca = rcgen::Certificate::from_params({
let ca_key = rcgen::KeyPair::generate()?;
let cert_key = rcgen::KeyPair::generate()?;
let ca = {
let mut params = rcgen::CertificateParams::default();
params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
params
})?;
params.self_signed(&ca_key)?
};
let cert = rcgen::Certificate::from_params({
let mut params = rcgen::CertificateParams::new(vec![hostname.into()]);
let cert = {
let mut params = rcgen::CertificateParams::new(vec![hostname.into()])?;
params.distinguished_name = rcgen::DistinguishedName::new();
params
.distinguished_name
.push(rcgen::DnType::CommonName, common_name);
params
})?;
params.signed_by(&cert_key, &ca, &ca_key)?
};
Ok((
pki_types::CertificateDer::from(ca.serialize_der()?),
pki_types::CertificateDer::from(cert.serialize_der_with_signer(&ca)?),
pki_types::PrivateKeyDer::Pkcs8(cert.serialize_private_key_der().into()),
ca.into(),
cert.into(),
pki_types::PrivateKeyDer::Pkcs8(cert_key.serialize_der().into()),
))
}
pub struct DummyClient(pub DuplexStream);
impl AsRawFd for DummyClient {
fn as_raw_fd(&self) -> std::os::unix::prelude::RawFd {
unreachable!()
}
}
impl AsyncWrite for DummyClient {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.0).poll_flush(cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}
}
impl AsyncRead for DummyClient {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}
struct ClientConfig<'a> {
config: rustls::ClientConfig,
hostname: &'a str,
@@ -121,7 +170,9 @@ fn generate_tls_config<'a>(
#[async_trait]
trait TestAuth: Sized {
async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
async fn authenticate<
S: AsyncRead + AsyncWrite + Unpin + Send + AsRawFd + KtlsAsyncReadReady,
>(
self,
stream: &mut PqStream<Stream<S>>,
) -> anyhow::Result<()> {
@@ -150,7 +201,9 @@ impl Scram {
#[async_trait]
impl TestAuth for Scram {
async fn authenticate<S: AsyncRead + AsyncWrite + Unpin + Send>(
async fn authenticate<
S: AsyncRead + AsyncWrite + Unpin + Send + AsRawFd + KtlsAsyncReadReady,
>(
self,
stream: &mut PqStream<Stream<S>>,
) -> anyhow::Result<()> {
@@ -170,14 +223,14 @@ impl TestAuth for Scram {
/// A dummy proxy impl which performs a handshake and reports auth success.
async fn dummy_proxy(
client: impl AsyncRead + AsyncWrite + Unpin + Send,
client: impl AsyncRead + AsyncWrite + Unpin + Send + AsRawFd,
tls: Option<TlsConfig>,
auth: impl TestAuth + Send,
) -> anyhow::Result<()> {
let (client, _) = read_proxy_protocol(client).await?;
let mut stream =
match handshake(&RequestMonitoring::test(), client, tls.as_ref(), false).await? {
HandshakeData::Startup(stream, _) => stream,
HandshakeData::Startup(stream, ..) => stream,
HandshakeData::Cancel(_) => bail!("cancellation not supported"),
};
@@ -196,7 +249,11 @@ async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let (_, server_config) = generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
let proxy = tokio::spawn(dummy_proxy(
DummyClient(client),
Some(server_config),
NoAuth,
));
let client_err = tokio_postgres::Config::new()
.user("john_doe")
@@ -225,7 +282,11 @@ async fn handshake_tls() -> anyhow::Result<()> {
let (client_config, server_config) =
generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
let proxy = tokio::spawn(dummy_proxy(
DummyClient(client),
Some(server_config),
NoAuth,
));
let (_client, _conn) = tokio_postgres::Config::new()
.user("john_doe")
@@ -241,7 +302,7 @@ async fn handshake_tls() -> anyhow::Result<()> {
async fn handshake_raw() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth));
let proxy = tokio::spawn(dummy_proxy(DummyClient(client), None, NoAuth));
let (_client, _conn) = tokio_postgres::Config::new()
.user("john_doe")
@@ -285,7 +346,7 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
let (client_config, server_config) =
generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(
client,
DummyClient(client),
Some(server_config),
Scram::new(password).await?,
));
@@ -309,7 +370,7 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
let (client_config, server_config) =
generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(
client,
DummyClient(client),
Some(server_config),
Scram::new("password").await?,
));
@@ -332,7 +393,11 @@ async fn scram_auth_mock() -> anyhow::Result<()> {
let (client_config, server_config) =
generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), Scram::mock()));
let proxy = tokio::spawn(dummy_proxy(
DummyClient(client),
Some(server_config),
Scram::mock(),
));
use rand::{distributions::Alphanumeric, Rng};
let password: String = rand::thread_rng()

View File

@@ -36,14 +36,14 @@ async fn proxy_mitm(
let end_server = connect_tls(server2, client_config2.make_tls_connect().unwrap()).await;
let (end_client, startup) = match handshake(
&RequestMonitoring::test(),
client1,
DummyClient(client1),
Some(&server_config1),
false,
)
.await
.unwrap()
{
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Startup(stream, _ep, params) => (stream, params),
HandshakeData::Cancel(_) => panic!("cancellation not supported"),
};
@@ -154,7 +154,7 @@ impl Encoder<Bytes> for PgFrame {
async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
let (server, client, client_config, server_config) = proxy_mitm(Intercept::None).await;
let proxy = tokio::spawn(dummy_proxy(
client,
DummyClient(client),
Some(server_config),
Scram::new("password").await?,
));
@@ -237,7 +237,7 @@ async fn connect_failure(
) -> anyhow::Result<()> {
let (server, client, client_config, server_config) = proxy_mitm(intercept).await;
let proxy = tokio::spawn(dummy_proxy(
client,
DummyClient(client),
Some(server_config),
Scram::new("password").await?,
));

View File

@@ -190,7 +190,19 @@ trait MaybeTlsAcceptor: Send + Sync + 'static {
#[async_trait]
impl MaybeTlsAcceptor for rustls::ServerConfig {
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW> {
Ok(Box::pin(TlsAcceptor::from(self).accept(conn).await?))
#[cfg(all(target_os = "linux", not(test)))]
let conn = ktls::CorkStream::new(conn);
let tls = TlsAcceptor::from(self).accept(conn).await?;
#[cfg(all(target_os = "linux", not(test)))]
return ktls::config_ktls_server(tls)
.await
.map(|s| Box::pin(s) as _)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));
#[cfg(any(not(target_os = "linux"), test))]
Ok(Box::pin(tls))
}
}

View File

@@ -16,6 +16,7 @@ use hyper1::upgrade::OnUpgrade;
use hyper_util::rt::TokioIo;
use pin_project_lite::pin_project;
use std::os::fd::AsRawFd;
use std::{
pin::Pin,
sync::Arc,
@@ -45,6 +46,18 @@ impl<S> WebSocketRw<S> {
}
}
impl<S> AsRawFd for WebSocketRw<S> {
fn as_raw_fd(&self) -> std::os::unix::prelude::RawFd {
unreachable!("ktls should not need to be used for websocket rw")
}
}
#[cfg(all(target_os = "linux", not(test)))]
impl<S> ktls::AsyncReadReady for WebSocketRw<S> {
fn poll_read_ready(&self, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
unreachable!("ktls should not need to be used for websocket rw")
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for WebSocketRw<S> {
fn poll_write(
self: Pin<&mut Self>,

View File

@@ -1,11 +1,13 @@
use crate::config::TlsServerEndPoint;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::Metrics;
use crate::proxy::handshake::KtlsAsyncReadReady;
use bytes::BytesMut;
use pq_proto::framed::{ConnectionError, Framed};
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
use rustls::ServerConfig;
use std::os::fd::AsRawFd;
use std::pin::Pin;
use std::sync::Arc;
use std::{io, task};
@@ -172,34 +174,31 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
}
/// Wrapper for upgrading raw streams into secure streams.
pub enum Stream<S> {
pub enum Stream<S: AsRawFd> {
/// We always begin with a raw stream,
/// which may then be upgraded into a secure stream.
Raw { raw: S },
Tls {
/// We box [`TlsStream`] since it can be quite large.
tls: Box<TlsStream<S>>,
#[cfg(any(not(target_os = "linux"), test))]
tls: Pin<Box<TlsStream<S>>>,
#[cfg(all(target_os = "linux", not(test)))]
tls: ktls::KtlsStream<S>,
/// Channel binding parameter
tls_server_end_point: TlsServerEndPoint,
},
}
impl<S: Unpin> Unpin for Stream<S> {}
impl<S: Unpin + AsRawFd> Unpin for Stream<S> {}
impl<S> Stream<S> {
impl<S: AsRawFd> Stream<S> {
/// Construct a new instance from a raw stream.
pub fn from_raw(raw: S) -> Self {
Self::Raw { raw }
}
/// Return SNI hostname when it's available.
pub fn sni_hostname(&self) -> Option<&str> {
match self {
Stream::Raw { .. } => None,
Stream::Tls { tls, .. } => tls.get_ref().1.server_name(),
}
}
pub fn tls_server_end_point(&self) -> TlsServerEndPoint {
match self {
Stream::Raw { .. } => TlsServerEndPoint::Undefined,
@@ -221,7 +220,7 @@ pub enum StreamUpgradeError {
Io(#[from] io::Error),
}
impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
impl<S: AsyncRead + AsyncWrite + Unpin + AsRawFd> Stream<S> {
/// If possible, upgrade raw stream into a secure TLS-based stream.
pub async fn upgrade(
self,
@@ -242,7 +241,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
impl<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady> AsyncRead for Stream<S> {
fn poll_read(
mut self: Pin<&mut Self>,
context: &mut task::Context<'_>,
@@ -255,7 +254,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
impl<S: AsyncRead + AsyncWrite + Unpin + AsRawFd + KtlsAsyncReadReady> AsyncWrite for Stream<S> {
fn poll_write(
mut self: Pin<&mut Self>,
context: &mut task::Context<'_>,

View File

@@ -114,16 +114,6 @@ fn check_permission(request: &Request<Body>, tenant_id: Option<TenantId>) -> Res
})
}
/// List all (not deleted) timelines.
async fn timeline_list_handler(request: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permission(&request, None)?;
let res: Vec<TenantTimelineId> = GlobalTimelines::get_all()
.iter()
.map(|tli| tli.ttid)
.collect();
json_response(StatusCode::OK, res)
}
/// Report info about timeline.
async fn timeline_status_handler(request: Request<Body>) -> Result<Response<Body>, ApiError> {
let ttid = TenantTimelineId::new(
@@ -572,9 +562,6 @@ pub fn make_router(conf: SafeKeeperConf) -> RouterBuilder<hyper::Body, ApiError>
.post("/v1/tenant/timeline", |r| {
request_span(r, timeline_create_handler)
})
.get("/v1/tenant/timeline", |r| {
request_span(r, timeline_list_handler)
})
.get("/v1/tenant/:tenant_id/timeline/:timeline_id", |r| {
request_span(r, timeline_status_handler)
})

View File

@@ -18,7 +18,6 @@ import psycopg2
from psycopg2.extras import execute_values
CREATE_TABLE = """
CREATE TYPE arch AS ENUM ('ARM64', 'X64', 'UNKNOWN');
CREATE TABLE IF NOT EXISTS results (
id BIGSERIAL PRIMARY KEY,
parent_suite TEXT NOT NULL,
@@ -29,7 +28,6 @@ CREATE TABLE IF NOT EXISTS results (
stopped_at TIMESTAMPTZ NOT NULL,
duration INT NOT NULL,
flaky BOOLEAN NOT NULL,
arch arch DEFAULT 'X64',
build_type TEXT NOT NULL,
pg_version INT NOT NULL,
run_id BIGINT NOT NULL,
@@ -37,7 +35,7 @@ CREATE TABLE IF NOT EXISTS results (
reference TEXT NOT NULL,
revision CHAR(40) NOT NULL,
raw JSONB COMPRESSION lz4 NOT NULL,
UNIQUE (parent_suite, suite, name, arch, build_type, pg_version, started_at, stopped_at, run_id)
UNIQUE (parent_suite, suite, name, build_type, pg_version, started_at, stopped_at, run_id)
);
"""
@@ -52,7 +50,6 @@ class Row:
stopped_at: datetime
duration: int
flaky: bool
arch: str
build_type: str
pg_version: int
run_id: int
@@ -124,14 +121,6 @@ def ingest_test_result(
raw.pop("labels")
raw.pop("extra")
# All allure parameters are prefixed with "__", see test_runner/fixtures/parametrize.py
parameters = {
p["name"].removeprefix("__"): p["value"]
for p in test["parameters"]
if p["name"].startswith("__")
}
arch = parameters.get("arch", "UNKNOWN").strip("'")
build_type, pg_version, unparametrized_name = parse_test_name(test["name"])
labels = {label["name"]: label["value"] for label in test["labels"]}
row = Row(
@@ -143,7 +132,6 @@ def ingest_test_result(
stopped_at=datetime.fromtimestamp(test["time"]["stop"] / 1000, tz=timezone.utc),
duration=test["time"]["duration"],
flaky=test["flaky"] or test["retriesStatusChange"],
arch=arch,
build_type=build_type,
pg_version=pg_version,
run_id=run_id,

View File

@@ -1,7 +1,7 @@
import random
from dataclasses import dataclass
from functools import total_ordering
from typing import Any, Dict, Type, TypeVar, Union
from typing import Any, Type, TypeVar, Union
T = TypeVar("T", bound="Id")
@@ -147,19 +147,6 @@ class TimelineId(Id):
return self.id.hex()
@dataclass
class TenantTimelineId:
tenant_id: TenantId
timeline_id: TimelineId
@classmethod
def from_json(cls, d: Dict[str, Any]) -> "TenantTimelineId":
return TenantTimelineId(
tenant_id=TenantId(d["tenant_id"]),
timeline_id=TimelineId(d["timeline_id"]),
)
# Workaround for compat with python 3.9, which does not have `typing.Self`
TTenantShardId = TypeVar("TTenantShardId", bound="TenantShardId")

View File

@@ -61,6 +61,8 @@ from fixtures.pageserver.common_types import IndexPartDump, LayerName, parse_lay
from fixtures.pageserver.http import PageserverHttpClient
from fixtures.pageserver.utils import (
wait_for_last_record_lsn,
wait_for_upload,
wait_for_upload_queue_empty,
)
from fixtures.pg_version import PgVersion
from fixtures.port_distributor import PortDistributor
@@ -5345,7 +5347,9 @@ def last_flush_lsn_upload(
for tenant_shard_id, pageserver in shards:
ps_http = pageserver.http_client(auth_token=auth_token)
wait_for_last_record_lsn(ps_http, tenant_shard_id, timeline_id, last_flush_lsn)
ps_http.timeline_checkpoint(tenant_shard_id, timeline_id, wait_until_uploaded=True)
# force a checkpoint to trigger upload
ps_http.timeline_checkpoint(tenant_shard_id, timeline_id)
wait_for_upload(ps_http, tenant_shard_id, timeline_id, last_flush_lsn)
return last_flush_lsn
@@ -5430,5 +5434,9 @@ def generate_uploads_and_deletions(
# ensures that the pageserver is in a fully idle state: there will be no more
# background ingest, no more uploads pending, and therefore no non-determinism
# in subsequent actions like pageserver restarts.
flush_ep_to_pageserver(env, endpoint, tenant_id, timeline_id, pageserver.id)
ps_http.timeline_checkpoint(tenant_id, timeline_id, wait_until_uploaded=True)
final_lsn = flush_ep_to_pageserver(env, endpoint, tenant_id, timeline_id, pageserver.id)
ps_http.timeline_checkpoint(tenant_id, timeline_id)
# Finish uploads
wait_for_upload(ps_http, tenant_id, timeline_id, final_lsn)
# Finish all remote writes (including deletions)
wait_for_upload_queue_empty(ps_http, tenant_id, timeline_id)

View File

@@ -1,7 +1,6 @@
import os
from typing import Any, Dict, Optional
import allure
import pytest
import toml
from _pytest.python import Metafunc
@@ -92,23 +91,3 @@ def pytest_generate_tests(metafunc: Metafunc):
and (platform := os.getenv("PLATFORM")) is not None
):
metafunc.parametrize("platform", [platform.lower()])
@pytest.hookimpl(hookwrapper=True, tryfirst=True)
def pytest_runtest_makereport(*args, **kwargs):
# Add test parameters to Allue report to distinguish the same tests with different parameters.
# Names has `__` prefix to avoid conflicts with `pytest.mark.parametrize` parameters
# A mapping between `uname -m` and `RUNNER_ARCH` values.
# `RUNNER_ARCH` environment variable is set on GitHub Runners,
# possible values are X86, X64, ARM, or ARM64.
# See https://docs.github.com/en/actions/learn-github-actions/variables#default-environment-variables
uname_m = {
"aarch64": "ARM64",
"arm64": "ARM64",
"x86_64": "X64",
}.get(os.uname().machine, "UNKNOWN")
arch = os.getenv("RUNNER_ARCH", uname_m)
allure.dynamic.parameter("__arch", arch)
yield

View File

@@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import pytest
import requests
from fixtures.common_types import Lsn, TenantId, TenantTimelineId, TimelineId
from fixtures.common_types import Lsn, TenantId, TimelineId
from fixtures.log_helper import log
from fixtures.metrics import Metrics, MetricsGetter, parse_metrics
@@ -144,12 +144,6 @@ class SafekeeperHttpClient(requests.Session, MetricsGetter):
assert isinstance(res_json, dict)
return res_json
def timeline_list(self) -> List[TenantTimelineId]:
res = self.get(f"http://localhost:{self.port}/v1/tenant/timeline")
res.raise_for_status()
resj = res.json()
return [TenantTimelineId.from_json(ttidj) for ttidj in resj]
def timeline_create(
self,
tenant_id: TenantId,

View File

@@ -10,7 +10,7 @@ from fixtures.neon_fixtures import (
tenant_get_shards,
wait_for_last_flush_lsn,
)
from fixtures.pageserver.utils import wait_for_last_record_lsn
from fixtures.pageserver.utils import wait_for_last_record_lsn, wait_for_upload
# neon_local doesn't handle creating/modifying endpoints concurrently, so we use a mutex
# to ensure we don't do that: this enables running lots of Workloads in parallel safely.
@@ -174,9 +174,8 @@ class Workload:
if upload:
# Wait for written data to be uploaded to S3 (force a checkpoint to trigger upload)
ps_http.timeline_checkpoint(
tenant_shard_id, self.timeline_id, wait_until_uploaded=True
)
ps_http.timeline_checkpoint(tenant_shard_id, self.timeline_id)
wait_for_upload(ps_http, tenant_shard_id, self.timeline_id, last_flush_lsn)
log.info(f"Churn: waiting for remote LSN {last_flush_lsn}")
else:
log.info(f"Churn: not waiting for upload, disk LSN {last_flush_lsn}")

View File

@@ -5,12 +5,8 @@ from typing import Any, Dict, Tuple
import pytest
from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker
from fixtures.log_helper import log
from fixtures.neon_fixtures import (
NeonEnv,
NeonEnvBuilder,
PgBin,
flush_ep_to_pageserver,
)
from fixtures.neon_fixtures import NeonEnv, NeonEnvBuilder, PgBin, wait_for_last_flush_lsn
from fixtures.pageserver.utils import wait_for_upload_queue_empty
from fixtures.remote_storage import s3_storage
from fixtures.utils import humantime_to_ms
@@ -66,6 +62,9 @@ def test_download_churn(
run_benchmark(env, pg_bin, record, io_engine, concurrency_per_target, duration)
# see https://github.com/neondatabase/neon/issues/8712
env.stop(immediate=True)
def setup_env(neon_env_builder: NeonEnvBuilder, pg_bin: PgBin):
remote_storage_kind = s3_storage()
@@ -99,9 +98,9 @@ def setup_env(neon_env_builder: NeonEnvBuilder, pg_bin: PgBin):
f"INSERT INTO data SELECT lpad(i::text, {bytes_per_row}, '0') FROM generate_series(1, {int(nrows)}) as i",
options="-c statement_timeout=0",
)
flush_ep_to_pageserver(env, ep, tenant_id, timeline_id)
client.timeline_checkpoint(tenant_id, timeline_id, compact=False, wait_until_uploaded=True)
wait_for_last_flush_lsn(env, ep, tenant_id, timeline_id)
# TODO: this is a bit imprecise, there could be frozen layers being written out that we don't observe here
wait_for_upload_queue_empty(client, tenant_id, timeline_id)
return env

View File

@@ -1,21 +1,20 @@
import time
from fixtures.neon_fixtures import NeonEnvBuilder, flush_ep_to_pageserver
from fixtures.neon_fixtures import NeonEnvBuilder
#
# Benchmark searching the layer map, when there are a lot of small layer files.
#
def test_layer_map(neon_env_builder: NeonEnvBuilder, zenbenchmark):
"""Benchmark searching the layer map, when there are a lot of small layer files."""
env = neon_env_builder.init_configs()
env = neon_env_builder.init_start()
n_iters = 10
n_records = 100000
env.start()
# We want to have a lot of lot of layer files to exercise the layer map. Disable
# GC, and make checkpoint_distance very small, so that we get a lot of small layer
# files.
tenant, timeline = env.neon_cli.create_tenant(
tenant, _ = env.neon_cli.create_tenant(
conf={
"gc_period": "0s",
"checkpoint_distance": "16384",
@@ -25,7 +24,8 @@ def test_layer_map(neon_env_builder: NeonEnvBuilder, zenbenchmark):
}
)
endpoint = env.endpoints.create_start("main", tenant_id=tenant)
env.neon_cli.create_timeline("test_layer_map", tenant_id=tenant)
endpoint = env.endpoints.create_start("test_layer_map", tenant_id=tenant)
cur = endpoint.connect().cursor()
cur.execute("create table t(x integer)")
for _ in range(n_iters):
@@ -33,12 +33,9 @@ def test_layer_map(neon_env_builder: NeonEnvBuilder, zenbenchmark):
time.sleep(1)
cur.execute("vacuum t")
with zenbenchmark.record_duration("test_query"):
cur.execute("SELECT count(*) from t")
assert cur.fetchone() == (n_iters * n_records,)
flush_ep_to_pageserver(env, endpoint, tenant, timeline)
env.pageserver.http_client().timeline_checkpoint(
tenant, timeline, compact=False, wait_until_uploaded=True
)
# see https://github.com/neondatabase/neon/issues/8712
env.stop(immediate=True)

View File

@@ -1,4 +1,4 @@
from fixtures.neon_fixtures import NeonEnvBuilder, flush_ep_to_pageserver
from fixtures.neon_fixtures import NeonEnvBuilder
def do_combocid_op(neon_env_builder: NeonEnvBuilder, op):
@@ -34,7 +34,7 @@ def do_combocid_op(neon_env_builder: NeonEnvBuilder, op):
# Clear the cache, so that we exercise reconstructing the pages
# from WAL
endpoint.clear_shared_buffers()
cur.execute("SELECT clear_buffer_cache()")
# Check that the cursor opened earlier still works. If the
# combocids are not restored correctly, it won't.
@@ -43,10 +43,6 @@ def do_combocid_op(neon_env_builder: NeonEnvBuilder, op):
assert len(rows) == 500
cur.execute("rollback")
flush_ep_to_pageserver(env, endpoint, env.initial_tenant, env.initial_timeline)
env.pageserver.http_client().timeline_checkpoint(
env.initial_tenant, env.initial_timeline, compact=False, wait_until_uploaded=True
)
def test_combocid_delete(neon_env_builder: NeonEnvBuilder):
@@ -96,7 +92,7 @@ def test_combocid_multi_insert(neon_env_builder: NeonEnvBuilder):
cur.execute("delete from t")
# Clear the cache, so that we exercise reconstructing the pages
# from WAL
endpoint.clear_shared_buffers()
cur.execute("SELECT clear_buffer_cache()")
# Check that the cursor opened earlier still works. If the
# combocids are not restored correctly, it won't.
@@ -106,11 +102,6 @@ def test_combocid_multi_insert(neon_env_builder: NeonEnvBuilder):
cur.execute("rollback")
flush_ep_to_pageserver(env, endpoint, env.initial_tenant, env.initial_timeline)
env.pageserver.http_client().timeline_checkpoint(
env.initial_tenant, env.initial_timeline, compact=False, wait_until_uploaded=True
)
def test_combocid(neon_env_builder: NeonEnvBuilder):
env = neon_env_builder.init_start()
@@ -146,8 +137,3 @@ def test_combocid(neon_env_builder: NeonEnvBuilder):
assert cur.rowcount == n_records
cur.execute("rollback")
flush_ep_to_pageserver(env, endpoint, env.initial_tenant, env.initial_timeline)
env.pageserver.http_client().timeline_checkpoint(
env.initial_tenant, env.initial_timeline, compact=False, wait_until_uploaded=True
)

View File

@@ -9,17 +9,14 @@ from typing import List, Optional
import pytest
import toml
from fixtures.common_types import TenantId, TimelineId
from fixtures.common_types import Lsn, TenantId, TimelineId
from fixtures.log_helper import log
from fixtures.neon_fixtures import (
NeonEnv,
NeonEnvBuilder,
PgBin,
flush_ep_to_pageserver,
)
from fixtures.neon_fixtures import NeonEnv, NeonEnvBuilder, PgBin
from fixtures.pageserver.http import PageserverApiException
from fixtures.pageserver.utils import (
timeline_delete_wait_completed,
wait_for_last_record_lsn,
wait_for_upload,
)
from fixtures.pg_version import PgVersion
from fixtures.remote_storage import RemoteStorageKind, S3Storage, s3_storage
@@ -125,9 +122,11 @@ def test_create_snapshot(
timeline_id = dict(snapshot_config["branch_name_mappings"]["main"])[tenant_id]
pageserver_http = env.pageserver.http_client()
lsn = Lsn(endpoint.safe_psql("SELECT pg_current_wal_flush_lsn()")[0][0])
flush_ep_to_pageserver(env, endpoint, tenant_id, timeline_id)
pageserver_http.timeline_checkpoint(tenant_id, timeline_id, wait_until_uploaded=True)
wait_for_last_record_lsn(pageserver_http, tenant_id, timeline_id, lsn)
pageserver_http.timeline_checkpoint(tenant_id, timeline_id)
wait_for_upload(pageserver_http, tenant_id, timeline_id, lsn)
env.endpoints.stop_all()
for sk in env.safekeepers:
@@ -301,7 +300,7 @@ def check_neon_works(env: NeonEnv, test_output_dir: Path, sql_dump_path: Path, r
pg_version = env.pg_version
# Stop endpoint while we recreate timeline
flush_ep_to_pageserver(env, ep, tenant_id, timeline_id)
ep.stop()
try:
pageserver_http.timeline_preserve_initdb_archive(tenant_id, timeline_id)
@@ -349,11 +348,6 @@ def check_neon_works(env: NeonEnv, test_output_dir: Path, sql_dump_path: Path, r
assert not dump_from_wal_differs, "dump from WAL differs"
assert not initial_dump_differs, "initial dump differs"
flush_ep_to_pageserver(env, ep, tenant_id, timeline_id)
pageserver_http.timeline_checkpoint(
tenant_id, timeline_id, compact=False, wait_until_uploaded=True
)
def dump_differs(
first: Path, second: Path, output: Path, allowed_diffs: Optional[List[str]] = None

View File

@@ -18,6 +18,7 @@ from fixtures.neon_fixtures import (
from fixtures.pageserver.utils import (
timeline_delete_wait_completed,
wait_for_last_record_lsn,
wait_for_upload,
)
from fixtures.remote_storage import RemoteStorageKind
from fixtures.utils import assert_pageserver_backups_equal, subprocess_capture
@@ -143,7 +144,7 @@ def test_import_from_vanilla(test_output_dir, pg_bin, vanilla_pg, neon_env_build
# Wait for data to land in s3
wait_for_last_record_lsn(client, tenant, timeline, Lsn(end_lsn))
client.timeline_checkpoint(tenant, timeline, compact=False, wait_until_uploaded=True)
wait_for_upload(client, tenant, timeline, Lsn(end_lsn))
# Check it worked
endpoint = env.endpoints.create_start(branch_name, tenant_id=tenant)
@@ -289,7 +290,7 @@ def _import(
# Wait for data to land in s3
wait_for_last_record_lsn(client, tenant, timeline, lsn)
client.timeline_checkpoint(tenant, timeline, compact=False, wait_until_uploaded=True)
wait_for_upload(client, tenant, timeline, lsn)
# Check it worked
endpoint = env.endpoints.create_start(branch_name, tenant_id=tenant, lsn=lsn)

View File

@@ -1,31 +1,27 @@
import os
import time
import pytest
from fixtures.log_helper import log
from fixtures.neon_fixtures import (
NeonEnvBuilder,
NeonEnv,
logical_replication_sync,
wait_for_last_flush_lsn,
)
from fixtures.pg_version import PgVersion
def test_layer_bloating(neon_env_builder: NeonEnvBuilder, vanilla_pg):
if neon_env_builder.pg_version != PgVersion.V16:
def test_layer_bloating(neon_simple_env: NeonEnv, vanilla_pg):
env = neon_simple_env
if env.pg_version != PgVersion.V16:
pytest.skip("pg_log_standby_snapshot() function is available only in PG16")
env = neon_env_builder.init_start(
initial_tenant_conf={
"gc_period": "0s",
"compaction_period": "0s",
"compaction_threshold": 99999,
"image_creation_threshold": 99999,
}
timeline = env.neon_cli.create_branch("test_logical_replication", "empty")
endpoint = env.endpoints.create_start(
"test_logical_replication", config_lines=["log_statement=all"]
)
timeline = env.initial_timeline
endpoint = env.endpoints.create_start("main", config_lines=["log_statement=all"])
pg_conn = endpoint.connect()
cur = pg_conn.cursor()
@@ -58,7 +54,7 @@ def test_layer_bloating(neon_env_builder: NeonEnvBuilder, vanilla_pg):
# Wait logical replication to sync
logical_replication_sync(vanilla_pg, endpoint)
wait_for_last_flush_lsn(env, endpoint, env.initial_tenant, timeline)
env.pageserver.http_client().timeline_checkpoint(env.initial_tenant, timeline, compact=False)
time.sleep(10)
# Check layer file sizes
timeline_path = f"{env.pageserver.workdir}/tenants/{env.initial_tenant}/timelines/{timeline}/"
@@ -67,5 +63,3 @@ def test_layer_bloating(neon_env_builder: NeonEnvBuilder, vanilla_pg):
if filename.startswith("00000"):
log.info(f"layer {filename} size is {os.path.getsize(timeline_path + filename)}")
assert os.path.getsize(timeline_path + filename) < 512_000_000
env.stop(immediate=True)

View File

@@ -152,9 +152,6 @@ def test_scrubber_physical_gc(neon_env_builder: NeonEnvBuilder, shard_count: Opt
# This write includes remote upload, will generate an index in this generation
workload.write_rows(1)
# We will use a min_age_secs=1 threshold for deletion, let it pass
time.sleep(2)
# With a high min_age, the scrubber should decline to delete anything
gc_summary = env.storage_scrubber.pageserver_physical_gc(min_age_secs=3600)
assert gc_summary["remote_storage_errors"] == 0

View File

@@ -37,7 +37,9 @@ def test_subscriber_restart(neon_simple_env: NeonEnv):
scur.execute("CREATE TABLE t (pk integer primary key, sk integer)")
# scur.execute("CREATE INDEX on t(sk)") # slowdown applying WAL at replica
pub_conn = f"host=localhost port={pub.pg_port} dbname=postgres user=cloud_admin"
query = f"CREATE SUBSCRIPTION sub CONNECTION '{pub_conn}' PUBLICATION pub"
# synchronous_commit=on to test a hypothesis for why this test has been flaky.
# XXX: Add link to the issue
query = f"CREATE SUBSCRIPTION sub CONNECTION '{pub_conn}' PUBLICATION pub with (synchronous_commit=on)"
scur.execute(query)
time.sleep(2) # let initial table sync complete

View File

@@ -757,9 +757,6 @@ def test_lsn_lease_size(neon_env_builder: NeonEnvBuilder, test_output_dir: Path,
assert_size_approx_equal_for_lease_test(lease_res, ro_branch_res)
# we are writing a lot, and flushing all of that to disk is not important for this test
env.stop(immediate=True)
def insert_with_action(
env: NeonEnv,

View File

@@ -254,10 +254,6 @@ def test_many_timelines(neon_env_builder: NeonEnvBuilder):
assert max(init_m[2].flush_lsns) <= min(final_m[2].flush_lsns) < middle_lsn
assert max(init_m[2].commit_lsns) <= min(final_m[2].commit_lsns) < middle_lsn
# Test timeline_list endpoint.
http_cli = env.safekeepers[0].http_client()
assert len(http_cli.timeline_list()) == 3
# Check that dead minority doesn't prevent the commits: execute insert n_inserts
# times, with fault_probability chance of getting a wal acceptor down or up
@@ -1300,8 +1296,6 @@ def test_lagging_sk(neon_env_builder: NeonEnvBuilder):
# Check that WALs are the same.
cmp_sk_wal([sk1, sk2, sk3], tenant_id, timeline_id)
env.stop(immediate=True)
# Smaller version of test_one_sk_down testing peer recovery in isolation: that
# it works without compute at all.

View File

@@ -259,7 +259,7 @@ files:
from
(values ('5m'),('15m'),('1h')) as t (x);
- metric_name: compute_current_lsn
- metric_name: current_lsn
type: gauge
help: 'Current LSN of the database'
key_labels:
@@ -272,19 +272,6 @@ files:
else (pg_current_wal_lsn() - '0/0')::FLOAT8
end as lsn;
- metric_name: compute_receive_lsn
type: gauge
help: 'Returns the last write-ahead log location that has been received and synced to disk by streaming replication'
key_labels:
values: [lsn]
query: |
SELECT
CASE
WHEN pg_catalog.pg_is_in_recovery()
THEN (pg_last_wal_receive_lsn() - '0/0')::FLOAT8
ELSE 0
END AS lsn;
- metric_name: replication_delay_bytes
type: gauge
help: 'Bytes between received and replayed LSN'
@@ -325,22 +312,6 @@ files:
query: |
SELECT checkpoints_timed FROM pg_stat_bgwriter;
- metric_name: compute_logical_snapshot_files
type: guage
help: 'Number of snapshot files in pg_logical/snapshot'
key_labels:
- tenant_id
- timeline_id
values: [num_logical_snapshot_files]
query: |
SELECT
(SELECT setting FROM pg_settings WHERE name = 'neon.tenant_id') AS tenant_id,
(SELECT setting FROM pg_settings WHERE name = 'neon.timeline_id') AS timeline_id,
-- Postgres creates temporary snapshot files of the form %X-%X.snap.%d.tmp. These
-- temporary snapshot files are renamed to the actual snapshot files after they are
-- completely built. We only WAL-log the completely built snapshot files.
(SELECT COUNT(*) FROM pg_ls_logicalsnapdir() WHERE name LIKE '%.snap') AS num_logical_snapshot_files;
# In all the below metrics, we cast LSNs to floats because Prometheus only supports floats.
# It's probably fine because float64 can store integers from -2^53 to +2^53 exactly.

View File

@@ -66,6 +66,8 @@ regex-syntax = { version = "0.8" }
reqwest-5ef9efb8ec2df382 = { package = "reqwest", version = "0.12", default-features = false, features = ["blocking", "json", "rustls-tls", "stream"] }
reqwest-a6292c17cd707f01 = { package = "reqwest", version = "0.11", default-features = false, features = ["blocking", "rustls-tls", "stream"] }
rustls = { version = "0.21", features = ["dangerous_configuration"] }
rustls-pki-types = { version = "1", features = ["std"] }
rustls-webpki = { version = "0.102", default-features = false, features = ["aws_lc_rs", "ring", "std"] }
scopeguard = { version = "1" }
serde = { version = "1", features = ["alloc", "derive"] }
serde_json = { version = "1", features = ["raw_value"] }
@@ -80,6 +82,8 @@ time = { version = "0.3", features = ["macros", "serde-well-known"] }
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "test-util"] }
tokio-rustls = { version = "0.24" }
tokio-util = { version = "0.7", features = ["codec", "compat", "io", "rt"] }
toml_datetime = { version = "0.6", default-features = false, features = ["serde"] }
toml_edit = { version = "0.19", features = ["serde"] }
tonic = { version = "0.9", features = ["tls-roots"] }
tower = { version = "0.4", default-features = false, features = ["balance", "buffer", "limit", "log", "timeout", "util"] }
tracing = { version = "0.1", features = ["log"] }
@@ -122,6 +126,7 @@ serde = { version = "1", features = ["alloc", "derive"] }
syn-dff4ba8e3ae991db = { package = "syn", version = "1", features = ["extra-traits", "full", "visit"] }
syn-f595c2ba2a3f28df = { package = "syn", version = "2", features = ["extra-traits", "fold", "full", "visit", "visit-mut"] }
time-macros = { version = "0.2", default-features = false, features = ["formatting", "parsing", "serde"] }
toml_datetime = { version = "0.6", default-features = false, features = ["serde"] }
zstd = { version = "0.13" }
zstd-safe = { version = "7", default-features = false, features = ["arrays", "legacy", "std", "zdict_builder"] }
zstd-sys = { version = "2", default-features = false, features = ["legacy", "std", "zdict_builder"] }